Browse Source

Buffer data in the tor -> pirserver direction

The pirserver will be doing a lot of computation, and we don't
want to be blocking the tor process if it tries to write to the
pirserver process while it's computing.
Ian Goldberg 3 years ago
parent
commit
7cccc11e90
1 changed files with 76 additions and 15 deletions
  1. 76 15
      src/feature/hs/hs_cache.c

+ 76 - 15
src/feature/hs/hs_cache.c

@@ -977,6 +977,8 @@ hs_cache_get_max_descriptor_size(void)
 }
 
 static process_handle_t *pirserver;
+static buf_t *pirserver_stdin_buf;
+static struct event *pirserver_stdin_ev;
 static struct event *pirserver_stdout_ev;
 static struct event *pirserver_stderr_ev;
 
@@ -1005,7 +1007,7 @@ hs_cache_pirserver_received(const unsigned char *hdrbuf,
 
 /* This is called when the pirserver has output for us. */
 static void
-hs_cache_pirserver_recvcb(evutil_socket_t fd, short what,
+hs_cache_pirserver_stdoutcb(evutil_socket_t fd, short what,
         ATTR_UNUSED void *arg) {
     static PIRServerReadState readstate = PIRSERVER_READSTATE_HEADER;
     static size_t readoff = 0;
@@ -1057,6 +1059,56 @@ hs_cache_pirserver_recvcb(evutil_socket_t fd, short what,
     }
 }
 
+/* This is called when the pirserver is ready to read from its stdin. */
+static void
+hs_cache_pirserver_stdincb(evutil_socket_t fd, short what,
+        ATTR_UNUSED void *arg) {
+    int res;
+    size_t bufsize = buf_datalen(pirserver_stdin_buf);
+    char *netbuf = NULL;
+
+    if (!(what & EV_WRITE)) {
+        /* Not sure why we're here */
+        log_info(LD_DIRSERV,"PIRSERVER bailing");
+        return;
+    }
+
+    if (bufsize == 0) {
+        log_err(LD_DIRSERV,"PIRSERVER trying to write 0-length buffer");
+        return;
+    }
+
+    netbuf = malloc(bufsize);
+    if (netbuf == NULL) {
+        log_err(LD_DIRSERV,"PIRSERVER failed to allocate buffer");
+        return;
+    }
+
+    /* One might think that just calling buf_flush_to_socket would be
+     * the thing to do, but that function ends up calling sendto()
+     * instead of write(), which doesn't work on pipes.  So we do it
+     * more manually.  Using a bufferevent may be another option. */
+    buf_peek(pirserver_stdin_buf, netbuf, bufsize);
+    res = write(fd, netbuf, bufsize);
+    free(netbuf);
+    if (res < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) {
+        /* Try writing again later */
+        event_add(pirserver_stdin_ev, NULL);
+        return;
+    }
+    if (res <= 0) {
+        /* Stop trying to write. */
+        return;
+    }
+    buf_drain(pirserver_stdin_buf, res);
+    bufsize -= res;
+
+    if (bufsize > 0) {
+        /* There's more to write */
+        event_add(pirserver_stdin_ev, NULL);
+    }
+}
+
 /* This is called when the pirserver writes something to its stderr. */
 static void
 hs_cache_pirserver_stderrcb(evutil_socket_t fd, short what,
@@ -1090,6 +1142,14 @@ hs_cache_pir_poke(void)
     }
 
     if (pirserver) {
+        if (pirserver_stdin_buf) {
+            buf_free(pirserver_stdin_buf);
+            pirserver_stdin_buf = NULL;
+        }
+        if (pirserver_stdin_ev) {
+            event_free(pirserver_stdin_ev);
+            pirserver_stdin_ev = NULL;
+        }
         if (pirserver_stdout_ev) {
             event_free(pirserver_stdout_ev);
             pirserver_stdout_ev = NULL;
@@ -1125,7 +1185,7 @@ hs_cache_pir_poke(void)
     /* Create a libevent event to listen to the PIR server's responses. */
     pirserver_stdout_ev = event_new(tor_libevent_get_base(),
         pirserver->stdout_pipe, EV_READ|EV_PERSIST,
-        hs_cache_pirserver_recvcb, NULL);
+        hs_cache_pirserver_stdoutcb, NULL);
     event_add(pirserver_stdout_ev, NULL);
 
     /* And one to listen to the PIR server's stderr. */
@@ -1133,6 +1193,13 @@ hs_cache_pir_poke(void)
         pirserver->stderr_pipe, EV_READ|EV_PERSIST,
         hs_cache_pirserver_stderrcb, NULL);
     event_add(pirserver_stderr_ev, NULL);
+
+    /* And one for writability to the pirserver's stdin, but don't add
+     * it just yet.  Also create the buffer it will use. */
+    pirserver_stdin_buf = buf_new();
+    pirserver_stdin_ev = event_new(tor_libevent_get_base(),
+        pirserver->stdin_pipe, EV_WRITE, hs_cache_pirserver_stdincb,
+        NULL);
 }
 
 /* Initialize the hidden service cache PIR subsystem. */
@@ -1140,6 +1207,8 @@ static void
 hs_cache_pir_init(void)
 {
     pirserver = NULL;
+    pirserver_stdin_buf = NULL;
+    pirserver_stdin_ev = NULL;
     pirserver_stdout_ev = NULL;
     pirserver_stderr_ev = NULL;
 }
@@ -1147,26 +1216,18 @@ hs_cache_pir_init(void)
 static int
 hs_cache_pirserver_send(const unsigned char *buf, size_t len)
 {
-    size_t written = 0;
-
     hs_cache_pir_poke();
     if (pirserver == NULL || pirserver->status != PROCESS_STATUS_RUNNING) {
         /* Launch failed */
         return -1;
     }
 
-    /* PIRONION TODO: For now, we're going to assume that writing to the
-     * pir server will never block, but we should actually put this into
-     * a write buffer. */
-    while (len) {
-        ssize_t res = write(pirserver->stdin_pipe, buf, len);
-        if (res < 0) return res;
-        if (res == 0) return written;
-        written += res;
-        buf += res;
-        len -= res;
+    /* Write the data to the stdin buffer */
+    if (len > 0) {
+        buf_add(pirserver_stdin_buf, (const char *)buf, len);
+        event_add(pirserver_stdin_ev, NULL);
     }
-    return written;
+    return len;
 }
 
 static int