obliv.cpp 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include "oasm_lib.h"
  2. #include "enclave_api.h"
  3. #include "obliv.hpp"
  4. // Routines for processing private data obliviously
  5. // Obliviously tally the number of messages in the given buffer destined
  6. // for each storage node. Each message is of size msg_size bytes.
  7. // There are num_msgs messages in the buffer. There are
  8. // num_storage_nodes storage nodes in total. The destination storage
  9. // node of each message is determined by looking at the top
  10. // DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
  11. // beginning of the message; this will be a number between 0 and
  12. // num_storage_nodes-1, which is not necessarily the node number of the
  13. // storage node, which may be larger if, for example, there are a bunch
  14. // of routing or ingestion nodes that are not also storage nodes. The
  15. // return value is a vector of length num_storage_nodes containing the
  16. // tally.
  17. std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
  18. uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes)
  19. {
  20. // The _contents_ of buf are private, but everything else in the
  21. // input is public. The contents of the output tally (but not its
  22. // length) are also private.
  23. std::vector<uint32_t> tally(num_storage_nodes, 0);
  24. // This part must all be oblivious except for the length checks on
  25. // num_msgs and num_storage_nodes
  26. while (num_msgs) {
  27. uint32_t storage_node_id = (*(const uint32_t*)buf) >> DEST_UID_BITS;
  28. for (uint32_t i=0; i<num_storage_nodes; ++i) {
  29. tally[i] += (storage_node_id == i);
  30. }
  31. buf += msg_size;
  32. --num_msgs;
  33. }
  34. return tally;
  35. }
  36. // Obliviously convert global padding (receiver id 0xffffffff) into
  37. // padding for each storage node according to the (private) padding
  38. // tally.
  39. void obliv_stg_padding(uint8_t *buf, uint32_t msg_size,
  40. std::vector<uint32_t> &tally, uint32_t num_msgs)
  41. {
  42. // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in
  43. // the bottom DEST_UID_BITS.
  44. const uint32_t pad_user = (1<<DEST_UID_BITS)-1;
  45. // This value is not oblivious
  46. const nodenum_t num_storage_nodes = nodenum_t(tally.size());
  47. uint8_t *cur_msg = buf;
  48. // For each message, obliviously turn global padding into padding
  49. // for some storage node whose tally shows it still needs more
  50. // padding.
  51. for (uint32_t m=0; m<num_msgs; ++m) {
  52. uint32_t receiver_id = *(uint32_t*)cur_msg;
  53. bool is_padding = (receiver_id == 0xffffffff);
  54. // Obliviously find a storage node that still needs more
  55. // padding, if is_padding is true. If is_padding is false, this
  56. // whole block is a no-op.
  57. bool found = !is_padding;
  58. uint32_t found_node = 0;
  59. for (uint32_t i=0; i<num_storage_nodes; ++i) {
  60. bool found_here = (!found) & (!!tally[i]);
  61. found_node = oselect_uint32_t(found_node, i, found_here);
  62. found = found | found_here;
  63. tally[i] -= found_here;
  64. }
  65. // If this was padding, overwrite the receiver id with the
  66. // padding id specific to the found storage node; otherwise
  67. // write the original receiver id back.
  68. receiver_id = oselect_uint32_t(receiver_id,
  69. ((found_node<<DEST_UID_BITS) | pad_user), is_padding);
  70. *(uint32_t*)cur_msg = receiver_id;
  71. cur_msg += msg_size;
  72. }
  73. }