Re: [PATCH] riscv: lib: Implement optimized memchr function

From: Andrew Jones
Date: Mon Dec 11 2023 - 10:08:54 EST


On Fri, Dec 08, 2023 at 02:18:39PM +0000, Ivan Orlov wrote:
> At the moment we don't have an architecture-specific memchr
> implementation for riscv in the Kernel. The generic version of this
> function iterates the memory area bytewise looking for the target value,
> which is not the optimal approach.
>
> Instead of iterating the memory byte by byte, we can iterate over words
> of memory. Word still takes only one cycle to be loaded, and it could be
> checked for containing the target byte in just 5 operations:
>
> 1. Let's say we are looking for the byte BA. XOR the word with
> 0xBABA..BA
> 2. If we have zero byte in the result, the word contains byte BA. Let's
> subtract 0x0101..01 from the xor result.
> 3. Calculate the ~(xor result).
> 4. And the results of steps 2 and 3. If in the xor result we had a zero
> bit somewhere, and after subtracting the 0x0101..01 it turned to 1,
> we will get 1 in the result
> 5. And the result of step 4 with 0x8080..80. If we had a leading zero
> bit in the xor result which turned to 1 after subtracting 0x0101..01,
> it was the leading bit of a zero byte. So, if result of this step != 0,
> the word contains the byte we are looking for.
>
> The same approach is used in the arm64 implementation of this function.
>
> So, this patch introduces the riscv-specific memchr function which
> accepts 3 parameters (address, target byte and count) and works in the
> following way:
>
> 0. If count is smaller than 128, iterate the area byte by byte as we
> would not get any performance gain here.
> 1. If address is not aligned, iterate SZREG - (address % SZREG) bytes
> to avoid unaligned memory access.
> 2. If count is larger than 128, iterate words of memory until we find
> the word which contains the target byte.
> 3. If we have found the word, iterate through it byte by byte and return
> the address of the first occurrence.
> 4. If we have not found the word, iterate the remainder (in case if
> the count was not divisible by 8).
> 5. If we still have not found the target byte, return 0.
>
> Here you can see the benchmark results for "Sifive Hifive Unmatched"
> board, which compares the old and new memchr implementations.
>
> | test_count | array_size | old_mean_ktime | new_mean_ktime |
> ---------------------------------------------------------------
> | 10000 | 10 | 415 | 409 |
> | 10000 | 100 | 642 | 717 |
> | 10000 | 128 | 714 | 775 |
> | 10000 | 256 | 1031 | 611 |
> | 5000 | 512 | 1686 | 769 |
> | 5000 | 768 | 2320 | 925 |
> | 5000 | 1024 | 2968 | 1095 |
> | 5000 | 1500 | 4165 | 1383 |
> | 5000 | 2048 | 5567 | 1731 |
> | 3000 | 4096 | 10698 | 3028 |
> | 3000 | 16384 | 41630 | 10766 |
> | 1000 | 524288 | 1475454 | 498183 |
> | 1000 | 1048576 | 2952080 | 997018 |
> | 500 | 10485760 | 49491492 | 29335358 |
> | 100 | 134217728 | 636033660 | 377157970 |
> | 20 | 536870912 | 2546979300 | 1510817350 |
> | 20 | 1073741824 | 5095776750 | 3019167250 |
>
> The target symbol was always placed at the last index of the array, and
> the mean time of function execution was measured using the ktime_get
> function.
>
> As you can see, the new function shows much better results even for
> the small arrays of 256 elements, therefore I believe it could be a
> useful addition to the existing riscv-specific string functions.

Looks good, but do we want to maintain both this version and a zbb
version? I'd expect a zbb version to be even better.

A couple more comments below.

>
> Signed-off-by: Ivan Orlov <ivan.orlov@xxxxxxxxxxxxxxx>
> ---
> arch/riscv/include/asm/string.h | 2 +
> arch/riscv/kernel/riscv_ksyms.c | 1 +
> arch/riscv/lib/Makefile | 1 +
> arch/riscv/lib/memchr.S | 98 +++++++++++++++++++++++++++++++++
> 4 files changed, 102 insertions(+)
> create mode 100644 arch/riscv/lib/memchr.S
>
> diff --git a/arch/riscv/include/asm/string.h b/arch/riscv/include/asm/string.h
> index a96b1fea24fe..ec1a643cb625 100644
> --- a/arch/riscv/include/asm/string.h
> +++ b/arch/riscv/include/asm/string.h
> @@ -18,6 +18,8 @@ extern asmlinkage void *__memcpy(void *, const void *, size_t);
> #define __HAVE_ARCH_MEMMOVE
> extern asmlinkage void *memmove(void *, const void *, size_t);
> extern asmlinkage void *__memmove(void *, const void *, size_t);
> +#define __HAVE_ARCH_MEMCHR
> +extern asmlinkage void *memchr(const void *, int, size_t);
>
> #define __HAVE_ARCH_STRCMP
> extern asmlinkage int strcmp(const char *cs, const char *ct);
> diff --git a/arch/riscv/kernel/riscv_ksyms.c b/arch/riscv/kernel/riscv_ksyms.c
> index a72879b4249a..08c0d846366b 100644
> --- a/arch/riscv/kernel/riscv_ksyms.c
> +++ b/arch/riscv/kernel/riscv_ksyms.c
> @@ -9,6 +9,7 @@
> /*
> * Assembly functions that may be used (directly or indirectly) by modules
> */
> +EXPORT_SYMBOL(memchr);
> EXPORT_SYMBOL(memset);
> EXPORT_SYMBOL(memcpy);
> EXPORT_SYMBOL(memmove);
> diff --git a/arch/riscv/lib/Makefile b/arch/riscv/lib/Makefile
> index 26cb2502ecf8..0a8b64f8ca88 100644
> --- a/arch/riscv/lib/Makefile
> +++ b/arch/riscv/lib/Makefile
> @@ -1,4 +1,5 @@
> # SPDX-License-Identifier: GPL-2.0-only
> +lib-y += memchr.o
> lib-y += delay.o
> lib-y += memcpy.o
> lib-y += memset.o
> diff --git a/arch/riscv/lib/memchr.S b/arch/riscv/lib/memchr.S
> new file mode 100644
> index 000000000000..d48e0fa3cd84
> --- /dev/null
> +++ b/arch/riscv/lib/memchr.S
> @@ -0,0 +1,98 @@
> +/* SPDX-License-Identifier: GPL-2.0-only */
> +/*
> + * Copyright (c) 2023 Codethink Ltd.
> + * Author: Ivan Orlov <ivan.orlov@xxxxxxxxxxxxxxx>
> + */
> +
> +#include <linux/linkage.h>
> +#include <asm/asm.h>
> +
> +#define REP_01 __REG_SEL(0x0101010101010101, 0x01010101)
> +#define REP_80 __REG_SEL(0x8080808080808080, 0x80808080)
> +
> +#define MIN_BORDER 128
> +
> +SYM_FUNC_START(memchr)
> + andi a1, a1, 0xFF
> +
> + // use byte-wide iteration for small numbers
> + add t1, x0, a2

move t1, a2

and for the remainder of the function s/x0/zero/

> + sltiu t2, a2, MIN_BORDER
> + bnez t2, 6f
> +
> + // get the number of bytes we should iterate before alignment

I'm not sure, but I think even in assembly we prefer the /* */ comment
format.

> + andi t0, a0, SZREG - 1
> + beqz t0, 4f
> +
> + # get the SZREG - t0

I'm 99% sure we don't want to use the # comment syntax.

> + xor t0, t0, SZREG - 1

xori?

> + addi t0, t0, 1
> +
> + sub a2, a2, t0

nit: Looks a bit odd to put a blank line above the sub line above,
instead of above the below comment.

> + // iterate before alignment
> +1:
> + beq t0, x0, 4f
> + lbu t2, 0(a0)
> + beq t2, a1, 3f
> + addi t0, t0, -1

This addi t0... isn't necessary if we do

add t0, a0, t0
1:
beq a0, t0, 4f
...
...
addi a0, a0, 1
j 1b

> + addi a0, a0, 1
> + j 1b
> +
> +2:
> + // found a word. Iterate it until we find the target byte
> + li t1, SZREG
> + j 6f

These instructions seem oddly placed among the rest.

> +3:
> + ret

And this is an odd place to put this ret (after unconditional jump and
in the middle of the function). We can just put a label at the bottom ret.

> +
> +4:
> + // get the count remainder
> + andi t1, a2, SZREG - 1
> +
> + // align the count
> + sub a2, a2, t1
> +
> + // if we have no words to iterate, iterate the remainder
> + beqz a2, 6f
> +
> + // from 0xBA we will get 0xBABABABABABABABA
> + li t3, REP_01
> + mul t3, t3, a1

I don't think we want to implement an optimized assembly function with
mul. We can just use a few shifts and ors.

slli t3, a1, 8
or t3, t3, a1
slli t4, t3, 16
or t3, t4, t3
#if __riscv_xlen == 64
slli t4, t3, 32
or t3, t4, t3
#endif

> +
> + add a2, a2, a0
> +
> + li t4, REP_01
> + li t5, REP_80
> +
> +5:
> + REG_L t2, 0(a0)
> +
> + // after this xor we will get one zero byte in the word if it contains the target byte
> + xor t2, t2, t3
> +
> + // word v contains the target byte if (v - 0x01010101) & (~v) & 0x80808080 is positive

s/is positive/is not zero/

> + sub t0, t2, t4
> +
> + not t2, t2
> +
> + and t0, t0, t2
> + and t0, t0, t5
> +
> + bnez t0, 2b
> + addi a0, a0, SZREG
> + bne a0, a2, 5b
> +
> +6:
> + // iterate the remainder
> + beq t1, x0, 7f
> + lbu t4, 0(a0)
> + beq t4, a1, 3b
> + addi a0, a0, 1
> + addi t1, t1, -1

Same comment as above about being able to drop the addi t1...

> + j 6b
> +
> +7:
> + addi a0, x0, 0

li a0, 0

> + ret
> +SYM_FUNC_END(memchr)
> +SYM_FUNC_ALIAS(__pi_memchr, memchr)
> --
> 2.34.1
>

Thanks,
drew