obliv.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. #include "oasm_lib.h"
  2. #include "enclave_api.h"
  3. #include "obliv.hpp"
  4. // Routines for processing private data obliviously
  5. // Obliviously tally the number of messages in the given buffer destined
  6. // for each storage node. Each message is of size msg_size bytes.
  7. // There are num_msgs messages in the buffer. There are
  8. // num_storage_nodes storage nodes in total. The destination storage
  9. // node of each message is determined by looking at the top
  10. // DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
  11. // beginning of the message; this will be a number between 0 and
  12. // num_storage_nodes-1, which is not necessarily the node number of the
  13. // storage node, which may be larger if, for example, there are a bunch
  14. // of routing or ingestion nodes that are not also storage nodes. The
  15. // return value is a vector of length num_storage_nodes containing the
  16. // tally.
  17. std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
  18. uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes)
  19. {
  20. // The _contents_ of buf are private, but everything else in the
  21. // input is public. The contents of the output tally (but not its
  22. // length) are also private.
  23. std::vector<uint32_t> tally(num_storage_nodes, 0);
  24. // This part must all be oblivious except for the length checks on
  25. // num_msgs and num_storage_nodes
  26. while (num_msgs) {
  27. uint32_t storage_node_id = (*(const uint32_t*)buf) >> DEST_UID_BITS;
  28. for (uint32_t i=0; i<num_storage_nodes; ++i) {
  29. tally[i] += (storage_node_id == i);
  30. }
  31. buf += msg_size;
  32. --num_msgs;
  33. }
  34. return tally;
  35. }
  36. // Obliviously create padding messages destined for the various storage
  37. // nodes, using the (private) counts in the tally vector. The tally
  38. // vector may be modified by this function. tot_padding must be the sum
  39. // of the elements in tally, which need _not_ be private.
  40. void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
  41. std::vector<uint32_t> &tally, uint32_t tot_padding)
  42. {
  43. // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in
  44. // the bottom DEST_UID_BITS.
  45. uint32_t pad_user = (1<<DEST_UID_BITS)-1;
  46. // This value is not oblivious
  47. const uint32_t num_storage_nodes = uint32_t(tally.size());
  48. // This part must all be oblivious except for the length checks on
  49. // tot_padding and num_storage_nodes
  50. while (tot_padding) {
  51. bool found = false;
  52. uint32_t found_node = 0;
  53. for (uint32_t i=0; i<num_storage_nodes; ++i) {
  54. bool found_here = (!found) & (!!tally[i]);
  55. found_node = oselect_uint32_t(found_node, i, found_here);
  56. found = found | found_here;
  57. tally[i] -= found_here;
  58. }
  59. *(uint32_t*)buf = ((found_node<<DEST_UID_BITS) | pad_user);
  60. buf += msg_size;
  61. --tot_padding;
  62. }
  63. }
  64. // Determine the number of messages exceeding the maximum that can be sent to a
  65. // storage server. Oblivious to the contents of tally vector.
  66. std::vector<uint32_t> obliv_excess_stg(std::vector<uint32_t> &tally,
  67. nodenum_t num_storage_nodes, uint32_t msgs_per_stg)
  68. {
  69. std::vector<uint32_t> excess(num_storage_nodes, 0);
  70. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  71. bool exceeds = tally[i] > msgs_per_stg;
  72. uint32_t diff = tally[i] - msgs_per_stg; // nonsensical if !exceeds
  73. excess[i] = oselect_uint32_t(0, diff, exceeds);
  74. }
  75. return excess;
  76. }
  77. // Determine the number of messages under the maximum that can be sent to a
  78. // storage server. Oblivious to the contents of tally vector.
  79. std::vector<uint32_t> obliv_padding_stg(std::vector<uint32_t> &tally,
  80. nodenum_t num_storage_nodes, uint32_t msgs_per_stg)
  81. {
  82. std::vector<uint32_t> padding(num_storage_nodes, 0);
  83. for (nodenum_t i=0; i<num_storage_nodes; ++i) {
  84. bool under = tally[i] < msgs_per_stg;
  85. uint32_t diff = msgs_per_stg - tally[i]; // nonsensical if !under
  86. padding[i] = oselect_uint32_t(0, diff, under);
  87. }
  88. return padding;
  89. }
  90. // For each excess messages, convert into padding for nodes that will need some.
  91. // Oblivious to contents of excess, padding, and tally vectors. May modify
  92. // excess, padding, and tally vectors.
  93. void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,
  94. std::vector<uint32_t> &excess, std::vector<uint32_t> &padding,
  95. std::vector<uint32_t> &tally, nodenum_t num_storage_nodes)
  96. {
  97. uint8_t *cur_msg = buf + ((num_msgs-1)*msg_size);
  98. uint32_t pad_user = (1<<DEST_UID_BITS)-1;
  99. for (uint32_t i=0; i<num_msgs; ++i) {
  100. // Determine if storage node for current node has excess messages.
  101. // Also, decrement excess count and tally if so.
  102. uint32_t storage_node_id = (*(const uint32_t*)cur_msg) >> DEST_UID_BITS;
  103. bool stg_node_excess = false;
  104. for (uint32_t j=0; j<num_storage_nodes; ++j) {
  105. bool at_msg_node = (storage_node_id == j);
  106. bool cur_node_excess = (excess[j] > 0);
  107. stg_node_excess = oselect_uint32_t(stg_node_excess,
  108. cur_node_excess, at_msg_node);
  109. excess[j] -= (at_msg_node & cur_node_excess);
  110. tally[j] -= (at_msg_node & cur_node_excess);
  111. }
  112. // Find first node that needs padding. Also, decrement padding count
  113. // and increment tally if current node has excess messages.
  114. bool found_padding = false;
  115. nodenum_t found_padding_node = 0;
  116. for (uint32_t j=0; j<num_storage_nodes; ++j) {
  117. bool found_padding_here = (!found_padding) & (!!padding[j]);
  118. found_padding_node = oselect_uint32_t(found_padding_node, j,
  119. found_padding_here);
  120. found_padding = found_padding | found_padding_here;
  121. padding[j] -= (found_padding_here & stg_node_excess);
  122. tally[j] += (found_padding_here & stg_node_excess);
  123. }
  124. // Convert to padding if excess
  125. uint32_t pad = ((found_padding_node<<DEST_UID_BITS) | pad_user);
  126. *(uint32_t*)cur_msg = oselect_uint32_t(*(uint32_t*)cur_msg, pad,
  127. stg_node_excess);
  128. // Go to previous message for backwards iteration through messages
  129. cur_msg -= msg_size;
  130. }
  131. }