Re: [PATCH 06/12] RISC-V: crypto: add accelerated AES-CBC/CTR/ECB/XTS implementations

From: Jerry Shih
Date: Sun Nov 19 2023 - 21:47:42 EST


On Nov 2, 2023, at 13:16, Eric Biggers <ebiggers@xxxxxxxxxx> wrote:
> On Thu, Oct 26, 2023 at 02:36:38AM +0800, Jerry Shih wrote:
>> +config CRYPTO_AES_BLOCK_RISCV64
>> + default y if RISCV_ISA_V
>> + tristate "Ciphers: AES, modes: ECB/CBC/CTR/XTS"
>> + depends on 64BIT && RISCV_ISA_V
>> + select CRYPTO_AES_RISCV64
>> + select CRYPTO_SKCIPHER
>> + help
>> + Length-preserving ciphers: AES cipher algorithms (FIPS-197)
>> + with block cipher modes:
>> + - ECB (Electronic Codebook) mode (NIST SP 800-38A)
>> + - CBC (Cipher Block Chaining) mode (NIST SP 800-38A)
>> + - CTR (Counter) mode (NIST SP 800-38A)
>> + - XTS (XOR Encrypt XOR Tweakable Block Cipher with Ciphertext
>> + Stealing) mode (NIST SP 800-38E and IEEE 1619)
>> +
>> + Architecture: riscv64 using:
>> + - Zvbb vector extension (XTS)
>> + - Zvkb vector crypto extension (CTR/XTS)
>> + - Zvkg vector crypto extension (XTS)
>> + - Zvkned vector crypto extension
>
> Maybe list Zvkned first since it's the most important one in this context.

Fixed.


>> +#define AES_BLOCK_VALID_SIZE_MASK (~(AES_BLOCK_SIZE - 1))
>> +#define AES_BLOCK_REMAINING_SIZE_MASK (AES_BLOCK_SIZE - 1)
>
> I think it would be easier to read if these values were just used directly.

Fixed.

>> +static int ecb_encrypt(struct skcipher_request *req)
>> +{
>> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
>> + const struct riscv64_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
>> + struct skcipher_walk walk;
>> + unsigned int nbytes;
>> + int err;
>> +
>> + /* If we have error here, the `nbytes` will be zero. */
>> + err = skcipher_walk_virt(&walk, req, false);
>> + while ((nbytes = walk.nbytes)) {
>> + kernel_vector_begin();
>> + rv64i_zvkned_ecb_encrypt(walk.src.virt.addr, walk.dst.virt.addr,
>> + nbytes & AES_BLOCK_VALID_SIZE_MASK,
>> + &ctx->key);
>> + kernel_vector_end();
>> + err = skcipher_walk_done(
>> + &walk, nbytes & AES_BLOCK_REMAINING_SIZE_MASK);
>> + }
>> +
>> + return err;
>> +}
>
> There's no fallback for !crypto_simd_usable() here. I really like it this way.
> However, for it to work (for skciphers and aeads), RISC-V needs to allow the
> vector registers to be used in softirq context. Is that already the case?

I turn to use simd skcipher interface. More details will be in the v2 patch set.

>> +/* ctr */
>> +static int ctr_encrypt(struct skcipher_request *req)
>> +{
>> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
>> + const struct riscv64_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
>> + struct skcipher_walk walk;
>> + unsigned int ctr32;
>> + unsigned int nbytes;
>> + unsigned int blocks;
>> + unsigned int current_blocks;
>> + unsigned int current_length;
>> + int err;
>> +
>> + /* the ctr iv uses big endian */
>> + ctr32 = get_unaligned_be32(req->iv + 12);
>> + err = skcipher_walk_virt(&walk, req, false);
>> + while ((nbytes = walk.nbytes)) {
>> + if (nbytes != walk.total) {
>> + nbytes &= AES_BLOCK_VALID_SIZE_MASK;
>> + blocks = nbytes / AES_BLOCK_SIZE;
>> + } else {
>> + /* This is the last walk. We should handle the tail data. */
>> + blocks = (nbytes + (AES_BLOCK_SIZE - 1)) /
>> + AES_BLOCK_SIZE;
>
> '(nbytes + (AES_BLOCK_SIZE - 1)) / AES_BLOCK_SIZE' can be replaced with
> 'DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE)'

Fixed.

>> +static int xts_crypt(struct skcipher_request *req, aes_xts_func func)
>> +{
>> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
>> + const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
>> + struct skcipher_request sub_req;
>> + struct scatterlist sg_src[2], sg_dst[2];
>> + struct scatterlist *src, *dst;
>> + struct skcipher_walk walk;
>> + unsigned int walk_size = crypto_skcipher_walksize(tfm);
>> + unsigned int tail_bytes;
>> + unsigned int head_bytes;
>> + unsigned int nbytes;
>> + unsigned int update_iv = 1;
>> + int err;
>> +
>> + /* xts input size should be bigger than AES_BLOCK_SIZE */
>> + if (req->cryptlen < AES_BLOCK_SIZE)
>> + return -EINVAL;
>> +
>> + /*
>> + * The tail size should be small than walk_size. Thus, we could make sure the
>> + * walk size for tail elements could be bigger than AES_BLOCK_SIZE.
>> + */
>> + if (req->cryptlen <= walk_size) {
>> + tail_bytes = req->cryptlen;
>> + head_bytes = 0;
>> + } else {
>> + if (req->cryptlen & AES_BLOCK_REMAINING_SIZE_MASK) {
>> + tail_bytes = req->cryptlen &
>> + AES_BLOCK_REMAINING_SIZE_MASK;
>> + tail_bytes = walk_size + tail_bytes - AES_BLOCK_SIZE;
>> + head_bytes = req->cryptlen - tail_bytes;
>> + } else {
>> + tail_bytes = 0;
>> + head_bytes = req->cryptlen;
>> + }
>> + }
>> +
>> + riscv64_aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv);
>> +
>> + if (head_bytes && tail_bytes) {
>> + skcipher_request_set_tfm(&sub_req, tfm);
>> + skcipher_request_set_callback(
>> + &sub_req, skcipher_request_flags(req), NULL, NULL);
>> + skcipher_request_set_crypt(&sub_req, req->src, req->dst,
>> + head_bytes, req->iv);
>> + req = &sub_req;
>> + }
>> +
>> + if (head_bytes) {
>> + err = skcipher_walk_virt(&walk, req, false);
>> + while ((nbytes = walk.nbytes)) {
>> + if (nbytes == walk.total)
>> + update_iv = (tail_bytes > 0);
>> +
>> + nbytes &= AES_BLOCK_VALID_SIZE_MASK;
>> + kernel_vector_begin();
>> + func(walk.src.virt.addr, walk.dst.virt.addr, nbytes,
>> + &ctx->ctx1.key, req->iv, update_iv);
>> + kernel_vector_end();
>> +
>> + err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
>> + }
>> + if (err || !tail_bytes)
>> + return err;
>> +
>> + dst = src = scatterwalk_next(sg_src, &walk.in);
>> + if (req->dst != req->src)
>> + dst = scatterwalk_next(sg_dst, &walk.out);
>> + skcipher_request_set_crypt(req, src, dst, tail_bytes, req->iv);
>> + }
>> +
>> + /* tail */
>> + err = skcipher_walk_virt(&walk, req, false);
>> + if (err)
>> + return err;
>> + if (walk.nbytes != tail_bytes)
>> + return -EINVAL;
>> + kernel_vector_begin();
>> + func(walk.src.virt.addr, walk.dst.virt.addr, walk.nbytes,
>> + &ctx->ctx1.key, req->iv, 0);
>> + kernel_vector_end();
>> +
>> + return skcipher_walk_done(&walk, 0);
>> +}
>
> This function looks a bit weird. I see it's also the only caller of the
> scatterwalk_next() function that you're adding. I haven't looked at this super
> closely, but I expect that there's a cleaner way of handling the "tail" than
> this -- maybe use scatterwalk_map_and_copy() to copy from/to a stack buffer?
>
> - Eric

I put more comments in v2 patch set. Hope it will be more clear.
Even though we use `scatterwalk_map_and_copy()`, it still use
`scatterwalk_ffwd()` inside. The `scatterwalk_next()` is used
for just `moving the next scatterlist` from from the previous
walk instead of iterating from the head.

-Jerry