123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- package crypto;
- import java.security.InvalidKeyException;
- import java.security.NoSuchAlgorithmException;
- import java.util.Random;
- import javax.crypto.BadPaddingException;
- import javax.crypto.Cipher;
- import javax.crypto.IllegalBlockSizeException;
- import javax.crypto.NoSuchPaddingException;
- import javax.crypto.spec.SecretKeySpec;
- import org.bouncycastle.util.Arrays;
- import exceptions.IllegalInputException;
- import exceptions.LengthNotMatchException;
- import util.Util;
- public class PRF {
- private Cipher cipher;
- private int l; // output bit length
- private int maxInputBytes = 12;
- public PRF(int l) {
- try {
- cipher = Cipher.getInstance("AES/ECB/NoPadding");
- } catch (NoSuchAlgorithmException | NoSuchPaddingException e) {
- e.printStackTrace();
- }
- this.l = l;
- }
- public void init(byte[] key) {
- if (key.length != 16)
- throw new LengthNotMatchException(key.length + " != 16");
- SecretKeySpec skey = new SecretKeySpec(key, "AES");
- try {
- cipher.init(Cipher.ENCRYPT_MODE, skey);
- } catch (InvalidKeyException e) {
- e.printStackTrace();
- }
- }
- public synchronized byte[] compute(byte[] input) {
- if (input.length > maxInputBytes)
- throw new IllegalInputException(input.length + " > " + maxInputBytes);
- byte[] in = new byte[16];
- System.arraycopy(input, 0, in, in.length - input.length, input.length);
- byte[] output = null;
- if (l <= 128)
- output = leq128(in, l);
- else
- output = g128(in);
- return output;
- }
- private byte[] leq128(byte[] input, int np) {
- byte[] ctext = null;
- try {
- ctext = cipher.doFinal(input);
- } catch (IllegalBlockSizeException | BadPaddingException e) {
- e.printStackTrace();
- }
- int outBytes = (np + 7) / 8;
- if (ctext.length == outBytes)
- return ctext;
- else
- return Arrays.copyOfRange(ctext, ctext.length - outBytes, ctext.length);
- }
- private byte[] g128(byte[] input) {
- int n = l / 128;
- int outBytes = (l + 7) / 8;
- byte[] output = new byte[outBytes];
- int len = Math.min(16 - maxInputBytes, 4);
- for (int i = 0; i < n; i++) {
- byte[] index = Util.intToBytes(i + 1);
- System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
- byte[] seg = leq128(input, 128);
- System.arraycopy(seg, 0, output, i * seg.length, seg.length);
- }
- int np = l % 128;
- if (np == 0)
- return output;
- byte[] index = Util.intToBytes(n + 1);
- System.arraycopy(index, 4 - len, input, 16 - maxInputBytes - len, len);
- byte[] last = leq128(input, np);
- System.arraycopy(last, 0, output, outBytes - last.length, last.length);
- return output;
- }
- public static byte[] generateKey(Random rand) {
- byte[] key = new byte[16];
- rand.nextBytes(key);
- return key;
- }
- }
|