CLI.java 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package ui;
  2. import java.lang.reflect.Constructor;
  3. import java.lang.reflect.InvocationTargetException;
  4. import java.net.InetSocketAddress;
  5. import org.apache.commons.cli.CommandLine;
  6. import org.apache.commons.cli.CommandLineParser;
  7. import org.apache.commons.cli.GnuParser;
  8. import org.apache.commons.cli.Options;
  9. import org.apache.commons.cli.ParseException;
  10. import communication.Communication;
  11. import exceptions.NoSuchPartyException;
  12. import oram.Global;
  13. import oram.Metadata;
  14. import protocols.*;
  15. import pir.*;
  16. import protocols.struct.Party;
  17. public class CLI {
  18. public static final int DEFAULT_PORT = 8000;
  19. public static final String DEFAULT_IP = "localhost";
  20. public static void main(String[] args) {
  21. // Setup command line argument parser
  22. Options options = new Options();
  23. options.addOption("config", true, "Config file");
  24. options.addOption("forest", true, "Forest file");
  25. options.addOption("eddie_ip", true, "IP to look for eddie");
  26. options.addOption("debbie_ip", true, "IP to look for debbie");
  27. options.addOption("protocol", true, "Algorithim to test");
  28. options.addOption("pipeline", false, "Whether to do pipelined eviction");
  29. // Parse the command line arguments
  30. CommandLineParser cmdParser = new GnuParser();
  31. CommandLine cmd = null;
  32. try {
  33. cmd = cmdParser.parse(options, args);
  34. } catch (ParseException e1) {
  35. e1.printStackTrace();
  36. }
  37. Global.pipeline = cmd.hasOption("pipeline");
  38. String configFile = cmd.getOptionValue("config", "config.yaml");
  39. String forestFile = cmd.getOptionValue("forest", null);
  40. String party = null;
  41. String[] positionalArgs = cmd.getArgs();
  42. if (positionalArgs.length > 0) {
  43. party = positionalArgs[0];
  44. } else {
  45. try {
  46. throw new ParseException("No party specified");
  47. } catch (ParseException e) {
  48. e.printStackTrace();
  49. System.exit(-1);
  50. }
  51. }
  52. int extra_port = 1;
  53. int eddiePort1 = DEFAULT_PORT;
  54. int eddiePort2 = eddiePort1 + extra_port;
  55. int debbiePort = eddiePort2 + extra_port;
  56. String eddieIp = cmd.getOptionValue("eddie_ip", DEFAULT_IP);
  57. String debbieIp = cmd.getOptionValue("debbie_ip", DEFAULT_IP);
  58. Class<? extends Protocol> operation = null;
  59. String protocol = cmd.getOptionValue("protocol", "retrieve").toLowerCase();
  60. if (protocol.equals("acc")) {
  61. operation = Access.class;
  62. } else if (protocol.equals("cot")) {
  63. operation = SSCOT.class;
  64. } else if (protocol.equals("iot")) {
  65. operation = SSIOT.class;
  66. } else if (protocol.equals("rsf")) {
  67. operation = Reshuffle.class;
  68. } else if (protocol.equals("ppt")) {
  69. operation = PostProcessT.class;
  70. } else if (protocol.equals("ur")) {
  71. operation = UpdateRoot.class;
  72. } else if (protocol.equals("evi")) {
  73. operation = Eviction.class;
  74. } else if (protocol.equals("pt")) {
  75. operation = PermuteTarget.class;
  76. } else if (protocol.equals("pi")) {
  77. operation = PermuteIndex.class;
  78. } else if (protocol.equals("xot")) {
  79. operation = SSXOT.class;
  80. } else if (protocol.equals("rtv")) {
  81. operation = Retrieve.class;
  82. } else if (protocol.equals("pircot")) {
  83. operation = PIRCOT.class;
  84. } else if (protocol.equals("piriot")) {
  85. operation = PIRIOT.class;
  86. } else if (protocol.equals("piracc")) {
  87. operation = PIRAccess.class;
  88. } else if (protocol.equals("pirrtv")) {
  89. operation = PIRRetrieve.class;
  90. } else if (protocol.equals("pirrsf")) {
  91. operation = PIRReshuffle.class;
  92. } else if (protocol.equals("sspir")) {
  93. operation = SSPIR.class;
  94. } else if (protocol.equals("shiftpir")) {
  95. operation = ShiftPIR.class;
  96. } else if (protocol.equals("tspir")) {
  97. operation = ThreeShiftPIR.class;
  98. } else if (protocol.equals("shiftxorpir")) {
  99. operation = ShiftXorPIR.class;
  100. } else if (protocol.equals("tsxpir")) {
  101. operation = ThreeShiftXorPIR.class;
  102. } else if (protocol.equals("shift")) {
  103. operation = Shift.class;
  104. } else {
  105. System.out.println("Protocol " + protocol + " not supported");
  106. System.exit(-1);
  107. }
  108. Constructor<? extends Protocol> operationCtor = null;
  109. try {
  110. operationCtor = operation.getDeclaredConstructor(Communication.class, Communication.class);
  111. } catch (NoSuchMethodException | SecurityException e1) {
  112. e1.printStackTrace();
  113. }
  114. // For now all logic happens here. Eventually this will get wrapped
  115. // up in party specific classes.
  116. System.out.println("Starting " + party + "...");
  117. Metadata md = new Metadata(configFile);
  118. int numComs = Global.pipeline ? md.getNumTrees() + 1 : 1;
  119. Communication[] con1 = new Communication[numComs];
  120. Communication[] con2 = new Communication[numComs];
  121. if (party.equals("eddie")) {
  122. System.out.print("Waiting to establish debbie connections...");
  123. for (int i = 0; i < numComs; i++) {
  124. con1[i] = new Communication();
  125. con1[i].start(eddiePort1);
  126. eddiePort1 += 3;
  127. while (con1[i].getState() != Communication.STATE_CONNECTED)
  128. ;
  129. }
  130. System.out.println(" done!");
  131. System.out.print("Waiting to establish charlie connections...");
  132. for (int i = 0; i < numComs; i++) {
  133. con2[i] = new Communication();
  134. con2[i].start(eddiePort2);
  135. eddiePort2 += 3;
  136. while (con2[i].getState() != Communication.STATE_CONNECTED)
  137. ;
  138. }
  139. System.out.println(" done!");
  140. for (int i = 0; i < numComs; i++) {
  141. con1[i].setTcpNoDelay(true);
  142. con2[i].setTcpNoDelay(true);
  143. }
  144. try {
  145. Protocol p = operationCtor.newInstance(con1[0], con2[0]);
  146. if (protocol.equals("rtv")) {
  147. ((Retrieve) p).setCons(con1, con2);
  148. }
  149. if (protocol.equals("pirrtv")) {
  150. ((PIRRetrieve) p).setCons(con1, con2);
  151. }
  152. if (!Global.usePIR) {
  153. p.run(Party.Eddie, md, forestFile);
  154. } else {
  155. p.run(Party.Eddie, md);
  156. }
  157. } catch (InstantiationException | IllegalAccessException | IllegalArgumentException
  158. | InvocationTargetException e) {
  159. e.printStackTrace();
  160. }
  161. } else if (party.equals("debbie")) {
  162. System.out.print("Waiting to establish eddie connections...");
  163. for (int i = 0; i < numComs; i++) {
  164. con1[i] = new Communication();
  165. InetSocketAddress addr = new InetSocketAddress(eddieIp, eddiePort1);
  166. con1[i].connect(addr);
  167. eddiePort1 += 3;
  168. while (con1[i].getState() != Communication.STATE_CONNECTED)
  169. ;
  170. }
  171. System.out.println(" done!");
  172. System.out.print("Waiting to establish charlie connections...");
  173. for (int i = 0; i < numComs; i++) {
  174. con2[i] = new Communication();
  175. con2[i].start(debbiePort);
  176. debbiePort += 3;
  177. while (con2[i].getState() != Communication.STATE_CONNECTED)
  178. ;
  179. }
  180. System.out.println(" done!");
  181. for (int i = 0; i < numComs; i++) {
  182. con1[i].setTcpNoDelay(true);
  183. con2[i].setTcpNoDelay(true);
  184. }
  185. try {
  186. Protocol p = operationCtor.newInstance(con1[0], con2[0]);
  187. if (protocol.equals("rtv")) {
  188. ((Retrieve) p).setCons(con1, con2);
  189. }
  190. if (protocol.equals("pirrtv")) {
  191. ((PIRRetrieve) p).setCons(con1, con2);
  192. }
  193. if (!Global.usePIR) {
  194. p.run(Party.Debbie, md, forestFile);
  195. } else {
  196. p.run(Party.Debbie, md);
  197. }
  198. } catch (InstantiationException | IllegalAccessException | IllegalArgumentException
  199. | InvocationTargetException e) {
  200. e.printStackTrace();
  201. }
  202. } else if (party.equals("charlie")) {
  203. System.out.print("Waiting to establish eddie connections...");
  204. for (int i = 0; i < numComs; i++) {
  205. con1[i] = new Communication();
  206. InetSocketAddress addr = new InetSocketAddress(eddieIp, eddiePort2);
  207. con1[i].connect(addr);
  208. eddiePort2 += 3;
  209. while (con1[i].getState() != Communication.STATE_CONNECTED)
  210. ;
  211. }
  212. System.out.println(" done!");
  213. System.out.print("Waiting to establish debbie connections...");
  214. for (int i = 0; i < numComs; i++) {
  215. con2[i] = new Communication();
  216. InetSocketAddress addr = new InetSocketAddress(debbieIp, debbiePort);
  217. con2[i].connect(addr);
  218. debbiePort += 3;
  219. while (con2[i].getState() != Communication.STATE_CONNECTED)
  220. ;
  221. }
  222. System.out.println(" done!");
  223. for (int i = 0; i < numComs; i++) {
  224. con1[i].setTcpNoDelay(true);
  225. con2[i].setTcpNoDelay(true);
  226. }
  227. try {
  228. Protocol p = operationCtor.newInstance(con1[0], con2[0]);
  229. if (protocol.equals("rtv")) {
  230. ((Retrieve) p).setCons(con1, con2);
  231. }
  232. if (protocol.equals("pirrtv")) {
  233. ((PIRRetrieve) p).setCons(con1, con2);
  234. }
  235. if (!Global.usePIR) {
  236. p.run(Party.Charlie, md, forestFile);
  237. } else {
  238. p.run(Party.Charlie, md);
  239. }
  240. } catch (InstantiationException | IllegalAccessException | IllegalArgumentException
  241. | InvocationTargetException e) {
  242. e.printStackTrace();
  243. }
  244. } else {
  245. throw new NoSuchPartyException(party);
  246. }
  247. try {
  248. Thread.sleep(1000);
  249. } catch (InterruptedException e) {
  250. e.printStackTrace();
  251. }
  252. for (int i = 0; i < numComs; i++) {
  253. con1[i].stop();
  254. con2[i].stop();
  255. }
  256. System.out.println(party + " exiting...");
  257. }
  258. }