[PATCH v3 1/1] lib/vsprintf: Implement spprintf() to catch truncated strings

From: Lee Jones
Date: Tue Jan 30 2024 - 11:10:13 EST


There is an ongoing effort to replace the use of {v}snprintf() variants
with safer alternatives - for a more in depth view, see Jon's write-up
on LWN [0] and/or Alex's on the Kernel Self Protection Project [1].

Whist executing the task, it quickly became apparent that the initial
thought of simply s/snprintf/scnprintf/ wasn't going to be adequate for
a number of cases. Specifically ones where the caller needs to know
whether the given string ends up being truncated. This is where
spprintf() comes in, since it takes the best parts of both of the
aforementioned variants. It has the testability of truncation of
snprintf() and returns the number of Bytes *actually* written, similar
to scnprintf(), making it a very programmer friendly alternative.

Here's some examples to show the differences:

Success: No truncation - all 9 Bytes successfully written to the buffer

ret = snprintf (buf, 10, "%s", "123456789"); // ret = 9
ret = scnprintf(buf, 10, "%s", "123456789"); // ret = 9
ret = spprintf (buf, 10, "%s", "123456789"); // ret = 9

Failure: Truncation - only 9 of 10 Bytes written; '-' is truncated

ret = snprintf (buf, 10, "%s", "123456789---"); // ret = 12

Reports: "12 Bytes would have been written if buf was large enough"
Issue: Too easy for programmers to assume ret is Bytes written

ret = scnprintf(buf, 10, "%s", "123456789---"); // ret = 9

Reports: "9 Bytes actually written"
Issue: Not testable - returns 9 on success AND failure (see above)

ret = spprintf (buf, 10, "%s", "123456789---"); // ret = 10

Reports: "Data provided is too large to fit in the buffer"
Issue: No tangible impact: No way to tell how much data was lost

Since spprintf() only reports the total size of the buffer, it's easy to
test if they buffer overflowed since if we include the compulsory '\0',
only 9 Bytes additional Bytes can fit, so the return of 10 informs the
caller of an overflow. Also, if the return data is plugged straight
into an additional call to spprintf() after the occurrence of an
overflow, no out-of-bounds will occur:

int size = 10;
char buf[size];
char *b = buf;

ret = spprintf(b, size, "1234");
size -= ret;
b += ret;
// ret = 4 size = 6 buf = "1234\0"

ret = spprintf(b, size, "5678");
size -= ret;
b += ret;
// ret = 4 size = 2 buf = "12345678\0"

ret = spprintf(b, size, "9***");
size -= ret;
b += ret;
// ret = 2 size = 0 buf = "123456789\0"

Since size is now 0, further calls result in no changes of state.

ret = spprintf(b, size, "----");
size -= ret;
b += ret;
// ret = 0 size = 0 buf = "123456789\0"

[0] https://lwn.net/Articles/69419/
[1] https://github.com/KSPP/linux/issues/105
Signed-off-by: Lee Jones <lee@xxxxxxxxxx>
---
Changelog:

v1 => v2:
- Address Rasmus Villemoes's review comments:
- Remove explicit check for zero sized buffer (-E2BIG is appropriate)
- Remove unreachable branch in vssprintf()

v2 => v3:
- Address session from David Laight
- Return 'size' instead of '-E2BIG'

include/linux/sprintf.h | 2 ++
lib/vsprintf.c | 51 +++++++++++++++++++++++++++++++++++++++++
2 files changed, 53 insertions(+)

Cc: Andrew Morton <akpm@xxxxxxxxxxxxxxxxxxxx>
Cc: Petr Mladek <pmladek@xxxxxxxx>
Cc: Steven Rostedt <rostedt@xxxxxxxxxxx>
Cc: Andy Shevchenko <andriy.shevchenko@xxxxxxxxxxxxxxx>
Cc: Rasmus Villemoes <linux@xxxxxxxxxxxxxxxxxx>
Cc: Sergey Senozhatsky <senozhatsky@xxxxxxxxxxxx>
Cc: Crutcher Dunnavant <crutcher+kernel@xxxxxxxxxxxxxx>
Cc: Juergen Quade <quade@xxxxxxx>
Cc: David Laight <David.Laight@xxxxxxxxxx>

diff --git a/include/linux/sprintf.h b/include/linux/sprintf.h
index 33dcbec719254..5c4b7e612ba04 100644
--- a/include/linux/sprintf.h
+++ b/include/linux/sprintf.h
@@ -13,6 +13,8 @@ __printf(3, 4) int snprintf(char *buf, size_t size, const char *fmt, ...);
__printf(3, 0) int vsnprintf(char *buf, size_t size, const char *fmt, va_list args);
__printf(3, 4) int scnprintf(char *buf, size_t size, const char *fmt, ...);
__printf(3, 0) int vscnprintf(char *buf, size_t size, const char *fmt, va_list args);
+__printf(3, 4) int spprintf(char *buf, size_t size, const char *fmt, ...);
+__printf(3, 0) int vspprintf(char *buf, size_t size, const char *fmt, va_list args);
__printf(2, 3) __malloc char *kasprintf(gfp_t gfp, const char *fmt, ...);
__printf(2, 0) __malloc char *kvasprintf(gfp_t gfp, const char *fmt, va_list args);
__printf(2, 0) const char *kvasprintf_const(gfp_t gfp, const char *fmt, va_list args);
diff --git a/lib/vsprintf.c b/lib/vsprintf.c
index 552738f14275a..54d4e170ded1d 100644
--- a/lib/vsprintf.c
+++ b/lib/vsprintf.c
@@ -2936,6 +2936,34 @@ int vscnprintf(char *buf, size_t size, const char *fmt, va_list args)
}
EXPORT_SYMBOL(vscnprintf);

+/**
+ * vspprintf - Format a string and place it in a buffer
+ * @buf: The buffer to place the result into
+ * @size: The size of the buffer, including the trailing null space
+ * @fmt: The format string to use
+ * @args: Arguments for the format string
+ *
+ * The return value is the number of characters which have been written into
+ * the @buf not including the trailing '\0' or the size of the buffer if the
+ * string was truncated.
+ *
+ * If you're not already dealing with a va_list consider using spprintf().
+ *
+ * See the vsnprintf() documentation for format string extensions over C99.
+ */
+int vspprintf(char *buf, size_t size, const char *fmt, va_list args)
+{
+ int i;
+
+ i = vsnprintf(buf, size, fmt, args);
+
+ if (likely(i < size))
+ return i;
+
+ return size;
+}
+EXPORT_SYMBOL(vspprintf);
+
/**
* snprintf - Format a string and place it in a buffer
* @buf: The buffer to place the result into
@@ -2987,6 +3015,29 @@ int scnprintf(char *buf, size_t size, const char *fmt, ...)
}
EXPORT_SYMBOL(scnprintf);

+/**
+ * spprintf - Format a string and place it in a buffer
+ * @buf: The buffer to place the result into
+ * @size: The size of the buffer, including the trailing null space
+ * @fmt: The format string to use
+ * @...: Arguments for the format string
+ *
+ * The return value is the number of characters written into @buf not including
+ * the trailing '\0' or the size of the buffer if the string was truncated.
+ */
+int spprintf(char *buf, size_t size, const char *fmt, ...)
+{
+ va_list args;
+ int i;
+
+ va_start(args, fmt);
+ i = vspprintf(buf, size, fmt, args);
+ va_end(args);
+
+ return i;
+}
+EXPORT_SYMBOL(spprintf);
+
/**
* vsprintf - Format a string and place it in a buffer
* @buf: The buffer to place the result into
--
2.43.0.429.g432eaa2c6b-goog