[PATCH 4/5] riscv: Vector checksum library

From: Charlie Jenkins
Date: Sat Aug 26 2023 - 21:27:43 EST


This patch is not ready for merge as vector support in the kernel is
limited. However, the code has been tested in QEMU so the algorithms
do work. When Vector support is more mature, I will do more thorough
testing of this code. It is written in assembly rather than using
the GCC vector instrinsics because they did not provide optimal code.

Signed-off-by: Charlie Jenkins <charlie@xxxxxxxxxxxx>
---
arch/riscv/lib/csum.c | 165 ++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 165 insertions(+)

diff --git a/arch/riscv/lib/csum.c b/arch/riscv/lib/csum.c
index 2037041ce8a0..049a10596008 100644
--- a/arch/riscv/lib/csum.c
+++ b/arch/riscv/lib/csum.c
@@ -12,6 +12,10 @@

#include <net/checksum.h>

+#ifdef CONFIG_RISCV_ISA_V
+#include <riscv_vector.h>
+#endif
+
/* Default version is sufficient for 32 bit */
#ifdef CONFIG_64BIT
__sum16 csum_ipv6_magic(const struct in6_addr *saddr,
@@ -64,6 +68,166 @@ typedef unsigned long csum_t;
* the bytes that it shouldn't. The same thing will occur on the tail-end of the
* read.
*/
+#ifdef CONFIG_RISCV_ISA_V
+#ifdef CONFIG_32BIT
+unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
+{
+ vuint64m1_t prev_buffer;
+ vuint32m1_t curr_buffer;
+ unsigned int shift;
+ unsigned int vl, high_result, low_result, csum, offset;
+ unsigned int tail_seg;
+ const unsigned int *ptr;
+
+ if (len <= 0)
+ return 0;
+
+ /*
+ * To align the address, grab the whole first byte in buff.
+ * Directly call KASAN with the alignment we will be using.
+ */
+ offset = (unsigned int)buff & OFFSET_MASK;
+ kasan_check_read(buff, len);
+ ptr = (const unsigned int *)(buff - offset);
+ len += offset;
+
+ // Read the tail segment
+ tail_seg = len % 4;
+ csum = 0;
+ if (tail_seg) {
+ shift = (4 - tail_seg) * 8;
+ csum = *(unsigned int *)((const unsigned char *)ptr + len - tail_seg);
+ csum = ((unsigned int)csum << shift) >> shift;
+ len -= tail_seg;
+ }
+
+ unsigned long start_mask = (unsigned int)(~(~0U << offset));
+
+ asm("vsetvli %[vl], %[len], e8, m1, ta, ma \n\
+ # clear out mask and vector registers since we switch up sizes \n\
+ vmclr.m v0 \n\
+ vmclr.m %[prev_buffer] \n\
+ vmclr.m %[curr_buffer] \n\
+ # Mask out the leading bits of a misaligned address \n\
+ vsetivli x0, 1, e64, m1, ta, ma \n\
+ vmv.s.x %[prev_buffer], %[csum] \n\
+ vmv.s.x v0, %[start_mask] \n\
+ vsetvli %[vl], %[len], e8, m1, ta, ma \n\
+ vmnot.m v0, v0 \n\
+ vle8.v %[curr_buffer], (%[buff]), v0.t \n\
+ j 2f \n\
+ # Iterate through the buff and sum all words \n\
+ 1: \n\
+ vsetvli %[vl], %[len], e8, m1, ta, ma \n\
+ vle8.v %[curr_buffer], (%[buff]) \n\
+ 2: \n\
+ vsetvli x0, x0, e32, m1, ta, ma \n\
+ vwredsumu.vs %[prev_buffer], %[curr_buffer], %[prev_buffer] \n\
+ sub %[len], %[len], %[vl] \n\t"
+#ifdef CONFIG_RISCV_ISA_ZBA
+ "sh2add %[iph], %[vl], %[iph] \n\t"
+#else
+ "slli %[vl], %[vl], 2 \n\
+ add %[iph], %[vl], %[iph] \n\t"
+#endif
+ "bnez %[len], 1b \n\
+ vsetvli x0, x0, e64, m1, ta, ma \n\
+ vmv.x.s %[result], %[prev_buffer] \n\
+ addi %[vl], x0, 32 \n\
+ vsrl.vx %[prev_buffer], %[prev_buffer], %[vl] \n\
+ vmv.x.s %[high_result], %[prev_buffer]"
+ : [vl] "=&r" (vl), [prev_buffer] "=&vd" (prev_buffer),
+ [curr_buffer] "=&vd" (curr_buffer),
+ [high_result] "=&r" (high_result),
+ [low_result] "=&r" (low_result)
+ : [buff] "r" (ptr), [len] "r" (len), [start_mask] "r" (start_mask),
+ [csum] "r" (csum));
+
+ high_result += low_result;
+ high_result += high_result < low_result;
+ result = (unsigned int)result + (((unsigned int)result >> 16) | ((unsigned int)result << 16));
+ if (offset & 1)
+ return (unsigned short)swab32(result);
+ return result >> 16;
+}
+#else
+unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
+{
+ vuint64m1_t prev_buffer;
+ vuint32m1_t curr_buffer;
+ unsigned int shift;
+ unsigned long vl, result, csum, offset;
+ unsigned int tail_seg;
+ const unsigned long *ptr;
+
+ if (len <= 0)
+ return 0;
+
+ /*
+ * To align the address, grab the whole first byte in buff.
+ * Directly call KASAN with the alignment we will be using.
+ */
+ offset = (unsigned long)buff & 7;
+ kasan_check_read(buff, len);
+ ptr = (const unsigned long *)(buff - offset);
+ len += offset;
+
+ // Read the tail segment
+ tail_seg = len % 4;
+ csum = 0;
+ if (tail_seg) {
+ shift = (4 - tail_seg) * 8;
+ csum = *(unsigned int *)((const unsigned char *)ptr + len - tail_seg);
+ csum = ((unsigned int)csum << shift) >> shift;
+ len -= tail_seg;
+ }
+
+ unsigned long start_mask = (unsigned int)(~(~0U << offset));
+
+ asm("vsetvli %[vl], %[len], e8, m1, ta, ma \n\
+ # clear out mask and vector registers since we switch up sizes \n\
+ vmclr.m v0 \n\
+ vmclr.m %[prev_buffer] \n\
+ vmclr.m %[curr_buffer] \n\
+ # Mask out the leading bits of a misaligned address \n\
+ vsetivli x0, 1, e64, m1, ta, ma \n\
+ vmv.s.x %[prev_buffer], %[csum] \n\
+ vmv.s.x v0, %[start_mask] \n\
+ vsetvli %[vl], %[len], e8, m1, ta, ma \n\
+ vmnot.m v0, v0 \n\
+ vle8.v %[curr_buffer], (%[buff]), v0.t \n\
+ j 2f \n\
+ # Iterate through the buff and sum all words \n\
+ 1: \n\
+ vsetvli %[vl], %[len], e8, m1, ta, ma \n\
+ vle8.v %[curr_buffer], (%[buff]) \n\
+ 2: \n\
+ vsetvli x0, x0, e32, m1, ta, ma \n\
+ vwredsumu.vs %[prev_buffer], %[curr_buffer], %[prev_buffer] \n\
+ subw %[len], %[len], %[vl] \n\t"
+#ifdef CONFIG_RISCV_ISA_ZBA
+ "sh2add %[iph], %[vl], %[iph] \n\t"
+#else
+ "slli %[vl], %[vl], 2 \n\
+ addw %[iph], %[vl], %[iph] \n\t"
+#endif
+ "bnez %[len], 1b \n\
+ vsetvli x0, x0, e64, m1, ta, ma \n\
+ vmv.x.s %[result], %[prev_buffer]"
+ : [vl] "=&r" (vl), [prev_buffer] "=&vd" (prev_buffer),
+ [curr_buffer] "=&vd" (curr_buffer), [result] "=&r" (result)
+ : [buff] "r" (ptr), [len] "r" (len), [start_mask] "r" (start_mask),
+ [csum] "r" (csum));
+
+ result += (result >> 32) | (result << 32);
+ result >>= 32;
+ result = (unsigned int)result + (((unsigned int)result >> 16) | ((unsigned int)result << 16));
+ if (offset & 1)
+ return (unsigned short)swab32(result);
+ return result >> 16;
+}
+#endif
+#else
unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
{
unsigned int offset, shift;
@@ -116,3 +280,4 @@ unsigned int __no_sanitize_address do_csum(const unsigned char *buff, int len)
return (unsigned short)swab32(csum);
return csum >> 16;
}
+#endif

--
2.41.0