PRF.java 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package crypto;
  2. import java.security.InvalidKeyException;
  3. import java.security.NoSuchAlgorithmException;
  4. import javax.crypto.BadPaddingException;
  5. import javax.crypto.Cipher;
  6. import javax.crypto.IllegalBlockSizeException;
  7. import javax.crypto.NoSuchPaddingException;
  8. import javax.crypto.spec.SecretKeySpec;
  9. import org.bouncycastle.util.Arrays;
  10. import exceptions.IllegalInputException;
  11. import exceptions.LengthNotMatchException;
  12. import util.Util;
  13. public class PRF {
  14. private Cipher cipher;
  15. private int l; // output bit length
  16. private int maxInputBytes = 12;
  17. public PRF(int l) {
  18. try {
  19. cipher = Cipher.getInstance("AES/ECB/NoPadding");
  20. } catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
  21. e.printStackTrace();
  22. }
  23. this.l = l;
  24. }
  25. public void init(byte[] key) {
  26. if (key.length != 16)
  27. throw new LengthNotMatchException(key.length + " != 16");
  28. SecretKeySpec skey = new SecretKeySpec(key, "AES");
  29. try {
  30. cipher.init(Cipher.ENCRYPT_MODE, skey);
  31. } catch (InvalidKeyException e) {
  32. e.printStackTrace();
  33. }
  34. }
  35. public byte[] compute(byte[] input) {
  36. if (input.length > maxInputBytes)
  37. throw new IllegalInputException(input.length + " > " + maxInputBytes);
  38. byte[] in = new byte[16];
  39. System.arraycopy(input, 0, in, in.length - input.length, input.length);
  40. byte[] output = null;
  41. if (l <= 128)
  42. output = leq128(in, l);
  43. else
  44. output = g128(in);
  45. return output;
  46. }
  47. private byte[] leq128(byte[] input, int np) {
  48. byte[] ctext = null;
  49. try {
  50. ctext = cipher.doFinal(input);
  51. } catch (IllegalBlockSizeException | BadPaddingException e) {
  52. e.printStackTrace();
  53. }
  54. int outBytes = (np + 7) / 8;
  55. if (ctext.length == outBytes)
  56. return ctext;
  57. else
  58. return Arrays.copyOfRange(ctext, ctext.length - outBytes, ctext.length);
  59. }
  60. private byte[] g128(byte[] input) {
  61. int n = l / 128;
  62. int outBytes = (l + 7) / 8;
  63. byte[] output = new byte[outBytes];
  64. int len = Math.min(16 - maxInputBytes, 4);
  65. for (int i = 0; i < n; i++) {
  66. byte[] index = Util.intToBytes(i + 1);
  67. System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
  68. byte[] seg = leq128(input, 128);
  69. System.arraycopy(seg, 0, output, i * seg.length, seg.length);
  70. }
  71. int np = l % 128;
  72. if (np == 0)
  73. return output;
  74. byte[] index = Util.intToBytes(n + 1);
  75. System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
  76. byte[] last = leq128(input, np);
  77. System.arraycopy(last, 0, output, outBytes - last.length, last.length);
  78. return output;
  79. }
  80. }