| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 | #include "oasm_lib.h"#include "enclave_api.h"#include "obliv.hpp"// Routines for processing private data obliviously// Obliviously tally the number of messages in the given buffer destined// for each storage node.  Each message is of size msg_size bytes.// There are num_msgs messages in the buffer.  There are// num_storage_nodes storage nodes in total.  The destination storage// node of each message is determined by looking at the top// DEST_STORAGE_NODE_BITS bits of the (little-endian) 32-bit word at the// beginning of the message; this will be a number between 0 and// num_storage_nodes-1, which is not necessarily the node number of the// storage node, which may be larger if, for example, there are a bunch// of routing or ingestion nodes that are not also storage nodes.  The// return value is a vector of length num_storage_nodes containing the// tally.std::vector<uint32_t> obliv_tally_stg(const uint8_t *buf,    uint32_t msg_size, uint32_t num_msgs, uint32_t num_storage_nodes){    // The _contents_ of buf are private, but everything else in the    // input is public.  The contents of the output tally (but not its    // length) are also private.    std::vector<uint32_t> tally(num_storage_nodes, 0);    // This part must all be oblivious except for the length checks on    // num_msgs and num_storage_nodes    while (num_msgs) {        uint32_t storage_node_id = (*(const uint32_t*)buf) >> DEST_UID_BITS;        for (uint32_t i=0; i<num_storage_nodes; ++i) {            tally[i] += (storage_node_id == i);        }        buf += msg_size;        --num_msgs;    }    return tally;}// Obliviously create padding messages destined for the various storage// nodes, using the (private) counts in the tally vector.  The tally// vector may be modified by this function.  tot_padding must be the sum// of the elements in tally, which need _not_ be private.void obliv_pad_stg(uint8_t *buf, uint32_t msg_size,    std::vector<uint32_t> &tally, uint32_t tot_padding){    // A value with 0 in the top DEST_STORAGE_NODE_BITS and all 1s in    // the bottom DEST_UID_BITS.    uint32_t pad_user = (1<<DEST_UID_BITS)-1;    // This value is not oblivious    const uint32_t num_storage_nodes = uint32_t(tally.size());    // This part must all be oblivious except for the length checks on    // tot_padding and num_storage_nodes    while (tot_padding) {        bool found = false;        uint32_t found_node = 0;        for (uint32_t i=0; i<num_storage_nodes; ++i) {            bool found_here = (!found) & (!!tally[i]);            found_node = oselect_uint32_t(found_node, i, found_here);            found = found | found_here;            tally[i] -= found_here;        }        *(uint32_t*)buf = ((found_node<<DEST_UID_BITS) | pad_user);        buf += msg_size;        --tot_padding;    }}// For each excess message, convert into padding for nodes that will need some.// Oblivious to contents of message buffer and tally vector. May modify message// buffer and tally vector.void obliv_excess_to_padding(uint8_t *buf, uint32_t msg_size, uint32_t num_msgs,    std::vector<uint32_t> &tally, uint32_t msgs_per_stg){    const nodenum_t num_storage_nodes = nodenum_t(tally.size());    // Determine the number of messages exceeding and under the maximum that    // can be sent to a storage server. Oblivious to the contents of tally    // vector.    std::vector<uint32_t> excess(num_storage_nodes, 0);    std::vector<uint32_t> padding(num_storage_nodes, 0);    for (nodenum_t i=0; i<num_storage_nodes; ++i) {        bool exceeds = tally[i] > msgs_per_stg;        uint32_t diff = tally[i] - msgs_per_stg;        excess[i] = oselect_uint32_t(0, diff, exceeds);        diff = msgs_per_stg - tally[i];        padding[i] = oselect_uint32_t(0, diff, !exceeds);    }    uint8_t *cur_msg = buf + ((num_msgs-1)*msg_size);    uint32_t pad_user = (1<<DEST_UID_BITS)-1;    for (uint32_t i=0; i<num_msgs; ++i) {        // Determine if storage node for current node has excess messages.        // Also, decrement excess count and tally if so.        uint32_t storage_node_id = (*(const uint32_t*)cur_msg) >> DEST_UID_BITS;        bool node_excess = false;        for (uint32_t j=0; j<num_storage_nodes; ++j) {            bool at_msg_node = (storage_node_id == j);            bool cur_node_excess = (excess[j] > 0);            node_excess = oselect_uint32_t(node_excess, cur_node_excess,                at_msg_node);            excess[j] -= (at_msg_node & cur_node_excess);            tally[j] -= (at_msg_node & cur_node_excess);        }        // Find first node that needs padding. Decrement padding count and        // increment tally for that if current-message node has excess messages.        bool found_padding = false;        nodenum_t found_padding_node = 0;        for (uint32_t j=0; j<num_storage_nodes; ++j) {            bool found_padding_here = (!found_padding) & (!!padding[j]);            found_padding_node = oselect_uint32_t(found_padding_node, j,                found_padding_here);            found_padding = found_padding | found_padding_here;            padding[j] -= (found_padding_here & node_excess);            tally[j] += (found_padding_here & node_excess);        }        // Convert to padding if excess        uint32_t pad = ((found_padding_node<<DEST_UID_BITS) | pad_user);        *(uint32_t*)cur_msg = oselect_uint32_t(*(uint32_t*)cur_msg, pad,            node_excess);        // Go to previous message for backwards iteration through messages        cur_msg -= msg_size;    }}
 |