pirserver.cc 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. // Implementation of the main loop of the pirserver, responsible for the
  2. // communication with the tor process. All the actual private lookup
  3. // work is done by an appropriate subclass of PIRServer.
  4. #include <stdlib.h>
  5. #include <unistd.h>
  6. #include <arpa/inet.h>
  7. #include "pirserver.h"
  8. #define PIRSERVER_HDR_SIZE 13
  9. #define PIRSERVER_REQUEST_PARAMS 0x01
  10. #define PIRSERVER_REQUEST_STORE 0x02
  11. #define PIRSERVER_REQUEST_LOOKUP 0x03
  12. #define PIRSERVER_RESPONSE_PARAMS 0xFF
  13. #define PIRSERVER_RESPONSE_LOOKUP_SUCCESS 0xFE
  14. #define PIRSERVER_RESPONSE_LOOKUP_FAILURE 0xFD
  15. static int
  16. read_all(char *buf, size_t len)
  17. {
  18. int tot_read = 0;
  19. while(len > 0) {
  20. int res = read(0, buf, len);
  21. if (res <= 0) return res;
  22. buf += res;
  23. len -= res;
  24. tot_read += res;
  25. }
  26. return tot_read;
  27. }
  28. static int
  29. write_all(const char *buf, size_t len)
  30. {
  31. int tot_written = 0;
  32. while(len > 0) {
  33. int res = write(1, buf, len);
  34. if (res <= 0) return res;
  35. buf += res;
  36. len -= res;
  37. tot_written += res;
  38. }
  39. return tot_written;
  40. }
  41. void
  42. PIRServer::mainloop()
  43. {
  44. char header[PIRSERVER_HDR_SIZE];
  45. size_t bodylen = 0;
  46. char *body = NULL;
  47. string query, response;
  48. size_t response_len;
  49. while(1) {
  50. // Read the request from stdin
  51. int res = read_all(header, PIRSERVER_HDR_SIZE);
  52. if (res <= 0) return; // stdin has reached EOF (or error); we
  53. // will terminate
  54. bodylen = ntohl(*(uint32_t*)(header+PIRSERVER_HDR_SIZE-4));
  55. if (bodylen > 0) {
  56. body = (char *)malloc(bodylen);
  57. res = read_all(body, bodylen);
  58. if (res <= 0) return;
  59. }
  60. // We have a complete request. Dispatch it.
  61. switch(header[8]) {
  62. case PIRSERVER_REQUEST_PARAMS:
  63. get_params(response);
  64. response_len = response.length();
  65. header[8] = PIRSERVER_RESPONSE_PARAMS;
  66. *(uint32_t*)(header+PIRSERVER_HDR_SIZE-4) = htonl(response_len);
  67. res = write_all(header, PIRSERVER_HDR_SIZE);
  68. if (res <= 0) return;
  69. if (response_len > 0) {
  70. res = write_all(response.c_str(), response_len);
  71. if (res <= 0) return;
  72. }
  73. break;
  74. case PIRSERVER_REQUEST_STORE:
  75. if (bodylen >= 32) {
  76. string key(body, 32);
  77. string value(body+32, bodylen-32);
  78. store(key, value);
  79. }
  80. break;
  81. case PIRSERVER_REQUEST_LOOKUP:
  82. query.assign(body, bodylen);
  83. if (lookup(query, response)) {
  84. response_len = response.length();
  85. header[8] = PIRSERVER_RESPONSE_LOOKUP_SUCCESS;
  86. } else {
  87. response_len = 0;
  88. header[8] = PIRSERVER_RESPONSE_LOOKUP_FAILURE;
  89. }
  90. *(uint32_t*)(header+PIRSERVER_HDR_SIZE-4) = htonl(response_len);
  91. res = write_all(header, PIRSERVER_HDR_SIZE);
  92. if (res <= 0) return;
  93. if (response_len > 0) {
  94. res = write_all(response.c_str(), response_len);
  95. if (res <= 0) return;
  96. }
  97. break;
  98. }
  99. // Clean up for the next request.
  100. free(body);
  101. body = NULL;
  102. bodylen = 0;
  103. }
  104. }