Re: [PATCH 06/12] RISC-V: crypto: add accelerated AES-CBC/CTR/ECB/XTS implementations

From: Eric Biggers
Date: Thu Nov 02 2023 - 01:16:46 EST


On Thu, Oct 26, 2023 at 02:36:38AM +0800, Jerry Shih wrote:
> +config CRYPTO_AES_BLOCK_RISCV64
> + default y if RISCV_ISA_V
> + tristate "Ciphers: AES, modes: ECB/CBC/CTR/XTS"
> + depends on 64BIT && RISCV_ISA_V
> + select CRYPTO_AES_RISCV64
> + select CRYPTO_SKCIPHER
> + help
> + Length-preserving ciphers: AES cipher algorithms (FIPS-197)
> + with block cipher modes:
> + - ECB (Electronic Codebook) mode (NIST SP 800-38A)
> + - CBC (Cipher Block Chaining) mode (NIST SP 800-38A)
> + - CTR (Counter) mode (NIST SP 800-38A)
> + - XTS (XOR Encrypt XOR Tweakable Block Cipher with Ciphertext
> + Stealing) mode (NIST SP 800-38E and IEEE 1619)
> +
> + Architecture: riscv64 using:
> + - Zvbb vector extension (XTS)
> + - Zvkb vector crypto extension (CTR/XTS)
> + - Zvkg vector crypto extension (XTS)
> + - Zvkned vector crypto extension

Maybe list Zvkned first since it's the most important one in this context.

> +#define AES_BLOCK_VALID_SIZE_MASK (~(AES_BLOCK_SIZE - 1))
> +#define AES_BLOCK_REMAINING_SIZE_MASK (AES_BLOCK_SIZE - 1)

I think it would be easier to read if these values were just used directly.

> +static int ecb_encrypt(struct skcipher_request *req)
> +{
> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> + const struct riscv64_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> + struct skcipher_walk walk;
> + unsigned int nbytes;
> + int err;
> +
> + /* If we have error here, the `nbytes` will be zero. */
> + err = skcipher_walk_virt(&walk, req, false);
> + while ((nbytes = walk.nbytes)) {
> + kernel_vector_begin();
> + rv64i_zvkned_ecb_encrypt(walk.src.virt.addr, walk.dst.virt.addr,
> + nbytes & AES_BLOCK_VALID_SIZE_MASK,
> + &ctx->key);
> + kernel_vector_end();
> + err = skcipher_walk_done(
> + &walk, nbytes & AES_BLOCK_REMAINING_SIZE_MASK);
> + }
> +
> + return err;
> +}

There's no fallback for !crypto_simd_usable() here. I really like it this way.
However, for it to work (for skciphers and aeads), RISC-V needs to allow the
vector registers to be used in softirq context. Is that already the case?

> +/* ctr */
> +static int ctr_encrypt(struct skcipher_request *req)
> +{
> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> + const struct riscv64_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
> + struct skcipher_walk walk;
> + unsigned int ctr32;
> + unsigned int nbytes;
> + unsigned int blocks;
> + unsigned int current_blocks;
> + unsigned int current_length;
> + int err;
> +
> + /* the ctr iv uses big endian */
> + ctr32 = get_unaligned_be32(req->iv + 12);
> + err = skcipher_walk_virt(&walk, req, false);
> + while ((nbytes = walk.nbytes)) {
> + if (nbytes != walk.total) {
> + nbytes &= AES_BLOCK_VALID_SIZE_MASK;
> + blocks = nbytes / AES_BLOCK_SIZE;
> + } else {
> + /* This is the last walk. We should handle the tail data. */
> + blocks = (nbytes + (AES_BLOCK_SIZE - 1)) /
> + AES_BLOCK_SIZE;

'(nbytes + (AES_BLOCK_SIZE - 1)) / AES_BLOCK_SIZE' can be replaced with
'DIV_ROUND_UP(nbytes, AES_BLOCK_SIZE)'

> +static int xts_crypt(struct skcipher_request *req, aes_xts_func func)
> +{
> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> + const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
> + struct skcipher_request sub_req;
> + struct scatterlist sg_src[2], sg_dst[2];
> + struct scatterlist *src, *dst;
> + struct skcipher_walk walk;
> + unsigned int walk_size = crypto_skcipher_walksize(tfm);
> + unsigned int tail_bytes;
> + unsigned int head_bytes;
> + unsigned int nbytes;
> + unsigned int update_iv = 1;
> + int err;
> +
> + /* xts input size should be bigger than AES_BLOCK_SIZE */
> + if (req->cryptlen < AES_BLOCK_SIZE)
> + return -EINVAL;
> +
> + /*
> + * The tail size should be small than walk_size. Thus, we could make sure the
> + * walk size for tail elements could be bigger than AES_BLOCK_SIZE.
> + */
> + if (req->cryptlen <= walk_size) {
> + tail_bytes = req->cryptlen;
> + head_bytes = 0;
> + } else {
> + if (req->cryptlen & AES_BLOCK_REMAINING_SIZE_MASK) {
> + tail_bytes = req->cryptlen &
> + AES_BLOCK_REMAINING_SIZE_MASK;
> + tail_bytes = walk_size + tail_bytes - AES_BLOCK_SIZE;
> + head_bytes = req->cryptlen - tail_bytes;
> + } else {
> + tail_bytes = 0;
> + head_bytes = req->cryptlen;
> + }
> + }
> +
> + riscv64_aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv);
> +
> + if (head_bytes && tail_bytes) {
> + skcipher_request_set_tfm(&sub_req, tfm);
> + skcipher_request_set_callback(
> + &sub_req, skcipher_request_flags(req), NULL, NULL);
> + skcipher_request_set_crypt(&sub_req, req->src, req->dst,
> + head_bytes, req->iv);
> + req = &sub_req;
> + }
> +
> + if (head_bytes) {
> + err = skcipher_walk_virt(&walk, req, false);
> + while ((nbytes = walk.nbytes)) {
> + if (nbytes == walk.total)
> + update_iv = (tail_bytes > 0);
> +
> + nbytes &= AES_BLOCK_VALID_SIZE_MASK;
> + kernel_vector_begin();
> + func(walk.src.virt.addr, walk.dst.virt.addr, nbytes,
> + &ctx->ctx1.key, req->iv, update_iv);
> + kernel_vector_end();
> +
> + err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
> + }
> + if (err || !tail_bytes)
> + return err;
> +
> + dst = src = scatterwalk_next(sg_src, &walk.in);
> + if (req->dst != req->src)
> + dst = scatterwalk_next(sg_dst, &walk.out);
> + skcipher_request_set_crypt(req, src, dst, tail_bytes, req->iv);
> + }
> +
> + /* tail */
> + err = skcipher_walk_virt(&walk, req, false);
> + if (err)
> + return err;
> + if (walk.nbytes != tail_bytes)
> + return -EINVAL;
> + kernel_vector_begin();
> + func(walk.src.virt.addr, walk.dst.virt.addr, walk.nbytes,
> + &ctx->ctx1.key, req->iv, 0);
> + kernel_vector_end();
> +
> + return skcipher_walk_done(&walk, 0);
> +}

This function looks a bit weird. I see it's also the only caller of the
scatterwalk_next() function that you're adding. I haven't looked at this super
closely, but I expect that there's a cleaner way of handling the "tail" than
this -- maybe use scatterwalk_map_and_copy() to copy from/to a stack buffer?

- Eric