main.cpp 16 KB


  1. #include <iostream>
  2. #include <random>
  3. #include <chrono>
  4. #include "BGN.hpp"
  5. using namespace std;
  6. const size_t NUM_RUNS_PER_TEST = 100;
  7. const size_t MAX_VALUE_IN_TEST = 999;
  8. bool testDecrypt(int x)
  9. {
  10. bool retval;
  11. BGN system;
  12. Scalar testVal(x);
  13. Scalar one(1);
  14. Scalar decrypted;
  15. CurveBipoint curveEnc, curveOne;
  16. TwistBipoint twistEnc, twistOne;
  17. Quadripoint quadEncA, quadEncB;
  18. system.encrypt(curveEnc, testVal);
  19. system.encrypt(curveOne, one);
  20. system.encrypt(twistEnc, testVal);
  21. system.encrypt(twistOne, one);
  22. quadEncA = system.homomorphic_multiplication(curveEnc, twistOne);
  23. quadEncB = system.homomorphic_multiplication(curveOne, twistEnc);
  24. decrypted = system.decrypt(curveEnc);
  25. retval = (decrypted == testVal);
  26. decrypted = system.decrypt(twistEnc);
  27. retval = retval && (decrypted == testVal);
  28. decrypted = system.decrypt(quadEncA);
  29. retval = retval && (decrypted == testVal);
  30. decrypted = system.decrypt(quadEncB);
  31. retval = retval && (decrypted == testVal);
  32. return retval;
  33. }
  34. double testCurveEncryptSpeed(default_random_engine& generator)
  35. {
  36. BGN system;
  37. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  38. vector<Scalar> testVals;
  39. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  40. testVals.push_back(Scalar(distribution(generator)));
  41. vector<CurveBipoint> encryptions(NUM_RUNS_PER_TEST);
  42. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  43. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  44. system.encrypt(encryptions[i], testVals[i]);
  45. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  46. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  47. return time_span.count();
  48. }
  49. double testTwistEncryptSpeed(default_random_engine& generator)
  50. {
  51. BGN system;
  52. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  53. vector<Scalar> testVals;
  54. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  55. testVals.push_back(Scalar(distribution(generator)));
  56. vector<TwistBipoint> encryptions(NUM_RUNS_PER_TEST);
  57. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  58. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  59. system.encrypt(encryptions[i], testVals[i]);
  60. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  61. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  62. return time_span.count();
  63. }
  64. double testCurveDecryptSpeed(default_random_engine& generator)
  65. {
  66. BGN system;
  67. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  68. vector<Scalar> testVals;
  69. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  70. testVals.push_back(Scalar(distribution(generator)));
  71. vector<CurveBipoint> encryptions(NUM_RUNS_PER_TEST);
  72. vector<Scalar> decryptions(NUM_RUNS_PER_TEST);
  73. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  74. system.encrypt(encryptions[i], testVals[i]);
  75. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  76. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  77. decryptions[i] = system.decrypt(encryptions[i]);
  78. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  79. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  80. return time_span.count();
  81. }
  82. double testTwistDecryptSpeed(default_random_engine& generator)
  83. {
  84. BGN system;
  85. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  86. vector<Scalar> testVals;
  87. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  88. testVals.push_back(Scalar(distribution(generator)));
  89. vector<TwistBipoint> encryptions(NUM_RUNS_PER_TEST);
  90. vector<Scalar> decryptions(NUM_RUNS_PER_TEST);
  91. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  92. system.encrypt(encryptions[i], testVals[i]);
  93. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  94. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  95. decryptions[i] = system.decrypt(encryptions[i]);
  96. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  97. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  98. return time_span.count();
  99. }
  100. double testQuadDecryptSpeed(default_random_engine& generator)
  101. {
  102. BGN system;
  103. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  104. vector<Scalar> testVals;
  105. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  106. testVals.push_back(Scalar(distribution(generator)));
  107. Scalar one(1);
  108. TwistBipoint oneEncryption;
  109. vector<CurveBipoint> firstEncryptions(NUM_RUNS_PER_TEST);
  110. vector<Quadripoint> realEncryptions(NUM_RUNS_PER_TEST);
  111. vector<Scalar> decryptions(NUM_RUNS_PER_TEST);
  112. system.encrypt(oneEncryption, one);
  113. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  114. {
  115. system.encrypt(firstEncryptions[i], testVals[i]);
  116. realEncryptions[i] = system.homomorphic_multiplication(firstEncryptions[i], oneEncryption);
  117. }
  118. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  119. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  120. decryptions[i] = system.decrypt(realEncryptions[i]);
  121. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  122. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  123. return time_span.count();
  124. }
  125. bool testAddition(int x, int y)
  126. {
  127. bool retval;
  128. BGN system;
  129. Scalar testX(x);
  130. Scalar testY(y);
  131. Scalar testSum(x + y);
  132. Scalar one(1);
  133. Scalar decrypted;
  134. CurveBipoint curveX, curveY, curveSum, curveOne;
  135. TwistBipoint twistX, twistY, twistSum, twistOne;
  136. Quadripoint quadXA, quadXB, quadYA, quadYB,
  137. quadSumAA, quadSumAB, quadSumBA, quadSumBB;
  138. system.encrypt(curveX, testX);
  139. system.encrypt(curveY, testY);
  140. system.encrypt(curveOne, one);
  141. system.encrypt(twistX, testX);
  142. system.encrypt(twistY, testY);
  143. system.encrypt(twistOne, one);
  144. curveSum = system.homomorphic_addition(curveX, curveY);
  145. twistSum = system.homomorphic_addition(twistX, twistY);
  146. quadXA = system.homomorphic_multiplication(curveX, twistOne);
  147. quadXB = system.homomorphic_multiplication(curveOne, twistX);
  148. quadYA = system.homomorphic_multiplication(curveY, twistOne);
  149. quadYB = system.homomorphic_multiplication(curveOne, twistY);
  150. quadSumAA = system.homomorphic_addition(quadXA, quadYA);
  151. quadSumAB = system.homomorphic_addition(quadXA, quadYB);
  152. quadSumBA = system.homomorphic_addition(quadXB, quadYA);
  153. quadSumBB = system.homomorphic_addition(quadXB, quadYB);
  154. decrypted = system.decrypt(curveSum);
  155. retval = (decrypted == testSum);
  156. decrypted = system.decrypt(twistSum);
  157. retval = retval && (decrypted == testSum);
  158. decrypted = system.decrypt(quadSumAA);
  159. retval = retval && (decrypted == testSum);
  160. decrypted = system.decrypt(quadSumAB);
  161. retval = retval && (decrypted == testSum);
  162. decrypted = system.decrypt(quadSumBA);
  163. retval = retval && (decrypted == testSum);
  164. decrypted = system.decrypt(quadSumBB);
  165. retval = retval && (decrypted == testSum);
  166. return retval;
  167. }
  168. double testCurveAdditionSpeed(default_random_engine& generator)
  169. {
  170. BGN system;
  171. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  172. vector<Scalar> testXs;
  173. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  174. testXs.push_back(Scalar(distribution(generator)));
  175. vector<Scalar> testYs;
  176. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  177. testYs.push_back(Scalar(distribution(generator)));
  178. vector<CurveBipoint> encXs(NUM_RUNS_PER_TEST);
  179. vector<CurveBipoint> encYs(NUM_RUNS_PER_TEST);
  180. vector<CurveBipoint> encSums(NUM_RUNS_PER_TEST);
  181. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  182. {
  183. system.encrypt(encXs[i], testXs[i]);
  184. system.encrypt(encYs[i], testYs[i]);
  185. }
  186. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  187. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  188. encSums[i] = system.homomorphic_addition(encXs[i], encYs[i]);
  189. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  190. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  191. return time_span.count();
  192. }
  193. double testTwistAdditionSpeed(default_random_engine& generator)
  194. {
  195. BGN system;
  196. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  197. vector<Scalar> testXs;
  198. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  199. testXs.push_back(Scalar(distribution(generator)));
  200. vector<Scalar> testYs;
  201. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  202. testYs.push_back(Scalar(distribution(generator)));
  203. vector<TwistBipoint> encXs(NUM_RUNS_PER_TEST);
  204. vector<TwistBipoint> encYs(NUM_RUNS_PER_TEST);
  205. vector<TwistBipoint> encSums(NUM_RUNS_PER_TEST);
  206. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  207. {
  208. system.encrypt(encXs[i], testXs[i]);
  209. system.encrypt(encYs[i], testYs[i]);
  210. }
  211. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  212. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  213. encSums[i] = system.homomorphic_addition(encXs[i], encYs[i]);
  214. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  215. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  216. return time_span.count();
  217. }
  218. double testQuadAdditionSpeed(default_random_engine& generator)
  219. {
  220. BGN system;
  221. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  222. vector<Scalar> testXs;
  223. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  224. testXs.push_back(Scalar(distribution(generator)));
  225. vector<Scalar> testYs;
  226. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  227. testYs.push_back(Scalar(distribution(generator)));
  228. Scalar one(1);
  229. TwistBipoint oneEncryption;
  230. vector<CurveBipoint> firstEncXs(NUM_RUNS_PER_TEST);
  231. vector<Quadripoint> realEncXs(NUM_RUNS_PER_TEST);
  232. vector<CurveBipoint> firstEncYs(NUM_RUNS_PER_TEST);
  233. vector<Quadripoint> realEncYs(NUM_RUNS_PER_TEST);
  234. vector<Quadripoint> encSums(NUM_RUNS_PER_TEST);
  235. system.encrypt(oneEncryption, one);
  236. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  237. {
  238. system.encrypt(firstEncXs[i], testXs[i]);
  239. system.encrypt(firstEncYs[i], testYs[i]);
  240. realEncXs[i] = system.homomorphic_multiplication(firstEncXs[i], oneEncryption);
  241. realEncYs[i] = system.homomorphic_multiplication(firstEncYs[i], oneEncryption);
  242. }
  243. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  244. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  245. encSums[i] = system.homomorphic_addition(realEncXs[i], realEncYs[i]);
  246. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  247. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  248. return time_span.count();
  249. }
  250. bool testMultiplication(int x, int y)
  251. {
  252. bool retval;
  253. BGN system;
  254. Scalar testX(x);
  255. Scalar testY(y);
  256. Scalar testProduct(x * y);
  257. Scalar decrypted;
  258. CurveBipoint curveX, curveY;
  259. TwistBipoint twistX, twistY;
  260. Quadripoint productA, productB;
  261. system.encrypt(curveX, testX);
  262. system.encrypt(curveY, testY);
  263. system.encrypt(twistX, testX);
  264. system.encrypt(twistY, testY);
  265. productA = system.homomorphic_multiplication(curveX, twistY);
  266. productB = system.homomorphic_multiplication(curveY, twistX);
  267. decrypted = system.decrypt(productA);
  268. retval = (decrypted == testProduct);
  269. decrypted = system.decrypt(productB);
  270. retval = retval && (decrypted == testProduct);
  271. return retval;
  272. }
  273. double testMultiplicationSpeed(default_random_engine& generator)
  274. {
  275. BGN system;
  276. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  277. vector<Scalar> testXs;
  278. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  279. testXs.push_back(Scalar(distribution(generator)));
  280. vector<Scalar> testYs;
  281. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  282. testYs.push_back(Scalar(distribution(generator)));
  283. vector<CurveBipoint> encXs(NUM_RUNS_PER_TEST);
  284. vector<TwistBipoint> encYs(NUM_RUNS_PER_TEST);
  285. vector<Quadripoint> encProducts(NUM_RUNS_PER_TEST);
  286. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  287. {
  288. system.encrypt(encXs[i], testXs[i]);
  289. system.encrypt(encYs[i], testYs[i]);
  290. }
  291. chrono::high_resolution_clock::time_point t0 = chrono::high_resolution_clock::now();
  292. for (size_t i = 0; i < NUM_RUNS_PER_TEST; i++)
  293. encProducts[i] = system.homomorphic_multiplication(encXs[i], encYs[i]);
  294. chrono::high_resolution_clock::time_point t1 = chrono::high_resolution_clock::now();
  295. chrono::duration<double> time_span = chrono::duration_cast<chrono::duration<double>>(t1 - t0);
  296. return time_span.count();
  297. }
  298. int main(int argc, char *argv[])
  299. {
  300. string seedStr("default");
  301. if (argc > 1)
  302. seedStr = argv[1];
  303. seed_seq seed(seedStr.begin(), seedStr.end());
  304. default_random_engine generator(seed);
  305. uniform_int_distribution<int> distribution(0, MAX_VALUE_IN_TEST);
  306. cout << "test_PointAtInfinity: ";
  307. if (testDecrypt(0))
  308. cout << "PASS" << endl;
  309. else
  310. cout << "FAIL" << endl;
  311. cout << "test_GeneratorPoint: ";
  312. if (testDecrypt(1))
  313. cout << "PASS" << endl;
  314. else
  315. cout << "FAIL" << endl;
  316. int randomPoint = distribution(generator);
  317. cout << "test_RandomPoint (" << randomPoint << "): ";
  318. if (testDecrypt(randomPoint))
  319. cout << "PASS" << endl;
  320. else
  321. cout << "FAIL" << endl;
  322. cout << "test_CurveEncryptSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  323. cout << testCurveEncryptSpeed(generator) << " seconds" << endl;
  324. cout << "test_TwistEncryptSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  325. cout << testTwistEncryptSpeed(generator) << " seconds" << endl;
  326. cout << "test_CurveDecryptSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  327. cout << testCurveDecryptSpeed(generator) << " seconds" << endl;
  328. cout << "test_TwistDecryptSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  329. cout << testTwistDecryptSpeed(generator) << " seconds" << endl;
  330. cout << "test_QuadDecryptSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  331. cout << testQuadDecryptSpeed(generator) << " seconds" << endl;
  332. int addX = distribution(generator);
  333. int addY = distribution(generator);
  334. cout << "test_Addition (" << addX << ", " << addY << "): ";
  335. if (testAddition(addX, addY))
  336. cout << "PASS" << endl;
  337. else
  338. cout << "FAIL" << endl;
  339. cout << "test_CurveAdditionSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  340. cout << testCurveAdditionSpeed(generator) << " seconds" << endl;
  341. cout << "test_TwistAdditionSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  342. cout << testTwistAdditionSpeed(generator) << " seconds" << endl;
  343. cout << "test_QuadAdditionSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  344. cout << testQuadAdditionSpeed(generator) << " seconds" << endl;
  345. int multX = distribution(generator);
  346. int multY = distribution(generator);
  347. cout << "test_Multiplication (" << multX << ", " << multY << "): ";
  348. if (testMultiplication(multX, multY))
  349. cout << "PASS" << endl;
  350. else
  351. cout << "FAIL" << endl;
  352. cout << "test_MultiplicationSpeed (" << NUM_RUNS_PER_TEST << " runs): ";
  353. cout << testMultiplicationSpeed(generator) << " seconds" << endl;
  354. return 0;
  355. }