bitutils.hpp 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. /* Adapted from preprocessing/bitutils.h from
  2. * https://git-crysp.uwaterloo.ca/avadapal/duoram by Adithya Vadapalli,
  3. * itself adapted from code by Ryan Henry */
  4. #ifndef __BITUTILS_HPP__
  5. #define __BITUTILS_HPP__
  6. #include <array>
  7. #include <cstdint>
  8. #include <x86intrin.h> // SSE and AVX intrinsics
  9. static const __m128i bool128_mask[2] = {
  10. _mm_set_epi64x(0,1), // 0b00...0001
  11. _mm_set_epi64x(1,0) // 0b00...0001 << 64
  12. };
  13. static const __m128i lsb128_mask[4] = {
  14. _mm_setzero_si128(), // 0b00...0000
  15. _mm_set_epi64x(0,1), // 0b00...0001
  16. _mm_set_epi64x(0,2), // 0b00...0010
  17. _mm_set_epi64x(0,3) // 0b00...0011
  18. };
  19. static const __m128i lsb128_mask_inv[4] = {
  20. _mm_set1_epi8(-1), // 0b11...1111
  21. _mm_set_epi64x(-1,-2), // 0b11...1110
  22. _mm_set_epi64x(-1,-3), // 0b11...1101
  23. _mm_set_epi64x(-1,-4) // 0b11...1100
  24. };
  25. static const __m128i if128_mask[2] = {
  26. _mm_setzero_si128(), // 0b00...0000
  27. _mm_set1_epi8(-1) // 0b11...1111
  28. };
  29. inline __m128i xor_if(const __m128i & block1, const __m128i & block2, __m128i flag)
  30. {
  31. return _mm_xor_si128(block1, _mm_and_si128(block2, flag));
  32. }
  33. inline __m128i xor_if(const __m128i & block1, const __m128i & block2, bool flag)
  34. {
  35. return _mm_xor_si128(block1, _mm_and_si128(block2, if128_mask[flag ? 1 : 0]));
  36. }
  37. template <size_t LWIDTH>
  38. inline std::array<__m128i,LWIDTH> xor_if(
  39. const std::array<__m128i,LWIDTH> & block1,
  40. const std::array<__m128i,LWIDTH> & block2, bool flag)
  41. {
  42. std::array<__m128i,LWIDTH> res;
  43. for (size_t j=0;j<LWIDTH;++j) {
  44. res[j] = xor_if(block1[j], block2[j], flag);
  45. }
  46. return res;
  47. }
  48. inline uint8_t get_lsb(const __m128i & block, uint8_t bits = 0b01)
  49. {
  50. __m128i vcmp = _mm_xor_si128(_mm_and_si128(block, lsb128_mask[bits]), lsb128_mask[bits]);
  51. return static_cast<uint8_t>(_mm_testz_si128(vcmp, vcmp));
  52. }
  53. template <size_t LWIDTH>
  54. inline uint8_t get_lsb(const std::array<__m128i,LWIDTH> & block)
  55. {
  56. return get_lsb(block[0]);
  57. }
  58. inline __m128i clear_lsb(const __m128i & block, uint8_t bits = 0b01)
  59. {
  60. return _mm_and_si128(block, lsb128_mask_inv[bits]);
  61. }
  62. inline __m128i set_lsb(const __m128i & block, const bool val = true)
  63. {
  64. return _mm_or_si128(clear_lsb(block, 0b01), lsb128_mask[val ? 0b01 : 0b00]);
  65. }
  66. // The following can probably be improved by someone who knows the SIMD
  67. // instruction sets better than I do.
  68. // Return the parity of the number of bits set in block; that is, 1 if
  69. // there are an odd number of bits set in block; 0 if even
  70. inline uint8_t parity(const __m128i & block)
  71. {
  72. uint64_t low = uint64_t(_mm_cvtsi128_si64x(block));
  73. uint64_t high = uint64_t(_mm_cvtsi128_si64x(_mm_srli_si128(block,8)));
  74. return ((__builtin_popcountll(low) ^ __builtin_popcountll(high)) & 1);
  75. }
  76. // Return the parity of the number of the number of bits set in block
  77. // strictly above the given position
  78. inline uint8_t parity_above(const __m128i &block, uint8_t position)
  79. {
  80. uint64_t high = uint64_t(_mm_cvtsi128_si64x(_mm_srli_si128(block,8)));
  81. if (position >= 64) {
  82. uint64_t mask = (uint64_t(1)<<(position-64));
  83. mask |= (mask-1);
  84. mask = ~mask;
  85. return (__builtin_popcountll(high & mask) & 1);
  86. } else {
  87. uint64_t low = uint64_t(_mm_cvtsi128_si64x(block));
  88. uint64_t mask = (uint64_t(1)<<position);
  89. mask |= (mask-1);
  90. mask = ~mask;
  91. return ((__builtin_popcountll(high) +
  92. __builtin_popcountll(low & mask)) & 1);
  93. }
  94. }
  95. // Return the parity of the number of the number of bits set in block
  96. // strictly below the given position
  97. inline uint8_t parity_below(const __m128i &block, uint8_t position)
  98. {
  99. uint64_t low = uint64_t(_mm_cvtsi128_si64x(block));
  100. if (position >= 64) {
  101. uint64_t high = uint64_t(_mm_cvtsi128_si64x(_mm_srli_si128(block,8)));
  102. uint64_t mask = (uint64_t(1)<<(position-64))-1;
  103. return ((__builtin_popcountll(low) +
  104. __builtin_popcountll(high & mask)) & 1);
  105. } else {
  106. uint64_t mask = (uint64_t(1)<<position)-1;
  107. return (__builtin_popcountll(low & mask) & 1);
  108. }
  109. }
  110. // Return the bit at the given position in block
  111. inline uint8_t bit_at(const __m128i &block, uint8_t position)
  112. {
  113. if (position >= 64) {
  114. uint64_t high = uint64_t(_mm_cvtsi128_si64x(_mm_srli_si128(block,8)));
  115. return !!(high & (uint64_t(1)<<(position-64)));
  116. } else {
  117. uint64_t low = uint64_t(_mm_cvtsi128_si64x(block));
  118. return !!(low & (uint64_t(1)<<position));
  119. }
  120. }
  121. #endif