// Copyright (C) 2014 by Xiao Shaun Wang <wangxiao@cs.umd.edu> package com.oblivm.backend.oram; import java.util.ArrayList; import java.util.Arrays; import com.oblivm.backend.flexsc.CompEnv; import com.oblivm.backend.flexsc.Party; public class RecursiveCircuitOram<T> { public LinearScanOram<T> baseOram; public ArrayList<CircuitOram<T>> clients = new ArrayList<>(); public int lengthOfIden; int recurFactor; int cutoff; int capacity; Party p; public RecursiveCircuitOram(CompEnv<T> env, int N, int dataSize, int cutoff, int recurFactor, int capacity, int sp) { init(env, N, dataSize, cutoff, recurFactor, capacity, sp); } public RecursiveCircuitOram(CompEnv<T> env, int N, int dataSize, int cutoff, int recurFactor) { init(env, N, dataSize, cutoff, recurFactor, 3, 80); } // with default params public RecursiveCircuitOram(CompEnv<T> env, int N, int dataSize) { init(env, N, dataSize, 1 << 6, 8, 3, 80); } public void setInitialValue(int initial) { clients.get(0).setInitialValue(initial); } void init(CompEnv<T> env, int N, int dataSize, int cutoff, int recurFactor, int capacity, int sp) { this.p = env.party; this.cutoff = cutoff; this.recurFactor = recurFactor; this.capacity = capacity; CircuitOram<T> oram = new CircuitOram<T>(env, N, dataSize, capacity, sp); lengthOfIden = oram.lengthOfIden; clients.add(oram); int newDataSize = oram.lengthOfPos * recurFactor, newN = (1 << oram.lengthOfIden) / recurFactor; while (newN > cutoff) { oram = new CircuitOram<T>(env, newN, newDataSize, capacity, sp); clients.add(oram); newDataSize = oram.lengthOfPos * recurFactor; newN = (1 << oram.lengthOfIden) / recurFactor; } CircuitOram<T> last = clients.get(clients.size() - 1); baseOram = new LinearScanOram<T>(env, (1 << last.lengthOfIden), last.lengthOfPos); } public T[] read(T[] iden) { T[][] poses = travelToDeep(iden, 1); CircuitOram<T> currentOram = clients.get(0); boolean[] oldPos = baseOram.lib.declassifyToBoth(poses[0]); T[] res = currentOram.read(iden, oldPos, poses[1]); return res; } public void write(T[] iden, T[] data) { T[][] poses = travelToDeep(iden, 1); CircuitOram<T> currentOram = clients.get(0); boolean[] oldPos = baseOram.lib.declassifyToBoth(poses[0]); currentOram.write(iden, oldPos, poses[1], data); } public void write(T[] iden, T[] data, T dummy) { T[][] poses = travelToDeep(iden, 1); CircuitOram<T> currentOram = clients.get(0); currentOram.write(iden, poses[0], poses[1], data, dummy); } public T[] access(T[] iden, T[] data, T op) { T[][] poses = travelToDeep(iden, 1); CircuitOram<T> currentOram = clients.get(0); boolean[] oldPos = baseOram.lib.declassifyToBoth(poses[0]); return currentOram.access(iden, oldPos, poses[1], data, op); } public T[][] travelToDeep(T[] iden, int level) { if (level == clients.size()) { T[] baseMap = baseOram.readAndRemove(baseOram.lib.padSignal(iden, baseOram.lengthOfIden)); T[] ithPos = baseOram.lib.rightPublicShift(iden, baseOram.lengthOfIden);// iden>>baseOram.lengthOfIden; T[] pos = extract(baseMap, ithPos, clients.get(level - 1).lengthOfPos); T[] newPos = baseOram.lib.randBools(clients.get(level - 1).lengthOfPos); put(baseMap, ithPos, newPos); baseOram.putBack(baseOram.lib.padSignal(iden, baseOram.lengthOfIden), baseMap); T[][] result = baseOram.env.newTArray(2, 0); result[0] = pos; result[1] = newPos; return result; } else { CircuitOram<T> currentOram = clients.get(level); T[][] poses = travelToDeep(subIdentifier(iden, currentOram), level + 1); boolean[] oldPos = baseOram.lib.declassifyToBoth(poses[0]); T[] data = currentOram.readAndRemove(subIdentifier(iden, currentOram), oldPos, true); T[] ithPos = currentOram.lib.rightPublicShift(iden, currentOram.lengthOfIden);// iden>>currentOram.lengthOfIden;//iden/(1<<currentOram.lengthOfIden); T[] pos = extract(data, ithPos, clients.get(level - 1).lengthOfPos); T[] tmpNewPos = baseOram.lib.randBools(clients.get(level - 1).lengthOfPos); put(data, ithPos, tmpNewPos); currentOram.putBack(subIdentifier(iden, currentOram), poses[1], data); T[][] result = currentOram.env.newTArray(2, 0); result[0] = pos; result[1] = tmpNewPos; return result; } } public T[] subIdentifier(T[] iden, OramParty<T> o) { // int a = iden & ((1<<o.lengthOfIden)-1);//(iden % (1<<o.lengthOfIden)) return o.lib.padSignal(iden, o.lengthOfIden); } public T[] extract(T[] array, T[] ithPos, int length) { int numberOfEntry = array.length / length; T[] result = Arrays.copyOfRange(array, 0, length); for (int i = 1; i < numberOfEntry; ++i) { T hit = baseOram.lib.eq(baseOram.lib.toSignals(i, ithPos.length), ithPos); result = baseOram.lib.mux(result, Arrays.copyOfRange(array, i * length, (i + 1) * length), hit); } return result; } public void put(T[] array, T[] ithPos, T[] content) { int numberOfEntry = array.length / content.length; for (int i = 0; i < numberOfEntry; ++i) { T hit = baseOram.lib.eq(baseOram.lib.toSignals(i, ithPos.length), ithPos); T[] tmp = baseOram.lib.mux(Arrays.copyOfRange(array, i * content.length, (i + 1) * content.length), content, hit); System.arraycopy(tmp, 0, array, i * content.length, content.length); } } }