PRF.java 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package crypto;
  2. import java.security.InvalidKeyException;
  3. import java.security.NoSuchAlgorithmException;
  4. import java.util.Random;
  5. import javax.crypto.BadPaddingException;
  6. import javax.crypto.Cipher;
  7. import javax.crypto.IllegalBlockSizeException;
  8. import javax.crypto.NoSuchPaddingException;
  9. import javax.crypto.spec.SecretKeySpec;
  10. import org.bouncycastle.util.Arrays;
  11. import exceptions.IllegalInputException;
  12. import exceptions.LengthNotMatchException;
  13. import util.Util;
  14. public class PRF {
  15. private Cipher cipher;
  16. private int l; // output bit length
  17. private int maxInputBytes = 12;
  18. public PRF(int l) {
  19. try {
  20. cipher = Cipher.getInstance("AES/ECB/NoPadding");
  21. } catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
  22. e.printStackTrace();
  23. }
  24. this.l = l;
  25. }
  26. public void init(byte[] key) {
  27. if (key.length != 16)
  28. throw new LengthNotMatchException(key.length + " != 16");
  29. SecretKeySpec skey = new SecretKeySpec(key, "AES");
  30. try {
  31. cipher.init(Cipher.ENCRYPT_MODE, skey);
  32. } catch (InvalidKeyException e) {
  33. e.printStackTrace();
  34. }
  35. }
  36. public synchronized byte[] compute(byte[] input) {
  37. if (input.length > maxInputBytes)
  38. throw new IllegalInputException(input.length + " > " + maxInputBytes);
  39. byte[] in = new byte[16];
  40. System.arraycopy(input, 0, in, in.length - input.length, input.length);
  41. byte[] output = null;
  42. if (l <= 128)
  43. output = leq128(in, l);
  44. else
  45. output = g128(in);
  46. return output;
  47. }
  48. private byte[] leq128(byte[] input, int np) {
  49. byte[] ctext = null;
  50. try {
  51. ctext = cipher.doFinal(input);
  52. } catch (IllegalBlockSizeException | BadPaddingException e) {
  53. e.printStackTrace();
  54. }
  55. int outBytes = (np + 7) / 8;
  56. if (ctext.length == outBytes)
  57. return ctext;
  58. else
  59. return Arrays.copyOfRange(ctext, ctext.length - outBytes, ctext.length);
  60. }
  61. private byte[] g128(byte[] input) {
  62. int n = l / 128;
  63. int outBytes = (l + 7) / 8;
  64. byte[] output = new byte[outBytes];
  65. int len = Math.min(16 - maxInputBytes, 4);
  66. for (int i = 0; i < n; i++) {
  67. byte[] index = Util.intToBytes(i + 1);
  68. System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
  69. byte[] seg = leq128(input, 128);
  70. System.arraycopy(seg, 0, output, i * seg.length, seg.length);
  71. }
  72. int np = l % 128;
  73. if (np == 0)
  74. return output;
  75. byte[] index = Util.intToBytes(n + 1);
  76. System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
  77. byte[] last = leq128(input, np);
  78. System.arraycopy(last, 0, output, outBytes - last.length, last.length);
  79. return output;
  80. }
  81. public static byte[] generateKey(Random rand) {
  82. byte[] key = new byte[16];
  83. rand.nextBytes(key);
  84. return key;
  85. }
  86. }