Re: [PATCH v2 07/13] RISC-V: crypto: add accelerated AES-CBC/CTR/ECB/XTS implementations

From: Jerry Shih
Date: Wed Nov 29 2023 - 02:57:35 EST


On Nov 28, 2023, at 12:07, Eric Biggers <ebiggers@xxxxxxxxxx> wrote:
> On Mon, Nov 27, 2023 at 03:06:57PM +0800, Jerry Shih wrote:
>> +typedef void (*aes_xts_func)(const u8 *in, u8 *out, size_t length,
>> + const struct crypto_aes_ctx *key, u8 *iv,
>> + int update_iv);
>
> There's no need for this indirection, because the function pointer can only have
> one value.
>
> Note also that when Control Flow Integrity is enabled, assembly functions can
> only be called indirectly when they use SYM_TYPED_FUNC_START. That's another
> reason to avoid indirect calls that aren't actually necessary.

We have two function pointers for encryption and decryption.
static int xts_encrypt(struct skcipher_request *req)
{
return xts_crypt(req, rv64i_zvbb_zvkg_zvkned_aes_xts_encrypt);
}

static int xts_decrypt(struct skcipher_request *req)
{
return xts_crypt(req, rv64i_zvbb_zvkg_zvkned_aes_xts_decrypt);
}
The enc and dec path could be folded together into `xts_crypt()`, but we will have
additional branches for enc/decryption path if we don't want to have the indirect calls.
Use `SYM_TYPED_FUNC_START` in asm might be better.

>> + nbytes &= (~(AES_BLOCK_SIZE - 1));
>
> Expressions like ~(n - 1) should not have another set of parentheses around them

Fixed.

>> +static int xts_crypt(struct skcipher_request *req, aes_xts_func func)
>> +{
>> + struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
>> + const struct riscv64_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
>> + struct skcipher_request sub_req;
>> + struct scatterlist sg_src[2], sg_dst[2];
>> + struct scatterlist *src, *dst;
>> + struct skcipher_walk walk;
>> + unsigned int walk_size = crypto_skcipher_walksize(tfm);
>> + unsigned int tail_bytes;
>> + unsigned int head_bytes;
>> + unsigned int nbytes;
>> + unsigned int update_iv = 1;
>> + int err;
>> +
>> + /* xts input size should be bigger than AES_BLOCK_SIZE */
>> + if (req->cryptlen < AES_BLOCK_SIZE)
>> + return -EINVAL;
>> +
>> + /*
>> + * We split xts-aes cryption into `head` and `tail` parts.
>> + * The head block contains the input from the beginning which doesn't need
>> + * `ciphertext stealing` method.
>> + * The tail block contains at least two AES blocks including ciphertext
>> + * stealing data from the end.
>> + */
>> + if (req->cryptlen <= walk_size) {
>> + /*
>> + * All data is in one `walk`. We could handle it within one AES-XTS call in
>> + * the end.
>> + */
>> + tail_bytes = req->cryptlen;
>> + head_bytes = 0;
>> + } else {
>> + if (req->cryptlen & (AES_BLOCK_SIZE - 1)) {
>> + /*
>> + * with ciphertext stealing
>> + *
>> + * Find the largest tail size which is small than `walk` size while the
>> + * head part still fits AES block boundary.
>> + */
>> + tail_bytes = req->cryptlen & (AES_BLOCK_SIZE - 1);
>> + tail_bytes = walk_size + tail_bytes - AES_BLOCK_SIZE;
>> + head_bytes = req->cryptlen - tail_bytes;
>> + } else {
>> + /* no ciphertext stealing */
>> + tail_bytes = 0;
>> + head_bytes = req->cryptlen;
>> + }
>> + }
>> +
>> + riscv64_aes_encrypt_zvkned(&ctx->ctx2, req->iv, req->iv);
>> +
>> + if (head_bytes && tail_bytes) {
>> + /* If we have to parts, setup new request for head part only. */
>> + skcipher_request_set_tfm(&sub_req, tfm);
>> + skcipher_request_set_callback(
>> + &sub_req, skcipher_request_flags(req), NULL, NULL);
>> + skcipher_request_set_crypt(&sub_req, req->src, req->dst,
>> + head_bytes, req->iv);
>> + req = &sub_req;
>> + }
>> +
>> + if (head_bytes) {
>> + err = skcipher_walk_virt(&walk, req, false);
>> + while ((nbytes = walk.nbytes)) {
>> + if (nbytes == walk.total)
>> + update_iv = (tail_bytes > 0);
>> +
>> + nbytes &= (~(AES_BLOCK_SIZE - 1));
>> + kernel_vector_begin();
>> + func(walk.src.virt.addr, walk.dst.virt.addr, nbytes,
>> + &ctx->ctx1, req->iv, update_iv);
>> + kernel_vector_end();
>> +
>> + err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
>> + }
>> + if (err || !tail_bytes)
>> + return err;
>> +
>> + /*
>> + * Setup new request for tail part.
>> + * We use `scatterwalk_next()` to find the next scatterlist from last
>> + * walk instead of iterating from the beginning.
>> + */
>> + dst = src = scatterwalk_next(sg_src, &walk.in);
>> + if (req->dst != req->src)
>> + dst = scatterwalk_next(sg_dst, &walk.out);
>> + skcipher_request_set_crypt(req, src, dst, tail_bytes, req->iv);
>> + }
>> +
>> + /* tail */
>> + err = skcipher_walk_virt(&walk, req, false);
>> + if (err)
>> + return err;
>> + if (walk.nbytes != tail_bytes)
>> + return -EINVAL;
>> + kernel_vector_begin();
>> + func(walk.src.virt.addr, walk.dst.virt.addr, walk.nbytes, &ctx->ctx1,
>> + req->iv, 0);
>> + kernel_vector_end();
>> +
>> + return skcipher_walk_done(&walk, 0);
>> +}
>
> Did you consider writing xts_crypt() the way that arm64 and x86 do it? The
> above seems to reinvent sort of the same thing from first principles. I'm
> wondering if you should just copy the existing approach for now. Then there
> would be no need to add the scatterwalk_next() function, and also the handling
> of inputs that don't need ciphertext stealing would be a bit more streamlined.

I will check the arm and x86's implementations.
But the `scatterwalk_next()` proposed in this series does the same thing as the
call `scatterwalk_ffwd()` in arm and x86's implementations.
The scatterwalk_ffwd() iterates from the beginning of scatterlist(O(n)), but the
scatterwalk_next() is just iterates from the end point of the last used
scatterlist(O(1)).

>> +static int __init riscv64_aes_block_mod_init(void)
>> +{
>> + int ret = -ENODEV;
>> +
>> + if (riscv_isa_extension_available(NULL, ZVKNED) &&
>> + riscv_vector_vlen() >= 128 && riscv_vector_vlen() <= 2048) {
>> + ret = simd_register_skciphers_compat(
>> + riscv64_aes_algs_zvkned,
>> + ARRAY_SIZE(riscv64_aes_algs_zvkned),
>> + riscv64_aes_simd_algs_zvkned);
>> + if (ret)
>> + return ret;
>> +
>> + if (riscv_isa_extension_available(NULL, ZVBB)) {
>> + ret = simd_register_skciphers_compat(
>> + riscv64_aes_alg_zvkned_zvkb,
>> + ARRAY_SIZE(riscv64_aes_alg_zvkned_zvkb),
>> + riscv64_aes_simd_alg_zvkned_zvkb);
>> + if (ret)
>> + goto unregister_zvkned;
>
> This makes the registration of the zvkned-zvkb algorithm conditional on zvbb,
> not zvkb. Shouldn't the extension checks actually look like:
>
> ZVKNED
> ZVKB
> ZVBB && ZVKG

Fixed.
But we will have the conditions like:
if(ZVKNED) {
reg_cipher_1();
if(ZVKB) {
reg_cipher_2();
}
if (ZVBB && ZVKG) {
reg_cipher_3();
}
}

> - Eric