AES-aesni.c 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. /*
  2. * Copyright (c)2019 ZeroTier, Inc.
  3. *
  4. * Use of this software is governed by the Business Source License included
  5. * in the LICENSE.TXT file in the project's root directory.
  6. *
  7. * Change Date: 2023-01-01
  8. *
  9. * On the date above, in accordance with the Business Source License, use
  10. * of this software will be governed by version 2.0 of the Apache License.
  11. */
  12. /****/
  13. /* This is done in plain C because the compiler (at least GCC and CLANG) seem
  14. * to do a *slightly* better job optimizing this intrinsic code when compiling
  15. * plain C. C also gives us the register hint keyword, which seems to actually
  16. * make a small difference. */
  17. #if (defined(__amd64) || defined(__amd64__) || defined(__x86_64) || defined(__x86_64__) || defined(__AMD64) || defined(__AMD64__) || defined(_M_X64))
  18. #include <stdint.h>
  19. #include <wmmintrin.h>
  20. #include <emmintrin.h>
  21. #include <smmintrin.h>
  22. #define ZT_AES_CTR_AESNI_ROUND(kk) c0 = _mm_aesenc_si128(c0,kk); c1 = _mm_aesenc_si128(c1,kk); c2 = _mm_aesenc_si128(c2,kk); c3 = _mm_aesenc_si128(c3,kk);
  23. void zt_crypt_ctr_aesni(const __m128i key[14],const uint8_t iv[16],const uint8_t *in,unsigned int len,uint8_t *out)
  24. {
  25. /* Because our CTR supports full 128-bit nonces, we must do a full 128-bit (big-endian)
  26. * increment to be compatible with canonical NIST-certified CTR implementations. That's
  27. * because it's possible to have a lot of bit saturation in the least significant 64
  28. * bits, which could on rare occasions actually cause a 64-bit wrap. If this happened
  29. * without carry it would result in incompatibility and quietly dropped packets. The
  30. * probability is low, so this would be a one in billions packet loss bug that would
  31. * probably never be found.
  32. *
  33. * This crazy code does a branch-free 128-bit increment by adding a one or a zero to
  34. * the most significant 64 bits of the 128-bit vector based on whether the add we want
  35. * to do to the least significant 64 bits would overflow. This can be computed by
  36. * NOTing those bits and comparing with what we want to add, since NOT is the same
  37. * as subtracting from uint64_max. This generates branch-free ASM on x64 with most
  38. * good compilers. */
  39. register __m128i swap128 = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
  40. register __m128i ctr0 = _mm_shuffle_epi8(_mm_loadu_si128((__m128i *)iv),swap128);
  41. register uint64_t notctr0msq = ~((uint64_t)_mm_extract_epi64(ctr0,0));
  42. register __m128i ctr1 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 1ULL),1LL)),swap128);
  43. register __m128i ctr2 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 2ULL),2LL)),swap128);
  44. register __m128i ctr3 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 3ULL),3LL)),swap128);
  45. ctr0 = _mm_shuffle_epi8(ctr0,swap128);
  46. register __m128i k0 = key[0];
  47. register __m128i k1 = key[1];
  48. while (len >= 64) {
  49. register __m128i ka = key[2];
  50. register __m128i c0 = _mm_xor_si128(ctr0,k0);
  51. register __m128i c1 = _mm_xor_si128(ctr1,k0);
  52. register __m128i c2 = _mm_xor_si128(ctr2,k0);
  53. register __m128i c3 = _mm_xor_si128(ctr3,k0);
  54. ctr0 = _mm_shuffle_epi8(ctr0,swap128);
  55. notctr0msq = ~((uint64_t)_mm_extract_epi64(ctr0,0));
  56. ctr1 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 5ULL),5LL)),swap128);
  57. ctr2 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 6ULL),6LL)),swap128);
  58. ctr3 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 7ULL),7LL)),swap128);
  59. ctr0 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)(notctr0msq < 4ULL),4LL)),swap128);
  60. register __m128i kb = key[3];
  61. ZT_AES_CTR_AESNI_ROUND(k1);
  62. register __m128i kc = key[4];
  63. ZT_AES_CTR_AESNI_ROUND(ka);
  64. register __m128i kd = key[5];
  65. ZT_AES_CTR_AESNI_ROUND(kb);
  66. ka = key[6];
  67. ZT_AES_CTR_AESNI_ROUND(kc);
  68. kb = key[7];
  69. ZT_AES_CTR_AESNI_ROUND(kd);
  70. kc = key[8];
  71. ZT_AES_CTR_AESNI_ROUND(ka);
  72. kd = key[9];
  73. ZT_AES_CTR_AESNI_ROUND(kb);
  74. ka = key[10];
  75. ZT_AES_CTR_AESNI_ROUND(kc);
  76. kb = key[11];
  77. ZT_AES_CTR_AESNI_ROUND(kd);
  78. kc = key[12];
  79. ZT_AES_CTR_AESNI_ROUND(ka);
  80. kd = key[13];
  81. ZT_AES_CTR_AESNI_ROUND(kb);
  82. ka = key[14];
  83. ZT_AES_CTR_AESNI_ROUND(kc);
  84. ZT_AES_CTR_AESNI_ROUND(kd);
  85. _mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,ka)));
  86. _mm_storeu_si128((__m128i *)(out + 16),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 16)),_mm_aesenclast_si128(c1,ka)));
  87. _mm_storeu_si128((__m128i *)(out + 32),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 32)),_mm_aesenclast_si128(c2,ka)));
  88. _mm_storeu_si128((__m128i *)(out + 48),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 48)),_mm_aesenclast_si128(c3,ka)));
  89. in += 64;
  90. out += 64;
  91. len -= 64;
  92. }
  93. register __m128i k2 = key[2];
  94. register __m128i k3 = key[3];
  95. register __m128i k4 = key[4];
  96. register __m128i k5 = key[5];
  97. register __m128i k6 = key[6];
  98. register __m128i k7 = key[7];
  99. while (len >= 16) {
  100. register __m128i c0 = _mm_xor_si128(ctr0,k0);
  101. ctr0 = _mm_shuffle_epi8(ctr0,swap128);
  102. ctr0 = _mm_shuffle_epi8(_mm_add_epi64(ctr0,_mm_set_epi64x((long long)((~((uint64_t)_mm_extract_epi64(ctr0,0))) < 1ULL),1LL)),swap128);
  103. c0 = _mm_aesenc_si128(c0,k1);
  104. c0 = _mm_aesenc_si128(c0,k2);
  105. c0 = _mm_aesenc_si128(c0,k3);
  106. c0 = _mm_aesenc_si128(c0,k4);
  107. c0 = _mm_aesenc_si128(c0,k5);
  108. c0 = _mm_aesenc_si128(c0,k6);
  109. register __m128i ka = key[8];
  110. c0 = _mm_aesenc_si128(c0,k7);
  111. register __m128i kb = key[9];
  112. c0 = _mm_aesenc_si128(c0,ka);
  113. ka = key[10];
  114. c0 = _mm_aesenc_si128(c0,kb);
  115. kb = key[11];
  116. c0 = _mm_aesenc_si128(c0,ka);
  117. ka = key[12];
  118. c0 = _mm_aesenc_si128(c0,kb);
  119. kb = key[13];
  120. c0 = _mm_aesenc_si128(c0,ka);
  121. ka = key[14];
  122. c0 = _mm_aesenc_si128(c0,kb);
  123. _mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,ka)));
  124. in += 16;
  125. out += 16;
  126. len -= 16;
  127. }
  128. if (len) {
  129. register __m128i c0 = _mm_xor_si128(ctr0,k0);
  130. k0 = key[8];
  131. c0 = _mm_aesenc_si128(c0,k1);
  132. c0 = _mm_aesenc_si128(c0,k2);
  133. k1 = key[9];
  134. c0 = _mm_aesenc_si128(c0,k3);
  135. c0 = _mm_aesenc_si128(c0,k4);
  136. k2 = key[10];
  137. c0 = _mm_aesenc_si128(c0,k5);
  138. c0 = _mm_aesenc_si128(c0,k6);
  139. k3 = key[11];
  140. c0 = _mm_aesenc_si128(c0,k7);
  141. c0 = _mm_aesenc_si128(c0,k0);
  142. k0 = key[12];
  143. c0 = _mm_aesenc_si128(c0,k1);
  144. c0 = _mm_aesenc_si128(c0,k2);
  145. k1 = key[13];
  146. c0 = _mm_aesenc_si128(c0,k3);
  147. c0 = _mm_aesenc_si128(c0,k0);
  148. k2 = key[14];
  149. c0 = _mm_aesenc_si128(c0,k1);
  150. c0 = _mm_aesenclast_si128(c0,k2);
  151. uint8_t tmp[16];
  152. _mm_storeu_si128((__m128i *)tmp,c0);
  153. for(unsigned int i=0;i<len;++i)
  154. out[i] = in[i] ^ tmp[i];
  155. }
  156. }
  157. #endif