Re: [PATCH 1/3] riscv: optimized memcpy

From: Nick Kossifidis
Date: Tue Jan 30 2024 - 07:12:07 EST


On 1/28/24 13:10, Jisheng Zhang wrote:
+
+void *__memcpy(void *dest, const void *src, size_t count)
+{
+ union const_types s = { .as_u8 = src };
+ union types d = { .as_u8 = dest };
+ int distance = 0;
+
+ if (count < MIN_THRESHOLD)
+ goto copy_remainder;
+
+ if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS)) {
+ /* Copy a byte at time until destination is aligned. */
+ for (; d.as_uptr & WORD_MASK; count--)
+ *d.as_u8++ = *s.as_u8++;
+
+ distance = s.as_uptr & WORD_MASK;
+ }
+
+ if (distance) {
+ unsigned long last, next;
+
+ /*
+ * s is distance bytes ahead of d, and d just reached
+ * the alignment boundary. Move s backward to word align it
+ * and shift data to compensate for distance, in order to do
+ * word-by-word copy.
+ */
+ s.as_u8 -= distance;
+
+ next = s.as_ulong[0];
+ for (; count >= BYTES_LONG; count -= BYTES_LONG) {
+ last = next;
+ next = s.as_ulong[1];
+
+ d.as_ulong[0] = last >> (distance * 8) |
+ next << ((BYTES_LONG - distance) * 8);
+
+ d.as_ulong++;
+ s.as_ulong++;
+ }
+
+ /* Restore s with the original offset. */
+ s.as_u8 += distance;
+ } else {
+ /*
+ * If the source and dest lower bits are the same, do a simple
+ * aligned copy.
+ */
+ size_t aligned_count = count & ~(BYTES_LONG * 8 - 1);
+
+ __memcpy_aligned(d.as_ulong, s.as_ulong, aligned_count);
+ d.as_u8 += aligned_count;
+ s.as_u8 += aligned_count;
+ count &= BYTES_LONG * 8 - 1;
+ }
+
+copy_remainder:
+ while (count--)
+ *d.as_u8++ = *s.as_u8++;
+
+ return dest;
+}
+EXPORT_SYMBOL(__memcpy);
+

We could also implement memcmp this way, e.g.:

int
memcmp(const void *s1, const void *s2, size_t len)
{
union const_data a = { .as_bytes = s1 };
union const_data b = { .as_bytes = s2 };
unsigned long a_val = 0;
unsigned long b_val = 0;
size_t remaining = len;
size_t a_offt = 0;

/* Nothing to do */
if (!s1 || !s2 || s1 == s2 || !len)
return 0;

if (len < 2 * WORD_SIZE)
goto trailing_fw;

for(; b.as_uptr & WORD_MASK; remaining--) {
a_val = *a.as_bytes++;
b_val = *b.as_bytes++;
if (a_val != b_val)
goto done;
}

a_offt = a.as_uptr & WORD_MASK;
if (!a_offt) {
for (; remaining >= WORD_SIZE; remaining -= WORD_SIZE) {
a_val = *a.as_ulong++;
b_val = *b.as_ulong++;
if (a_val != b_val)
break;

}
} else {
unsigned long a_cur, a_next;
a.as_bytes -= a_offt;
a_next = *a.as_ulong;
for (; remaining >= WORD_SIZE; remaining -= WORD_SIZE, b.as_ulong++) {
a_cur = a_next;
a_next = *++a.as_ulong;
a_val = a_cur >> (a_offt * 8) |
a_next << ((WORD_SIZE - a_offt) * 8);
b_val = *b.as_ulong;
if (a_val != b_val) {
a.as_bytes += a_offt;
break;
}
}
a.as_bytes += a_offt;
}

trailing_fw:
while (remaining-- > 0) {
a_val = *a.as_bytes++;
b_val = *b.as_bytes++;
if (a_val != b_val)
break;
}

done:
if (!remaining)
return 0;

return (int) (a_val - b_val);
}

Regards,
Nick