1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- #include "oasm_lib.h"
- #include "enclave_api.h"
- #include "obliv.hpp"
- // Routines for processing private data obliviously
- // Obliviously tally the number of messages in the given buffer destined
- // for each storage node. Each message is of size msg_size bytes.
- // There are num_msgs messages in the buffer. There are
- // num_storage_nodes storage nodes in total. The destination storage
- // node of each message is determined by looking at the top
- // DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the
- // beginning of the message; this will be a number between 0 and
- // num_storage_nodes-1, which is not necessarily the node number of the
- // storage node, which may be larger if, for example, there are a bunch
- // of routing or ingestion nodes that are not also storage nodes. The
- // return value is a vector of length num_storage_nodes containing the
- // tally.
- std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,
- uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes)
- {
- // The _contents_ of buf are private, but everything else in the
- // input is public. The contents of the output tally (but not its
- // length) are also private.
- std::vector<uint32_t> tally(num_storage_nodes, 0);
- // This part must all be oblivious except for the length checks on
- // num_msgs and num_storage_nodes
- while (num_msgs) {
- uint32_t storage_node_id = (*(const uint32_t*)buf) >> DEST_UID_BITS;
- for (uint32_t i=0; i<num_storage_nodes; ++i) {
- tally[i] += (storage_node_id == i);
- }
- buf += msg_size;
- --num_msgs;
- }
- return tally;
- }
- // Obliviously create padding messages destined for the various storage
- // nodes, using the (private) counts in the tally vector. The tally
- // vector may be modified by this function. tot_padding must be the sum
- // of the elements in tally, which need _not_ be private.
- void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,
- std::vector<uint32_t> &tally, uint32_t tot_padding)
- {
- // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in
- // the bottom DEST_UID_BITS.
- uint32_t pad_user = (1<<DEST_UID_BITS)-1;
- // This value is not oblivious
- const uint32_t num_storage_nodes = uint32_t(tally.size());
- // This part must all be oblivious except for the length checks on
- // tot_padding and num_storage_nodes
- while (tot_padding) {
- bool found = false;
- uint32_t found_node = 0;
- for (uint32_t i=0; i<num_storage_nodes; ++i) {
- bool found_here = (!found) & (!!tally[i]);
- found_node = oselect_uint32_t(found_node, i, found_here);
- found = found | found_here;
- tally[i] -= found_here;
- }
- *(uint32_t*)buf = ((found_node<<DEST_UID_BITS) | pad_user);
- buf += msg_size;
- --tot_padding;
- }
- }
|