Re: [PATCH v2 07/13] RISC-V: crypto: add accelerated AES-CBC/CTR/ECB/XTS implementations

From: Eric Biggers
Date: Mon Nov 27 2023 - 23:07:23 EST


On Mon, Nov 27, 2023 at 03:06:57PM +0800, Jerry Shih wrote:
> +typedef void (*aes_xts_func)(const u8 *in, u8 *out, size_t length,
> + const struct crypto_aes_ctx *key, u8 *iv,
> + int update_iv);

There's no need for this indirection, because the function pointer can only have
one value.

Note also that when Control Flow Integrity is enabled, assembly functions can
only be called indirectly when they use SYM_TYPED_FUNC_START. That's another
reason to avoid indirect calls that aren't actually necessary.

> + nbytes &= (~(AES_BLOCK_SIZE - 1));

Expressions like ~(n - 1) should not have another set of parentheses around them

> +static int xts_crypt(struct skcipher_request *req, aes_xts_func func)
> +{
> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
> + const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
> + struct skcipher_request sub_req;
> + struct scatterlist sg_src[2], sg_dst[2];
> + struct scatterlist *src, *dst;
> + struct skcipher_walk walk;
> + unsigned int walk_size = crypto_skcipher_walksize(tfm);
> + unsigned int tail_bytes;
> + unsigned int head_bytes;
> + unsigned int nbytes;
> + unsigned int update_iv = 1;
> + int err;
> +
> + /* xts input size should be bigger than AES_BLOCK_SIZE */
> + if (req->cryptlen < AES_BLOCK_SIZE)
> + return -EINVAL;
> +
> + /*
> + * We split xts-aes cryption into `head` and `tail` parts.
> + * The head block contains the input from the beginning which doesn't need
> + * `ciphertext stealing` method.
> + * The tail block contains at least two AES blocks including ciphertext
> + * stealing data from the end.
> + */
> + if (req->cryptlen <= walk_size) {
> + /*
> + * All data is in one `walk`. We could handle it within one AES-XTS call in
> + * the end.
> + */
> + tail_bytes = req->cryptlen;
> + head_bytes = 0;
> + } else {
> + if (req->cryptlen & (AES_BLOCK_SIZE - 1)) {
> + /*
> + * with ciphertext stealing
> + *
> + * Find the largest tail size which is small than `walk` size while the
> + * head part still fits AES block boundary.
> + */
> + tail_bytes = req->cryptlen & (AES_BLOCK_SIZE - 1);
> + tail_bytes = walk_size + tail_bytes - AES_BLOCK_SIZE;
> + head_bytes = req->cryptlen - tail_bytes;
> + } else {
> + /* no ciphertext stealing */
> + tail_bytes = 0;
> + head_bytes = req->cryptlen;
> + }
> + }
> +
> + riscv64_aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv);
> +
> + if (head_bytes && tail_bytes) {
> + /* If we have to parts, setup new request for head part only. */
> + skcipher_request_set_tfm(&sub_req, tfm);
> + skcipher_request_set_callback(
> + &sub_req, skcipher_request_flags(req), NULL, NULL);
> + skcipher_request_set_crypt(&sub_req, req->src, req->dst,
> + head_bytes, req->iv);
> + req = &sub_req;
> + }
> +
> + if (head_bytes) {
> + err = skcipher_walk_virt(&walk, req, false);
> + while ((nbytes = walk.nbytes)) {
> + if (nbytes == walk.total)
> + update_iv = (tail_bytes > 0);
> +
> + nbytes &= (~(AES_BLOCK_SIZE - 1));
> + kernel_vector_begin();
> + func(walk.src.virt.addr, walk.dst.virt.addr, nbytes,
> + &ctx->ctx1, req->iv, update_iv);
> + kernel_vector_end();
> +
> + err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
> + }
> + if (err || !tail_bytes)
> + return err;
> +
> + /*
> + * Setup new request for tail part.
> + * We use `scatterwalk_next()` to find the next scatterlist from last
> + * walk instead of iterating from the beginning.
> + */
> + dst = src = scatterwalk_next(sg_src, &walk.in);
> + if (req->dst != req->src)
> + dst = scatterwalk_next(sg_dst, &walk.out);
> + skcipher_request_set_crypt(req, src, dst, tail_bytes, req->iv);
> + }
> +
> + /* tail */
> + err = skcipher_walk_virt(&walk, req, false);
> + if (err)
> + return err;
> + if (walk.nbytes != tail_bytes)
> + return -EINVAL;
> + kernel_vector_begin();
> + func(walk.src.virt.addr, walk.dst.virt.addr, walk.nbytes, &ctx->ctx1,
> + req->iv, 0);
> + kernel_vector_end();
> +
> + return skcipher_walk_done(&walk, 0);
> +}

Did you consider writing xts_crypt() the way that arm64 and x86 do it? The
above seems to reinvent sort of the same thing from first principles. I'm
wondering if you should just copy the existing approach for now. Then there
would be no need to add the scatterwalk_next() function, and also the handling
of inputs that don't need ciphertext stealing would be a bit more streamlined.

> +static int __init riscv64_aes_block_mod_init(void)
> +{
> + int ret = -ENODEV;
> +
> + if (riscv_isa_extension_available(NULL, ZVKNED) &&
> + riscv_vector_vlen() >= 128 && riscv_vector_vlen() <= 2048) {
> + ret = simd_register_skciphers_compat(
> + riscv64_aes_algs_zvkned,
> + ARRAY_SIZE(riscv64_aes_algs_zvkned),
> + riscv64_aes_simd_algs_zvkned);
> + if (ret)
> + return ret;
> +
> + if (riscv_isa_extension_available(NULL, ZVBB)) {
> + ret = simd_register_skciphers_compat(
> + riscv64_aes_alg_zvkned_zvkb,
> + ARRAY_SIZE(riscv64_aes_alg_zvkned_zvkb),
> + riscv64_aes_simd_alg_zvkned_zvkb);
> + if (ret)
> + goto unregister_zvkned;

This makes the registration of the zvkned-zvkb algorithm conditional on zvbb,
not zvkb. Shouldn't the extension checks actually look like:

ZVKNED
ZVKB
ZVBB && ZVKG

- Eric