ソースを参照

update PRF/PRG

Boyoung- 8 年 前
コミット
4e530cd7df

BIN
key/prg.key


+ 1 - 1
src/crypto/OramCrypto.java → src/crypto/Crypto.java

@@ -3,7 +3,7 @@ package crypto;
 import java.security.NoSuchAlgorithmException;
 import java.security.SecureRandom;
 
-public class OramCrypto {
+public class Crypto {
 	public static SecureRandom sr;
 
 	static {

+ 98 - 0
src/crypto/PRF.java

@@ -0,0 +1,98 @@
+package crypto;
+
+import java.security.InvalidKeyException;
+import java.security.NoSuchAlgorithmException;
+
+import javax.crypto.BadPaddingException;
+import javax.crypto.Cipher;
+import javax.crypto.IllegalBlockSizeException;
+import javax.crypto.NoSuchPaddingException;
+import javax.crypto.spec.SecretKeySpec;
+
+import org.bouncycastle.util.Arrays;
+
+import exceptions.IllegalInputException;
+import exceptions.LengthNotMatchException;
+import util.Util;
+
+public class PRF {
+
+	private Cipher cipher;
+	private int l; // output bit length
+
+	private int maxInputBytes = 12;
+
+	public PRF(int l) {
+		try {
+			cipher = Cipher.getInstance("AES/ECB/NoPadding");
+		} catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
+			e.printStackTrace();
+		}
+		this.l = l;
+	}
+
+	public void init(byte[] key) {
+		if (key.length != 16)
+			throw new LengthNotMatchException(key.length + " != 16");
+
+		SecretKeySpec skey = new SecretKeySpec(key, "AES");
+		try {
+			cipher.init(Cipher.ENCRYPT_MODE, skey);
+		} catch (InvalidKeyException e) {
+			e.printStackTrace();
+		}
+	}
+
+	public byte[] compute(byte[] input) {
+		if (input.length > maxInputBytes)
+			throw new IllegalInputException(input.length + " > " + maxInputBytes);
+
+		byte[] in = new byte[16];
+		System.arraycopy(input, 0, in, in.length - input.length, input.length);
+		byte[] output = null;
+		if (l <= 128)
+			output = leq128(in, l);
+		else
+			output = g128(in);
+
+		return output;
+	}
+
+	private byte[] leq128(byte[] input, int np) {
+		byte[] ctext = null;
+		try {
+			ctext = cipher.doFinal(input);
+		} catch (IllegalBlockSizeException | BadPaddingException e) {
+			e.printStackTrace();
+		}
+
+		int outBytes = (np + 7) / 8;
+		if (ctext.length == outBytes)
+			return ctext;
+		else
+			return Arrays.copyOfRange(ctext, ctext.length - outBytes, ctext.length);
+	}
+
+	private byte[] g128(byte[] input) {
+		int n = l / 128;
+		int outBytes = (l + 7) / 8;
+		byte[] output = new byte[outBytes];
+
+		int len = Math.min(16 - maxInputBytes, 4);
+		for (int i = 0; i < n; i++) {
+			byte[] index = Util.intToBytes(i + 1);
+			System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
+			byte[] seg = leq128(input, 128);
+			System.arraycopy(seg, 0, output, i * seg.length, seg.length);
+		}
+		int np = l % 128;
+		if (np == 0)
+			return output;
+
+		byte[] index = Util.intToBytes(n + 1);
+		System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
+		byte[] last = leq128(input, np);
+		System.arraycopy(last, 0, output, outBytes - last.length, last.length);
+		return output;
+	}
+}

+ 105 - 0
src/crypto/PRG.java

@@ -0,0 +1,105 @@
+package crypto;
+
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.security.InvalidAlgorithmParameterException;
+import java.security.InvalidKeyException;
+import java.security.NoSuchAlgorithmException;
+import java.util.Random;
+
+import javax.crypto.BadPaddingException;
+import javax.crypto.Cipher;
+import javax.crypto.IllegalBlockSizeException;
+import javax.crypto.NoSuchPaddingException;
+import javax.crypto.spec.IvParameterSpec;
+import javax.crypto.spec.SecretKeySpec;
+
+import exceptions.IllegalInputException;
+
+public class PRG {
+	private Cipher cipher;
+	private SecretKeySpec skey;
+	private int l; // output bit length
+
+	public PRG(int l) {
+		try {
+			cipher = Cipher.getInstance("AES/CTR/NoPadding");
+		} catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
+			e.printStackTrace();
+		}
+		readKey();
+		this.l = l;
+	}
+
+	public byte[] compute(byte[] seed) {
+		byte[] input;
+		if (seed.length > 16) {
+			throw new IllegalInputException(seed.length + " > 16");
+		} else if (seed.length == 16) {
+			input = seed;
+		} else {
+			input = new byte[16];
+			System.arraycopy(seed, 0, input, input.length - seed.length, seed.length);
+		}
+
+		IvParameterSpec IV = new IvParameterSpec(input);
+		byte[] msg = new byte[(l + 7) / 8];
+		byte[] output = null;
+
+		try {
+			cipher.init(Cipher.ENCRYPT_MODE, skey, IV);
+			output = cipher.doFinal(msg);
+		} catch (InvalidKeyException | InvalidAlgorithmParameterException | IllegalBlockSizeException
+				| BadPaddingException e) {
+			e.printStackTrace();
+		}
+
+		return output;
+	}
+
+	public static void generateKey(Random rand) {
+		byte[] key = new byte[16];
+		rand.nextBytes(key);
+		SecretKeySpec skey = new SecretKeySpec(key, "AES");
+		FileOutputStream fos = null;
+		ObjectOutputStream oos = null;
+		try {
+			fos = new FileOutputStream("key/prg.key");
+			oos = new ObjectOutputStream(fos);
+			oos.writeObject(skey);
+		} catch (IOException e) {
+			e.printStackTrace();
+		} finally {
+			if (oos != null) {
+				try {
+					oos.close();
+				} catch (IOException e) {
+					e.printStackTrace();
+				}
+			}
+		}
+	}
+
+	public void readKey() {
+		FileInputStream fis = null;
+		ObjectInputStream ois = null;
+		try {
+			fis = new FileInputStream("key/prg.key");
+			ois = new ObjectInputStream(fis);
+			skey = (SecretKeySpec) ois.readObject();
+		} catch (IOException | ClassNotFoundException e) {
+			e.printStackTrace();
+		} finally {
+			if (ois != null) {
+				try {
+					ois.close();
+				} catch (IOException e) {
+					e.printStackTrace();
+				}
+			}
+		}
+	}
+}

+ 16 - 0
src/exceptions/IllegalInputException.java

@@ -0,0 +1,16 @@
+package exceptions;
+
+public class IllegalInputException extends RuntimeException {
+	/**
+	 * 
+	 */
+	private static final long serialVersionUID = 1L;
+
+	public IllegalInputException() {
+		super();
+	}
+
+	public IllegalInputException(String message) {
+		super(message);
+	}
+}

+ 3 - 3
test/ui/InitForest.java → src/init/InitForest.java

@@ -1,6 +1,6 @@
-package ui;
+package init;
 
-import crypto.OramCrypto;
+import crypto.Crypto;
 import oram.Forest;
 import oram.Metadata;
 
@@ -13,7 +13,7 @@ public class InitForest {
 		forest.writeToFile();
 
 		Forest share1 = forest;
-		Forest share2 = new Forest(md, OramCrypto.sr);
+		Forest share2 = new Forest(md, Crypto.sr);
 		share1.setXor(share2);
 		share1.writeToFile(md.getDefaultSharesName1());
 		share2.writeToFile(md.getDefaultSharesName2());

+ 2 - 2
src/oram/Forest.java

@@ -10,7 +10,7 @@ import java.math.BigInteger;
 import java.util.HashMap;
 import java.util.Random;
 
-import crypto.OramCrypto;
+import crypto.Crypto;
 import exceptions.LengthNotMatchException;
 import util.Util;
 
@@ -154,7 +154,7 @@ public class Forest implements Serializable {
 					// if N is a new address, then find an unused leaf tuple
 					if (leafTupleIndex == null) {
 						do {
-							leafTupleIndex = Util.nextLong(OramCrypto.sr, (numBuckets / 2 + 1) * w);
+							leafTupleIndex = Util.nextLong(Crypto.sr, (numBuckets / 2 + 1) * w);
 						} while (addrToTuple[i].containsValue(leafTupleIndex));
 						addrToTuple[i].put(N[i], leafTupleIndex);
 					}

+ 246 - 0
src/protocols/SSCOT.java

@@ -0,0 +1,246 @@
+package protocols;
+
+import java.math.BigInteger;
+
+import org.apache.commons.lang3.tuple.Pair;
+
+import communication.Communication;
+import oram.Forest;
+import oram.Metadata;
+
+public class SSCOT extends Protocol {
+	public SSCOT(Communication con1, Communication con2) {
+		super(con1, con2);
+	}
+
+	public Pair<Integer, BigInteger> executeCharlie(Communication D,
+			Communication E, int i, int N, int l, int l_p) {
+		// protocol
+		// step 1
+		timing.stopwatch[PID.sscot][TID.online_read].start();
+		byte[] msg_ev = E.read();
+
+		// step 2
+		byte[] msg_pw = D.read();
+		timing.stopwatch[PID.sscot][TID.online_read].stop();
+
+		// step 3
+		timing.stopwatch[PID.sscot][TID.online].start();
+		byte[][] e = new byte[N][];
+		byte[][] v = new byte[N][];
+		byte[][] p = new byte[N][];
+		byte[][] w = new byte[N][];
+		PRG G = new PRG(l);
+		int gBytes = (l + 7) / 8;
+
+		for (int t = 0; t < N; t++) {
+			e[t] = Arrays.copyOfRange(msg_ev, t * gBytes, (t + 1) * gBytes);
+			v[t] = Arrays.copyOfRange(msg_ev, N * gBytes + t * SR.kBytes, N
+					* gBytes + (t + 1) * SR.kBytes);
+			p[t] = Arrays.copyOfRange(msg_pw, t * SR.kBytes, (t + 1)
+					* SR.kBytes);
+			w[t] = Arrays.copyOfRange(msg_pw, (N + t) * SR.kBytes, (N + t + 1)
+					* SR.kBytes);
+
+			if (new BigInteger(1, v[t]).compareTo(new BigInteger(1, w[t])) == 0) {
+				//BigInteger m_t = new BigInteger(1, e[t]).xor(new BigInteger(1, G.compute(p[t])));
+				timing.stopwatch[PID.aes_prg][TID.online].start();
+				byte[] tmp = G.compute(p[t]);
+				timing.stopwatch[PID.aes_prg][TID.online].stop();
+				BigInteger m_t = new BigInteger(1, e[t]).xor(new BigInteger(1, tmp));
+				timing.stopwatch[PID.sscot][TID.online].stop();
+				return Pair.of(t, m_t);
+			}
+		}
+		timing.stopwatch[PID.sscot][TID.online].stop();
+
+		// error
+		return null;
+	}
+
+	public void executeDebbie(Communication C, Communication E, int i, int N,
+			int l, int l_p, BigInteger[] b) {
+		// protocol
+		// step 2
+		timing.stopwatch[PID.sscot][TID.online].start();
+		int diffBits = SR.kBits - l_p;
+		BigInteger[] y = new BigInteger[N];
+		byte[][][] pw = new byte[2][N][];
+		byte[] msg_pw = new byte[SR.kBytes * N * 2];
+		PRF F_k = new PRF(SR.kBits);
+		PRF F_k_p = new PRF(SR.kBits);
+		F_k.init(PreData.sscot_k[i]);
+		F_k_p.init(PreData.sscot_k_p[i]);
+
+		for (int t = 0; t < N; t++) {
+			y[t] = PreData.sscot_r[i][t].xor(b[t].shiftLeft(diffBits));
+			timing.stopwatch[PID.aes_prf][TID.online].start();
+			pw[0][t] = F_k.compute(y[t].toByteArray());
+			pw[1][t] = F_k_p.compute(y[t].toByteArray());
+			timing.stopwatch[PID.aes_prf][TID.online].stop();
+			System.arraycopy(pw[0][t], 0, msg_pw, t * SR.kBytes, SR.kBytes);
+			System.arraycopy(pw[1][t], 0, msg_pw, (N + t) * SR.kBytes,
+					SR.kBytes);
+		}
+		timing.stopwatch[PID.sscot][TID.online].stop();
+
+		timing.stopwatch[PID.sscot][TID.online_write].start();
+		C.write(msg_pw, PID.sscot);
+		timing.stopwatch[PID.sscot][TID.online_write].stop();
+	}
+
+	public void executeEddie(Communication C, Communication D, int i, int N,
+			int l, int l_p, BigInteger[] m, BigInteger[] a) {
+		// protocol
+		// step 1
+		timing.stopwatch[PID.sscot][TID.online].start();
+		int gBytes = (l + 7) / 8;
+		int diffBits = SR.kBits - l_p;
+		BigInteger[] x = new BigInteger[N];
+		byte[][][] ev = new byte[2][N][];
+		byte[] msg_ev = new byte[(SR.kBytes + gBytes) * N];
+		PRF F_k = new PRF(SR.kBits);
+		PRF F_k_p = new PRF(SR.kBits);
+		PRG G = new PRG(l);
+		F_k.init(PreData.sscot_k[i]);
+		F_k_p.init(PreData.sscot_k_p[i]);
+
+		for (int t = 0; t < N; t++) {
+			x[t] = PreData.sscot_r[i][t].xor(a[t].shiftLeft(diffBits));
+			//ev[0][t] = new BigInteger(1, G.compute(F_k.compute(x[t].toByteArray()))).xor(m[t]).toByteArray();
+			timing.stopwatch[PID.aes_prf][TID.online].start();
+			ev[1][t] = F_k_p.compute(x[t].toByteArray());
+			byte[] tmp = F_k.compute(x[t].toByteArray());
+			timing.stopwatch[PID.aes_prf][TID.online].stop();
+			timing.stopwatch[PID.aes_prg][TID.online].start();
+			tmp = G.compute(tmp);
+			timing.stopwatch[PID.aes_prg][TID.online].stop();
+			ev[0][t] = new BigInteger(1, tmp).xor(m[t]).toByteArray();
+			if (ev[0][t].length < gBytes)
+				System.arraycopy(ev[0][t], 0, msg_ev, (t + 1) * gBytes
+						- ev[0][t].length, ev[0][t].length);
+			else
+				System.arraycopy(ev[0][t], ev[0][t].length - gBytes, msg_ev, t
+						* gBytes, gBytes);
+			System.arraycopy(ev[1][t], 0, msg_ev, N * gBytes + t * SR.kBytes,
+					SR.kBytes);
+		}
+		timing.stopwatch[PID.sscot][TID.online].stop();
+
+		timing.stopwatch[PID.sscot][TID.online_write].start();
+		C.write(msg_ev, PID.sscot);
+		timing.stopwatch[PID.sscot][TID.online_write].stop();
+	}
+
+	// for testing correctness
+	@Override
+	public void run(Party party, Forest forest) throws ForestException {
+		System.out.println("#####  Testing SSCOT  #####");
+
+		timing = new Timing();
+
+		for (int ii = 0; ii < 20; ii++) {
+
+			if (party == Party.Eddie) {
+				int levels = ForestMetadata.getLevels();
+				int i = SR.rand.nextInt(levels - 1) + 1;
+				int N = SR.rand.nextInt(50) + 150; // 150-199
+				int l = ForestMetadata.getTupleBits(i);
+				int l_p = 1 + ForestMetadata.getNBits(i);
+				int t = SR.rand.nextInt(N);
+
+				PreData.sscot_k = new byte[levels][16];
+				PreData.sscot_k_p = new byte[levels][16];
+				PreData.sscot_r = new BigInteger[levels][N];
+				BigInteger[] a = new BigInteger[N];
+				BigInteger[] b = new BigInteger[N];
+				BigInteger[] m = new BigInteger[N];
+
+				SR.rand.nextBytes(PreData.sscot_k[i]);
+				SR.rand.nextBytes(PreData.sscot_k_p[i]);
+				for (int o = 0; o < N; o++) {
+					PreData.sscot_r[i][o] = new BigInteger(SR.kBits, SR.rand);
+					a[o] = new BigInteger(l_p, SR.rand);
+					b[o] = new BigInteger(l_p, SR.rand);
+					while (a[o].compareTo(b[o]) == 0)
+						b[o] = new BigInteger(l_p, SR.rand);
+					m[o] = new BigInteger(l, SR.rand);
+				}
+				a[t] = b[t];
+
+				con1.write(i);
+				con1.write(N);
+				con1.write(l);
+				con1.write(l_p);
+
+				con2.write(i);
+				con2.write(N);
+				con2.write(l);
+				con2.write(l_p);
+				con2.write(PreData.sscot_k[i]);
+				con2.write(PreData.sscot_k_p[i]);
+				con2.write(PreData.sscot_r[i]);
+				con2.write(b);
+
+				executeEddie(con1, con2, i, N, l, l_p, m, a);
+
+				int output_t = con1.readInt();
+				BigInteger m_t = con1.readBigInteger();
+
+				System.out.println("i = " + i);
+				if (t == output_t && m[t].compareTo(m_t) == 0) {
+					System.out.println("SSCOT test passed:");
+				} else {
+					System.out.println("SSCOT test failed:");
+				}
+				System.out.println("t=" + t + ", output_t=" + output_t);
+				System.out.println("m[t]=" + m[t] + ", m_t=" + m_t);
+			} else if (party == Party.Debbie) {
+				int i = con2.readInt();
+				int N = con2.readInt();
+				int l = con2.readInt();
+				int l_p = con2.readInt();
+
+				int levels = ForestMetadata.getLevels();
+				PreData.sscot_k = new byte[levels][];
+				PreData.sscot_k_p = new byte[levels][];
+				PreData.sscot_r = new BigInteger[levels][];
+
+				PreData.sscot_k[i] = con2.read();
+				PreData.sscot_k_p[i] = con2.read();
+				PreData.sscot_r[i] = con2.readBigIntegerArray();
+				BigInteger[] b = con2.readBigIntegerArray();
+
+				executeDebbie(con1, con2, i, N, l, l_p, b);
+			} else if (party == Party.Charlie) {
+				int i = con2.readInt();
+				int N = con2.readInt();
+				int l = con2.readInt();
+				int l_p = con2.readInt();
+
+				Pair<Integer, BigInteger> output = executeCharlie(con1, con2,
+						i, N, l, l_p);
+				int t = output.getLeft();
+				BigInteger m_t = output.getRight();
+
+				con2.write(t);
+				con2.write(m_t);
+			}
+
+		}
+
+		System.out.println("#####  Testing SSCOT Finished  #####");
+	}
+
+	@Override
+	public void run(protocols.Party party, Metadata md, oram.Forest forest) {
+		// TODO Auto-generated method stub
+		
+	}
+
+	@Override
+	public void run(Party party, Metadata md, Forest forest) {
+		// TODO Auto-generated method stub
+		
+	}
+}

+ 11 - 4
src/util/Util.java

@@ -1,6 +1,7 @@
 package util;
 
 import java.math.BigInteger;
+import java.nio.ByteBuffer;
 import java.util.Arrays;
 import java.util.Random;
 
@@ -70,19 +71,25 @@ public class Util {
 		for (int i = 0; i < a.length; i++)
 			a[i] = (byte) (a[i] ^ b[i]);
 	}
-	
+
+	public static byte[] intToBytes(int i) {
+		ByteBuffer bb = ByteBuffer.allocate(4);
+		bb.putInt(i);
+		return bb.array();
+	}
+
 	public static void debug(String s) {
 		// only to make Communication.java compile
 	}
-	
+
 	public static void disp(String s) {
 		// only to make Communication.java compile
 	}
-	
+
 	public static void error(String s) {
 		// only to make Communication.java compile
 	}
-	
+
 	public static void error(String s, Exception e) {
 		// only to make Communication.java compile
 	}

+ 39 - 0
test/crypto/TestPRF.java

@@ -0,0 +1,39 @@
+package crypto;
+
+import java.math.BigInteger;
+
+public class TestPRF {
+
+	public static void main(String[] args) {
+		try {
+			for (int l = 1; l < 5000; l++) {
+				System.out.println("Round: l=" + l);
+				PRF f1 = new PRF(l);
+				PRF f2 = new PRF(l);
+				byte[] k = new byte[16];
+				Crypto.sr.nextBytes(k);
+				byte[] input = new byte[Crypto.sr.nextInt(12) + 1];
+				Crypto.sr.nextBytes(input);
+				f1.init(k);
+				f2.init(k);
+				byte[] output1 = f1.compute(input);
+				byte[] output2 = f2.compute(input);
+				for (int i = 0; i < output2.length; i++)
+					System.out.print(String.format("%02X", output2[i]));
+				System.out.println("");
+				boolean test1 = new BigInteger(1, output1).compareTo(new BigInteger(1, output2)) == 0;
+				boolean test2 = output1.length == (l + 7) / 8;
+				if (!test1 || !test2) {
+					System.out.println("Fail: l=" + l + "  " + test1 + "  " + test2);
+					break;
+				}
+			}
+
+			System.out.println("done");
+
+		} catch (Exception e) {
+			e.printStackTrace();
+		}
+	}
+
+}

+ 31 - 0
test/crypto/TestPRG.java

@@ -0,0 +1,31 @@
+package crypto;
+
+import java.math.BigInteger;
+
+public class TestPRG {
+
+	public static void main(String[] args) {
+		// PRG.generateKey(Crypto.sr);
+
+		int n = 10;
+		int outBits = 1000;
+		int outBytes = (outBits + 7) / 8;
+		byte[][] input = new byte[n][16];
+		byte[][] output = new byte[n][];
+		PRG G = new PRG(outBits);
+
+		for (int i = 0; i < n; i++) {
+			Crypto.sr.nextBytes(input[i]);
+			output[i] = G.compute(input[i]);
+			System.out.println(new BigInteger(1, output[i]).toString(16));
+		}
+
+		for (int i = 0; i < n; i++) {
+			byte[] tmp = G.compute(input[i]);
+			System.out.println(
+					"deterministic:\t" + (new BigInteger(1, tmp).compareTo(new BigInteger(1, output[i])) == 0));
+			System.out.println("right length:\t" + (output[i].length == outBytes));
+		}
+	}
+
+}

+ 2 - 2
test/oram/TestForest.java

@@ -2,7 +2,7 @@ package oram;
 
 import java.math.BigInteger;
 
-import crypto.OramCrypto;
+import crypto.Crypto;
 import util.Util;
 
 public class TestForest {
@@ -20,7 +20,7 @@ public class TestForest {
 		// long numTests = numRecords;
 		for (long n = 0; n < numTests; n++) {
 			// address of record we want to test
-			long testAddr = Util.nextLong(OramCrypto.sr, numRecords);
+			long testAddr = Util.nextLong(Crypto.sr, numRecords);
 			// long testAddr = n;
 			long L = 0;
 			long outRecord = 0;