nss_countbytes.c 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. /* Copyright 2018-2019, The Tor Project Inc. */
  2. /* See LICENSE for licensing information */
  3. /**
  4. * \file nss_countbytes.c
  5. * \brief A PRFileDesc layer to let us count the number of bytes
  6. * bytes actually written on a PRFileDesc.
  7. **/
  8. #include "orconfig.h"
  9. #include "lib/log/util_bug.h"
  10. #include "lib/malloc/malloc.h"
  11. #include "lib/tls/nss_countbytes.h"
  12. #include <stdlib.h>
  13. #include <string.h>
  14. #include <prio.h>
  15. /** Boolean: have we initialized this module */
  16. static bool countbytes_initialized = false;
  17. /** Integer to identity this layer. */
  18. static PRDescIdentity countbytes_layer_id = PR_INVALID_IO_LAYER;
  19. /** Table of methods for this layer.*/
  20. static PRIOMethods countbytes_methods;
  21. /** Default close function provided by NSPR. We use this to help
  22. * implement our own close function.*/
  23. static PRStatus(*default_close_fn)(PRFileDesc *fd);
  24. static PRStatus countbytes_close_fn(PRFileDesc *fd);
  25. static PRInt32 countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount);
  26. static PRInt32 countbytes_write_fn(PRFileDesc *fd, const void *buf,
  27. PRInt32 amount);
  28. static PRInt32 countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
  29. PRInt32 size, PRIntervalTime timeout);
  30. static PRInt32 countbytes_send_fn(PRFileDesc *fd, const void *buf,
  31. PRInt32 amount, PRIntn flags,
  32. PRIntervalTime timeout);
  33. static PRInt32 countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
  34. PRIntn flags, PRIntervalTime timeout);
  35. /** Private fields for the byte-counter layer. We cast this to and from
  36. * PRFilePrivate*, which is supposed to be allowed. */
  37. typedef struct tor_nss_bytecounts_t {
  38. uint64_t n_read;
  39. uint64_t n_written;
  40. } tor_nss_bytecounts_t;
  41. /**
  42. * Initialize this module, if it is not already initialized.
  43. **/
  44. void
  45. tor_nss_countbytes_init(void)
  46. {
  47. if (countbytes_initialized)
  48. return;
  49. countbytes_layer_id = PR_GetUniqueIdentity("Tor byte-counting layer");
  50. tor_assert(countbytes_layer_id != PR_INVALID_IO_LAYER);
  51. memcpy(&countbytes_methods, PR_GetDefaultIOMethods(), sizeof(PRIOMethods));
  52. default_close_fn = countbytes_methods.close;
  53. countbytes_methods.close = countbytes_close_fn;
  54. countbytes_methods.read = countbytes_read_fn;
  55. countbytes_methods.write = countbytes_write_fn;
  56. countbytes_methods.writev = countbytes_writev_fn;
  57. countbytes_methods.send = countbytes_send_fn;
  58. countbytes_methods.recv = countbytes_recv_fn;
  59. /* NOTE: We aren't wrapping recvfrom, sendto, or sendfile, since I think
  60. * NSS won't be using them for TLS connections. */
  61. countbytes_initialized = true;
  62. }
  63. /**
  64. * Return the tor_nss_bytecounts_t object for a given IO layer. Asserts that
  65. * the IO layer is in fact a layer created by this module.
  66. */
  67. static tor_nss_bytecounts_t *
  68. get_counts(PRFileDesc *fd)
  69. {
  70. tor_assert(fd->identity == countbytes_layer_id);
  71. return (tor_nss_bytecounts_t*) fd->secret;
  72. }
  73. /** Helper: increment the read-count of an fd by n. */
  74. #define INC_READ(fd, n) STMT_BEGIN \
  75. get_counts(fd)->n_read += (n); \
  76. STMT_END
  77. /** Helper: increment the write-count of an fd by n. */
  78. #define INC_WRITTEN(fd, n) STMT_BEGIN \
  79. get_counts(fd)->n_written += (n); \
  80. STMT_END
  81. /** Implementation for PR_Close: frees the 'secret' field, then passes control
  82. * to the default close function */
  83. static PRStatus
  84. countbytes_close_fn(PRFileDesc *fd)
  85. {
  86. tor_assert(fd);
  87. tor_nss_bytecounts_t *counts = (tor_nss_bytecounts_t *)fd->secret;
  88. tor_free(counts);
  89. fd->secret = NULL;
  90. return default_close_fn(fd);
  91. }
  92. /** Implementation for PR_Read: Calls the lower-level read function,
  93. * and records what it said. */
  94. static PRInt32
  95. countbytes_read_fn(PRFileDesc *fd, void *buf, PRInt32 amount)
  96. {
  97. tor_assert(fd);
  98. tor_assert(fd->lower);
  99. PRInt32 result = (fd->lower->methods->read)(fd->lower, buf, amount);
  100. if (result > 0)
  101. INC_READ(fd, result);
  102. return result;
  103. }
  104. /** Implementation for PR_Write: Calls the lower-level write function,
  105. * and records what it said. */
  106. static PRInt32
  107. countbytes_write_fn(PRFileDesc *fd, const void *buf, PRInt32 amount)
  108. {
  109. tor_assert(fd);
  110. tor_assert(fd->lower);
  111. PRInt32 result = (fd->lower->methods->write)(fd->lower, buf, amount);
  112. if (result > 0)
  113. INC_WRITTEN(fd, result);
  114. return result;
  115. }
  116. /** Implementation for PR_Writev: Calls the lower-level writev function,
  117. * and records what it said. */
  118. static PRInt32
  119. countbytes_writev_fn(PRFileDesc *fd, const PRIOVec *iov,
  120. PRInt32 size, PRIntervalTime timeout)
  121. {
  122. tor_assert(fd);
  123. tor_assert(fd->lower);
  124. PRInt32 result = (fd->lower->methods->writev)(fd->lower, iov, size, timeout);
  125. if (result > 0)
  126. INC_WRITTEN(fd, result);
  127. return result;
  128. }
  129. /** Implementation for PR_Send: Calls the lower-level send function,
  130. * and records what it said. */
  131. static PRInt32
  132. countbytes_send_fn(PRFileDesc *fd, const void *buf,
  133. PRInt32 amount, PRIntn flags, PRIntervalTime timeout)
  134. {
  135. tor_assert(fd);
  136. tor_assert(fd->lower);
  137. PRInt32 result = (fd->lower->methods->send)(fd->lower, buf, amount, flags,
  138. timeout);
  139. if (result > 0)
  140. INC_WRITTEN(fd, result);
  141. return result;
  142. }
  143. /** Implementation for PR_Recv: Calls the lower-level recv function,
  144. * and records what it said. */
  145. static PRInt32
  146. countbytes_recv_fn(PRFileDesc *fd, void *buf, PRInt32 amount,
  147. PRIntn flags, PRIntervalTime timeout)
  148. {
  149. tor_assert(fd);
  150. tor_assert(fd->lower);
  151. PRInt32 result = (fd->lower->methods->recv)(fd->lower, buf, amount, flags,
  152. timeout);
  153. if (result > 0)
  154. INC_READ(fd, result);
  155. return result;
  156. }
  157. /**
  158. * Wrap a PRFileDesc from NSPR with a new PRFileDesc that will count the
  159. * total number of bytes read and written. Return the new PRFileDesc.
  160. *
  161. * This function takes ownership of its input.
  162. */
  163. PRFileDesc *
  164. tor_wrap_prfiledesc_with_byte_counter(PRFileDesc *stack)
  165. {
  166. if (BUG(! countbytes_initialized)) {
  167. tor_nss_countbytes_init();
  168. }
  169. tor_nss_bytecounts_t *bytecounts = tor_malloc_zero(sizeof(*bytecounts));
  170. PRFileDesc *newfd = PR_CreateIOLayerStub(countbytes_layer_id,
  171. &countbytes_methods);
  172. tor_assert(newfd);
  173. newfd->secret = (PRFilePrivate *)bytecounts;
  174. /* This does some complicated messing around with the headers of these
  175. objects; see the NSPR documentation for more. The upshot is that
  176. after PushIOLayer, "stack" will be the head of the stack.
  177. */
  178. PRStatus status = PR_PushIOLayer(stack, PR_TOP_IO_LAYER, newfd);
  179. tor_assert(status == PR_SUCCESS);
  180. return stack;
  181. }
  182. /**
  183. * Given a PRFileDesc returned by tor_wrap_prfiledesc_with_byte_counter(),
  184. * or another PRFileDesc wrapping that PRFileDesc, set the provided
  185. * pointers to the number of bytes read and written on the descriptor since
  186. * it was created.
  187. *
  188. * Return 0 on success, -1 on failure.
  189. */
  190. int
  191. tor_get_prfiledesc_byte_counts(PRFileDesc *fd,
  192. uint64_t *n_read_out,
  193. uint64_t *n_written_out)
  194. {
  195. if (BUG(! countbytes_initialized)) {
  196. tor_nss_countbytes_init();
  197. }
  198. tor_assert(fd);
  199. PRFileDesc *bclayer = PR_GetIdentitiesLayer(fd, countbytes_layer_id);
  200. if (BUG(bclayer == NULL))
  201. return -1;
  202. tor_nss_bytecounts_t *counts = get_counts(bclayer);
  203. *n_read_out = counts->n_read;
  204. *n_written_out = counts->n_written;
  205. return 0;
  206. }