瀏覽代碼

Make all the other read/writes into recv/sends, except when they shouldn't be.

svn:r1260
Nick Mathewson 21 年之前
父節點
當前提交
976bacae94
共有 4 個文件被更改,包括 21 次插入15 次删除
  1. 11 5
      src/common/util.c
  2. 2 2
      src/common/util.h
  3. 4 4
      src/or/cpuworker.c
  4. 4 4
      src/or/dns.c

+ 11 - 5
src/common/util.c

@@ -282,12 +282,15 @@ time_t tor_timegm (struct tm *tm) {
 
 /* a wrapper for write(2) that makes sure to write all count bytes.
  * Only use if fd is a blocking fd. */
-int write_all(int fd, const char *buf, size_t count) {
+int write_all(int fd, const char *buf, size_t count, int isSocket) {
   size_t written = 0;
   int result;
 
   while(written != count) {
-    result = write(fd, buf+written, count-written);
+    if (isSocket)
+      result = send(fd, buf+written, count-written, 0);
+    else
+      result = write(fd, buf+written, count-written);
     if(result<0)
       return -1;
     written += result;
@@ -297,12 +300,15 @@ int write_all(int fd, const char *buf, size_t count) {
 
 /* a wrapper for read(2) that makes sure to read all count bytes.
  * Only use if fd is a blocking fd. */
-int read_all(int fd, char *buf, size_t count) {
+int read_all(int fd, char *buf, size_t count, int isSocket) {
   size_t numread = 0;
   int result;
 
   while(numread != count) {
-    result = read(fd, buf+numread, count-numread);
+    if (isSocket) 
+      result = recv(fd, buf+numread, count-numread, 0);
+    else
+      result = read(fd, buf+numread, count-numread);
     if(result<=0)
       return -1;
     numread += result;
@@ -615,7 +621,7 @@ char *read_file_to_str(const char *filename) {
 
   string = tor_malloc(statbuf.st_size+1);
 
-  if(read_all(fd,string,statbuf.st_size) != statbuf.st_size) {
+  if(read_all(fd,string,statbuf.st_size,0) != statbuf.st_size) {
     log_fn(LOG_WARN,"Couldn't read all %ld bytes of file '%s'.",
            (long)statbuf.st_size,filename);
     free(string);

+ 2 - 2
src/common/util.h

@@ -66,8 +66,8 @@ void tv_add(struct timeval *a, struct timeval *b);
 int tv_cmp(struct timeval *a, struct timeval *b);
 time_t tor_timegm (struct tm *tm);
 
-int write_all(int fd, const char *buf, size_t count);
-int read_all(int fd, char *buf, size_t count);
+int write_all(int fd, const char *buf, size_t count, int isSocket);
+int read_all(int fd, char *buf, size_t count, int isSocket);
 
 void set_socket_nonblocking(int socket);
 

+ 4 - 4
src/or/cpuworker.c

@@ -132,19 +132,19 @@ int cpuworker_main(void *data) {
 
   for(;;) {
 
-    if(read(fd, &question_type, 1) != 1) {
+    if(recv(fd, &question_type, 1, 0) != 1) {
 //      log_fn(LOG_ERR,"read type failed. Exiting.");
       log_fn(LOG_INFO,"cpuworker exiting because tor process died.");
       spawn_exit();
     }
     assert(question_type == CPUWORKER_TASK_ONION);
 
-    if(read_all(fd, tag, TAG_LEN) != TAG_LEN) {
+    if(read_all(fd, tag, TAG_LEN, 1) != TAG_LEN) {
       log_fn(LOG_ERR,"read tag failed. Exiting.");
       spawn_exit();
     }
 
-    if(read_all(fd, question, ONIONSKIN_CHALLENGE_LEN) != ONIONSKIN_CHALLENGE_LEN) {
+    if(read_all(fd, question, ONIONSKIN_CHALLENGE_LEN, 1) != ONIONSKIN_CHALLENGE_LEN) {
       log_fn(LOG_ERR,"read question failed. Exiting.");
       spawn_exit();
     }
@@ -163,7 +163,7 @@ int cpuworker_main(void *data) {
         memcpy(buf+1+TAG_LEN,reply_to_proxy,ONIONSKIN_REPLY_LEN);
         memcpy(buf+1+TAG_LEN+ONIONSKIN_REPLY_LEN,keys,40+32);
       }
-      if(write_all(fd, buf, LEN_ONION_RESPONSE) != LEN_ONION_RESPONSE) {
+      if(write_all(fd, buf, LEN_ONION_RESPONSE, 1) != LEN_ONION_RESPONSE) {
         log_fn(LOG_ERR,"writing response buf failed. Exiting.");
         spawn_exit();
       }

+ 4 - 4
src/or/dns.c

@@ -397,14 +397,14 @@ int dnsworker_main(void *data) {
 
   for(;;) {
 
-    if(read(fd, &address_len, 1) != 1) {
+    if(recv(fd, &address_len, 1, 0) != 1) {
 //      log_fn(LOG_INFO,"read length failed. Child exiting.");
       log_fn(LOG_INFO,"dnsworker exiting because tor process died.");
       spawn_exit();
     }
     assert(address_len > 0);
 
-    if(read_all(fd, address, address_len) != address_len) {
+    if(read_all(fd, address, address_len, 1) != address_len) {
       log_fn(LOG_ERR,"read hostname failed. Child exiting.");
       spawn_exit();
     }
@@ -413,13 +413,13 @@ int dnsworker_main(void *data) {
     rent = gethostbyname(address);
     if (!rent) {
       log_fn(LOG_INFO,"Could not resolve dest addr %s. Returning nulls.",address);
-      if(write_all(fd, "\0\0\0\0", 4) != 4) {
+      if(write_all(fd, "\0\0\0\0", 4, 1) != 4) {
         log_fn(LOG_ERR,"writing nulls failed. Child exiting.");
         spawn_exit();
       }
     } else {
       assert(rent->h_length == 4); /* break to remind us if we move away from ipv4 */
-      if(write_all(fd, rent->h_addr, 4) != 4) {
+      if(write_all(fd, rent->h_addr, 4, 1) != 4) {
         log_fn(LOG_INFO,"writing answer failed. Child exiting.");
         spawn_exit();
       }