Browse Source

bugfix: snprintf returns wrong # of written characters; and can potentially go overflow

Chia-Che Tsai 7 years ago
parent
commit
46abc3beea
5 changed files with 89 additions and 62 deletions
  1. 11 5
      LibOS/shim/src/utils/printf.c
  2. 2 2
      Pal/lib/api.h
  3. 70 51
      Pal/lib/stdlib/printfmt.c
  4. 3 2
      Pal/src/printf.c
  5. 3 2
      Pal/src/security/Linux/printf.c

+ 11 - 5
LibOS/shim/src/utils/printf.c

@@ -34,22 +34,26 @@ struct debugbuf {
     char buf[DEBUGBUF_SIZE];
     char buf[DEBUGBUF_SIZE];
 };
 };
 
 
-static inline void
+static inline int
 debug_fputs (void * f, const char * buf, int len)
 debug_fputs (void * f, const char * buf, int len)
 {
 {
-    DkStreamWrite(debug_handle, 0, len, (void *) buf, NULL);
+    if (DkStreamWrite(debug_handle, 0, len, (void *) buf, NULL) == len)
+        return 0;
+    else
+        return -1;
 }
 }
 
 
-static void
+static int
 debug_fputch (void * f, int ch, void * b)
 debug_fputch (void * f, int ch, void * b)
 {
 {
     struct debug_buf * buf = (struct debug_buf *) b;
     struct debug_buf * buf = (struct debug_buf *) b;
     buf->buf[buf->end++] = ch;
     buf->buf[buf->end++] = ch;
 
 
     if (ch == '\n') {
     if (ch == '\n') {
-        debug_fputs(NULL, buf->buf, buf->end);
+        if (debug_fputs(NULL, buf->buf, buf->end) == -1)
+            return -1;
         buf->end = buf->start;
         buf->end = buf->start;
-        return;
+        return 0;
     }
     }
 
 
     if (buf->end == DEBUGBUF_SIZE - 4) {
     if (buf->end == DEBUGBUF_SIZE - 4) {
@@ -61,6 +65,8 @@ debug_fputch (void * f, int ch, void * b)
         buf->buf[buf->end++] = '.';
         buf->buf[buf->end++] = '.';
         buf->buf[buf->end++] = '.';
         buf->buf[buf->end++] = '.';
     }
     }
+
+    return 0;
 }
 }
 
 
 void debug_puts (const char * str)
 void debug_puts (const char * str)

+ 2 - 2
Pal/lib/api.h

@@ -77,10 +77,10 @@ int memcmp (const void *s1, const void *s2, int len);
      static_strlen(force_static(str)))
      static_strlen(force_static(str)))
 
 
 /* Libc printf functions */
 /* Libc printf functions */
-void fprintfmt (void (*_fputch)(void *, int, void *), void * f, void * putdat,
+void fprintfmt (int (*_fputch)(void *, int, void *), void * f, void * putdat,
                 const char * fmt, ...);
                 const char * fmt, ...);
 
 
-void vfprintfmt (void (*_fputch)(void *, int, void *), void * f, void * putdat,
+void vfprintfmt (int (*_fputch)(void *, int, void *), void * f, void * putdat,
                  const char * fmt, va_list *ap);
                  const char * fmt, va_list *ap);
 
 
 int snprintf (char * buf, int n, const char * fmt, ...);
 int snprintf (char * buf, int n, const char * fmt, ...);

+ 70 - 51
Pal/lib/stdlib/printfmt.c

@@ -12,26 +12,31 @@
 // Print a number (base <= 16) in reverse order,
 // Print a number (base <= 16) in reverse order,
 // using specified fputch function and associated pointer putdat.
 // using specified fputch function and associated pointer putdat.
 #if !defined(__i386__)
 #if !defined(__i386__)
-static void
-printnum(void (*_fputch)(void *, int, void *), void * f, void * putdat,
+static int
+printnum(int (*_fputch)(void *, int, void *), void * f, void * putdat,
 	 unsigned long long num, unsigned base, int width, int padc)
 	 unsigned long long num, unsigned base, int width, int padc)
 #else
 #else
-static void
-printnum(void (*_fputch)(void *, int, void *), void * f, void * putdat,
+static int
+printnum(int (*_fputch)(void *, int, void *), void * f, void * putdat,
 	 unsigned long num, unsigned base, int width, int padc)
 	 unsigned long num, unsigned base, int width, int padc)
 #endif
 #endif
 {
 {
 	// first recursively print all preceding (more significant) digits
 	// first recursively print all preceding (more significant) digits
 	if (num >= base) {
 	if (num >= base) {
-		printnum(_fputch, f, putdat, num / base, base, width - 1, padc);
+		if (printnum(_fputch, f, putdat, num / base, base, width - 1, padc) == -1)
+			return -1;
 	} else {
 	} else {
 		// print any needed pad characters before first digit
 		// print any needed pad characters before first digit
 		while (--width > 0)
 		while (--width > 0)
-			(*_fputch) (f, padc, putdat);
+			if ((*_fputch) (f, padc, putdat) == -1)
+				return -1;
 	}
 	}
 
 
 	// then print this (the least significant) digit
 	// then print this (the least significant) digit
-	(*_fputch) (f, "0123456789abcdef"[num % base], putdat);
+	if ((*_fputch) (f, "0123456789abcdef"[num % base], putdat) == -1)
+		return -1;
+
+	return 0;
 }
 }
 
 
 // Get an unsigned int of various possible sizes from a varargs list,
 // Get an unsigned int of various possible sizes from a varargs list,
@@ -77,12 +82,12 @@ getint(va_list *ap, int lflag)
 }
 }
 
 
 // Main function to format and print a string.
 // Main function to format and print a string.
-void fprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
-			        const char * fmt, ...);
+void fprintfmt(int (*_fputch)(void *, int, void *), void * f, void * putdat,
+			      const char * fmt, ...);
 
 
 void
 void
-vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
-			   const char * fmt, va_list *ap)
+vfprintfmt(int (*_fputch)(void *, int, void *), void * f, void * putdat,
+			  const char * fmt, va_list *ap)
 {
 {
 	register const char *p;
 	register const char *p;
 	register int ch;
 	register int ch;
@@ -98,7 +103,8 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 		while ((ch = *(unsigned char *) (fmt++)) != '%') {
 		while ((ch = *(unsigned char *) (fmt++)) != '%') {
 			if (ch == '\0')
 			if (ch == '\0')
 				return;
 				return;
-			(*_fputch) (f, ch, putdat);
+			if ((*_fputch) (f, ch, putdat) < 0)
+				return;
 		}
 		}
 
 
 		// Process a %-escape sequence
 		// Process a %-escape sequence
@@ -163,7 +169,8 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 
 
 		// character
 		// character
 		case 'c':
 		case 'c':
-			(*_fputch) (f, va_arg(*ap, int), putdat);
+			if ((*_fputch) (f, va_arg(*ap, int), putdat) == -1)
+				return;
 			break;
 			break;
 
 
 		// string
 		// string
@@ -172,14 +179,19 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 				p = "(null)";
 				p = "(null)";
 			if (width > 0 && padc != '-')
 			if (width > 0 && padc != '-')
 				for (width -= strnlen(p, precision); width > 0; width--)
 				for (width -= strnlen(p, precision); width > 0; width--)
-					(*_fputch) (f, padc, putdat);
+					if ((*_fputch) (f, padc, putdat) == -1)
+						return;
 			for (; (ch = *p++) != '\0' && (precision < 0 || --precision >= 0); width--)
 			for (; (ch = *p++) != '\0' && (precision < 0 || --precision >= 0); width--)
-				if (altflag && (ch < ' ' || ch > '~'))
-					(*_fputch) (f, '?', putdat);
-				else
-					(*_fputch) (f, ch, putdat);
+				if (altflag && (ch < ' ' || ch > '~')) {
+					if ((*_fputch) (f, '?', putdat) == -1)
+						return;
+				} else {
+					if ((*_fputch) (f, ch, putdat) == -1)
+						return;
+				}
 			for (; width > 0; width--)
 			for (; width > 0; width--)
-				(*_fputch) (f, ' ', putdat);
+				if ((*_fputch) (f, ' ', putdat) == -1)
+					return;
 			break;
 			break;
 
 
 		// (signed) decimal
 		// (signed) decimal
@@ -187,12 +199,14 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 			num = getint(ap, lflag);
 			num = getint(ap, lflag);
 #if !defined(__i386__)
 #if !defined(__i386__)
 			if ((long long) num < 0) {
 			if ((long long) num < 0) {
-				(*_fputch) (f, '-', putdat);
+				if ((*_fputch) (f, '-', putdat) == -1)
+					return;
 				num = -(long long) num;
 				num = -(long long) num;
 			}
 			}
 #else
 #else
 			if ((long) num < 0) {
 			if ((long) num < 0) {
-				(*_fputch) (f, '-', putdat);
+				if ((*_fputch) (f, '-', putdat) == -1)
+					return;
 				num = -(long) num;
 				num = -(long) num;
 			}
 			}
 #endif
 #endif
@@ -214,8 +228,10 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 
 
 		// pointer
 		// pointer
 		case 'p':
 		case 'p':
-			(*_fputch) (f, '0', putdat);
-			(*_fputch) (f, 'x', putdat);
+			if ((*_fputch) (f, '0', putdat) == -1)
+				return;
+			if ((*_fputch) (f, 'x', putdat) == -1)
+				return;
 #if !defined(__i386__)
 #if !defined(__i386__)
 			num = (unsigned long long)
 			num = (unsigned long long)
 				(uintptr_t) va_arg(*ap, void *);
 				(uintptr_t) va_arg(*ap, void *);
@@ -231,13 +247,15 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 			num = getuint(ap, lflag);
 			num = getuint(ap, lflag);
 			base = 16;
 			base = 16;
 		number:
 		number:
-			printnum(_fputch, f, putdat, num, base, width, padc);
+			if (printnum(_fputch, f, putdat, num, base, width, padc) == -1)
+				return;
 			break;
 			break;
 
 
-                // escape character
-                case '^':
-                        (*_fputch) (f, 0x1b, putdat);
-                        break;
+		// escape character
+		case '^':
+			if ((*_fputch) (f, 0x1b, putdat) == -1)
+				return;
+			break;
 
 
 		// escaped '%' character
 		// escaped '%' character
 		case '%':
 		case '%':
@@ -255,7 +273,7 @@ vfprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 }
 }
 
 
 void
 void
-fprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
+fprintfmt(int (*_fputch)(void *, int, void *), void * f, void * putdat,
 			   const char * fmt, ...)
 			   const char * fmt, ...)
 {
 {
 	va_list ap;
 	va_list ap;
@@ -266,46 +284,47 @@ fprintfmt(void (*_fputch)(void *, int, void *), void * f, void * putdat,
 }
 }
 
 
 struct sprintbuf {
 struct sprintbuf {
-    char * buf;
-    char * ebuf;
-    int cnt;
+	int cnt, max;
+	char * buf;
 };
 };
 
 
-static void
+static int
 sprintputch(void * f, int ch, struct sprintbuf * b)
 sprintputch(void * f, int ch, struct sprintbuf * b)
 {
 {
-    b->cnt++;
-    if (b->buf < b->ebuf)
-        *b->buf++ = ch;
+	if (b->cnt >= b->max)
+		return -1;
+
+	b->buf[b->cnt++] = ch;
+	return 0;
 }
 }
 
 
 static int
 static int
 vsprintf(char * buf, int n, const char * fmt, va_list *ap)
 vsprintf(char * buf, int n, const char * fmt, va_list *ap)
 {
 {
-    struct sprintbuf b = {buf, buf + n - 1, 0};
+	struct sprintbuf b = { 0, n, buf };
 
 
-    if (buf == NULL || n < 1) {
-        return -1;
-    }
+	if (!buf || n < 1)
+		return 0;
 
 
-    // print the string to the buffer
-    vfprintfmt((void *) sprintputch, (void *) 0, &b, fmt, ap);
+	// print the string to the buffer
+	vfprintfmt((void *) sprintputch, (void *) 0, &b, fmt, ap);
 
 
-    // null terminate the buffer
-    *b.buf = '\0';
+	// null terminate the buffer
+	if (b.cnt < n)
+		b.buf[b.cnt] = '\0';
 
 
-    return b.cnt;
+	return b.cnt;
 }
 }
 
 
 int
 int
 snprintf(char * buf, int n, const char * fmt, ...)
 snprintf(char * buf, int n, const char * fmt, ...)
 {
 {
-    va_list ap;
-    int rc;
+	va_list ap;
+	int rc;
 
 
-    va_start(ap, fmt);
-    rc = vsprintf(buf, n, fmt, &ap);
-    va_end(ap);
+	va_start(ap, fmt);
+	rc = vsprintf(buf, n, fmt, &ap);
+	va_end(ap);
 
 
-    return rc;
+	return rc;
 }
 }

+ 3 - 2
Pal/src/printf.c

@@ -36,15 +36,16 @@ struct printbuf {
     char buf[PRINTBUF_SIZE];
     char buf[PRINTBUF_SIZE];
 };
 };
 
 
-static void
+static int
 fputch(void * f, int ch, struct printbuf * b)
 fputch(void * f, int ch, struct printbuf * b)
 {
 {
     b->buf[b->idx++] = ch;
     b->buf[b->idx++] = ch;
-    if (b->idx == PRINTBUF_SIZE-1) {
+    if (b->idx == PRINTBUF_SIZE - 1) {
         _DkPrintConsole(b->buf, b->idx);
         _DkPrintConsole(b->buf, b->idx);
         b->idx = 0;
         b->idx = 0;
     }
     }
     b->cnt++;
     b->cnt++;
+    return 0;
 }
 }
 
 
 static int
 static int

+ 3 - 2
Pal/src/security/Linux/printf.c

@@ -30,15 +30,16 @@ struct sprintbuf {
 
 
 #define sys_cputs(fd, bf, cnt) INLINE_SYSCALL(write, 3, (fd), (bf), (cnt))
 #define sys_cputs(fd, bf, cnt) INLINE_SYSCALL(write, 3, (fd), (bf), (cnt))
 
 
-static void
+static int
 fputch(int fd, int ch, struct printbuf *b)
 fputch(int fd, int ch, struct printbuf *b)
 {
 {
 	b->buf[b->idx++] = ch;
 	b->buf[b->idx++] = ch;
-	if (b->idx == PRINTBUF_SIZE-1) {
+	if (b->idx == PRINTBUF_SIZE - 1) {
 		sys_cputs(fd, b->buf, b->idx);
 		sys_cputs(fd, b->buf, b->idx);
 		b->idx = 0;
 		b->idx = 0;
 	}
 	}
 	b->cnt++;
 	b->cnt++;
+	return 0;
 }
 }
 
 
 static int
 static int