Salsa20.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. /*
  2. * Based on public domain code available at: http://cr.yp.to/snuffle.html
  3. *
  4. * This therefore is public domain.
  5. */
  6. #include "Salsa20.hpp"
  7. #include "Constants.hpp"
  8. #define ROTATE(v,c) (((v) << (c)) | ((v) >> (32 - (c))))
  9. #define XOR(v,w) ((v) ^ (w))
  10. #define PLUS(v,w) ((uint32_t)((v) + (w)))
  11. #if __BYTE_ORDER == __LITTLE_ENDIAN
  12. #define U8TO32_LITTLE(p) (*((const uint32_t *)((const void *)(p))))
  13. #define U32TO8_LITTLE(c,v) *((uint32_t *)((void *)(c))) = (v)
  14. #else
  15. #ifdef __GNUC__
  16. #define U8TO32_LITTLE(p) __builtin_bswap32(*((const uint32_t *)((const void *)(p))))
  17. #define U32TO8_LITTLE(c,v) *((uint32_t *)((void *)(c))) = __builtin_bswap32((v))
  18. #else
  19. error need be;
  20. #endif
  21. #endif
  22. #ifdef ZT_SALSA20_SSE
  23. class _s20sseconsts
  24. {
  25. public:
  26. _s20sseconsts()
  27. {
  28. maskLo32 = _mm_shuffle_epi32(_mm_cvtsi32_si128(-1), _MM_SHUFFLE(1, 0, 1, 0));
  29. maskHi32 = _mm_slli_epi64(maskLo32, 32);
  30. }
  31. __m128i maskLo32,maskHi32;
  32. };
  33. static const _s20sseconsts _S20SSECONSTANTS;
  34. #endif
  35. namespace ZeroTier {
  36. void Salsa20::init(const void *key,unsigned int kbits,const void *iv,unsigned int rounds)
  37. throw()
  38. {
  39. #ifdef ZT_SALSA20_SSE
  40. const uint32_t *k = (const uint32_t *)key;
  41. _state.i[0] = 0x61707865;
  42. _state.i[3] = 0x6b206574;
  43. _state.i[13] = k[0];
  44. _state.i[10] = k[1];
  45. _state.i[7] = k[2];
  46. _state.i[4] = k[3];
  47. if (kbits == 256) {
  48. k += 4;
  49. _state.i[1] = 0x3320646e;
  50. _state.i[2] = 0x79622d32;
  51. } else {
  52. _state.i[1] = 0x3120646e;
  53. _state.i[2] = 0x79622d36;
  54. }
  55. _state.i[15] = k[0];
  56. _state.i[12] = k[1];
  57. _state.i[9] = k[2];
  58. _state.i[6] = k[3];
  59. _state.i[14] = ((const uint32_t *)iv)[0];
  60. _state.i[11] = ((const uint32_t *)iv)[1];
  61. _state.i[5] = 0;
  62. _state.i[8] = 0;
  63. #else
  64. const char *constants;
  65. const uint8_t *k = (const uint8_t *)key;
  66. _state.i[1] = U8TO32_LITTLE(k + 0);
  67. _state.i[2] = U8TO32_LITTLE(k + 4);
  68. _state.i[3] = U8TO32_LITTLE(k + 8);
  69. _state.i[4] = U8TO32_LITTLE(k + 12);
  70. if (kbits == 256) { /* recommended */
  71. k += 16;
  72. constants = "expand 32-byte k";
  73. } else { /* kbits == 128 */
  74. constants = "expand 16-byte k";
  75. }
  76. _state.i[5] = U8TO32_LITTLE(constants + 4);
  77. _state.i[6] = U8TO32_LITTLE(((const uint8_t *)iv) + 0);
  78. _state.i[7] = U8TO32_LITTLE(((const uint8_t *)iv) + 4);
  79. _state.i[8] = 0;
  80. _state.i[9] = 0;
  81. _state.i[10] = U8TO32_LITTLE(constants + 8);
  82. _state.i[11] = U8TO32_LITTLE(k + 0);
  83. _state.i[12] = U8TO32_LITTLE(k + 4);
  84. _state.i[13] = U8TO32_LITTLE(k + 8);
  85. _state.i[14] = U8TO32_LITTLE(k + 12);
  86. _state.i[15] = U8TO32_LITTLE(constants + 12);
  87. _state.i[0] = U8TO32_LITTLE(constants + 0);
  88. #endif
  89. _roundsDiv2 = rounds / 2;
  90. }
  91. void Salsa20::encrypt(const void *in,void *out,unsigned int bytes)
  92. throw()
  93. {
  94. uint8_t tmp[64];
  95. const uint8_t *m = (const uint8_t *)in;
  96. uint8_t *c = (uint8_t *)out;
  97. uint8_t *ctarget = c;
  98. unsigned int i;
  99. #ifndef ZT_SALSA20_SSE
  100. uint32_t x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15;
  101. uint32_t j0, j1, j2, j3, j4, j5, j6, j7, j8, j9, j10, j11, j12, j13, j14, j15;
  102. #endif
  103. if (!bytes)
  104. return;
  105. #ifndef ZT_SALSA20_SSE
  106. j0 = _state.i[0];
  107. j1 = _state.i[1];
  108. j2 = _state.i[2];
  109. j3 = _state.i[3];
  110. j4 = _state.i[4];
  111. j5 = _state.i[5];
  112. j6 = _state.i[6];
  113. j7 = _state.i[7];
  114. j8 = _state.i[8];
  115. j9 = _state.i[9];
  116. j10 = _state.i[10];
  117. j11 = _state.i[11];
  118. j12 = _state.i[12];
  119. j13 = _state.i[13];
  120. j14 = _state.i[14];
  121. j15 = _state.i[15];
  122. #endif
  123. for (;;) {
  124. if (bytes < 64) {
  125. for (i = 0;i < bytes;++i)
  126. tmp[i] = m[i];
  127. m = tmp;
  128. ctarget = c;
  129. c = tmp;
  130. }
  131. #ifdef ZT_SALSA20_SSE
  132. __m128i X0 = _mm_loadu_si128((const __m128i *)&(_state.v[0]));
  133. __m128i X1 = _mm_loadu_si128((const __m128i *)&(_state.v[1]));
  134. __m128i X2 = _mm_loadu_si128((const __m128i *)&(_state.v[2]));
  135. __m128i X3 = _mm_loadu_si128((const __m128i *)&(_state.v[3]));
  136. __m128i X0s = X0;
  137. __m128i X1s = X1;
  138. __m128i X2s = X2;
  139. __m128i X3s = X3;
  140. for (i=0;i<_roundsDiv2;++i) {
  141. __m128i T = _mm_add_epi32(X0, X3);
  142. X1 = _mm_xor_si128(X1, _mm_slli_epi32(T, 7));
  143. X1 = _mm_xor_si128(X1, _mm_srli_epi32(T, 25));
  144. T = _mm_add_epi32(X1, X0);
  145. X2 = _mm_xor_si128(X2, _mm_slli_epi32(T, 9));
  146. X2 = _mm_xor_si128(X2, _mm_srli_epi32(T, 23));
  147. T = _mm_add_epi32(X2, X1);
  148. X3 = _mm_xor_si128(X3, _mm_slli_epi32(T, 13));
  149. X3 = _mm_xor_si128(X3, _mm_srli_epi32(T, 19));
  150. T = _mm_add_epi32(X3, X2);
  151. X0 = _mm_xor_si128(X0, _mm_slli_epi32(T, 18));
  152. X0 = _mm_xor_si128(X0, _mm_srli_epi32(T, 14));
  153. X1 = _mm_shuffle_epi32(X1, 0x93);
  154. X2 = _mm_shuffle_epi32(X2, 0x4E);
  155. X3 = _mm_shuffle_epi32(X3, 0x39);
  156. T = _mm_add_epi32(X0, X1);
  157. X3 = _mm_xor_si128(X3, _mm_slli_epi32(T, 7));
  158. X3 = _mm_xor_si128(X3, _mm_srli_epi32(T, 25));
  159. T = _mm_add_epi32(X3, X0);
  160. X2 = _mm_xor_si128(X2, _mm_slli_epi32(T, 9));
  161. X2 = _mm_xor_si128(X2, _mm_srli_epi32(T, 23));
  162. T = _mm_add_epi32(X2, X3);
  163. X1 = _mm_xor_si128(X1, _mm_slli_epi32(T, 13));
  164. X1 = _mm_xor_si128(X1, _mm_srli_epi32(T, 19));
  165. T = _mm_add_epi32(X1, X2);
  166. X0 = _mm_xor_si128(X0, _mm_slli_epi32(T, 18));
  167. X0 = _mm_xor_si128(X0, _mm_srli_epi32(T, 14));
  168. X1 = _mm_shuffle_epi32(X1, 0x39);
  169. X2 = _mm_shuffle_epi32(X2, 0x4E);
  170. X3 = _mm_shuffle_epi32(X3, 0x93);
  171. }
  172. X0 = _mm_add_epi32(X0s,X0);
  173. X1 = _mm_add_epi32(X1s,X1);
  174. X2 = _mm_add_epi32(X2s,X2);
  175. X3 = _mm_add_epi32(X3s,X3);
  176. {
  177. __m128i k02 = _mm_or_si128(_mm_slli_epi64(X0, 32), _mm_srli_epi64(X3, 32));
  178. k02 = _mm_shuffle_epi32(k02, _MM_SHUFFLE(0, 1, 2, 3));
  179. __m128i k13 = _mm_or_si128(_mm_slli_epi64(X1, 32), _mm_srli_epi64(X0, 32));
  180. k13 = _mm_shuffle_epi32(k13, _MM_SHUFFLE(0, 1, 2, 3));
  181. __m128i k20 = _mm_or_si128(_mm_and_si128(X2, _S20SSECONSTANTS.maskLo32), _mm_and_si128(X1, _S20SSECONSTANTS.maskHi32));
  182. __m128i k31 = _mm_or_si128(_mm_and_si128(X3, _S20SSECONSTANTS.maskLo32), _mm_and_si128(X2, _S20SSECONSTANTS.maskHi32));
  183. const float *const mv = (const float *)m;
  184. float *const cv = (float *)c;
  185. _mm_storeu_ps(cv,_mm_castsi128_ps(_mm_xor_si128(_mm_unpackhi_epi64(k02,k20),_mm_castps_si128(_mm_loadu_ps(mv)))));
  186. _mm_storeu_ps(cv + 4,_mm_castsi128_ps(_mm_xor_si128(_mm_unpackhi_epi64(k13,k31),_mm_castps_si128(_mm_loadu_ps(mv + 4)))));
  187. _mm_storeu_ps(cv + 8,_mm_castsi128_ps(_mm_xor_si128(_mm_unpacklo_epi64(k20,k02),_mm_castps_si128(_mm_loadu_ps(mv + 8)))));
  188. _mm_storeu_ps(cv + 12,_mm_castsi128_ps(_mm_xor_si128(_mm_unpacklo_epi64(k31,k13),_mm_castps_si128(_mm_loadu_ps(mv + 12)))));
  189. }
  190. if (!(++_state.i[8])) {
  191. ++_state.i[5]; // state reordered for SSE
  192. /* stopping at 2^70 bytes per nonce is user's responsibility */
  193. }
  194. #else
  195. x0 = j0;
  196. x1 = j1;
  197. x2 = j2;
  198. x3 = j3;
  199. x4 = j4;
  200. x5 = j5;
  201. x6 = j6;
  202. x7 = j7;
  203. x8 = j8;
  204. x9 = j9;
  205. x10 = j10;
  206. x11 = j11;
  207. x12 = j12;
  208. x13 = j13;
  209. x14 = j14;
  210. x15 = j15;
  211. for(i=0;i<_roundsDiv2;++i) {
  212. x4 = XOR( x4,ROTATE(PLUS( x0,x12), 7));
  213. x8 = XOR( x8,ROTATE(PLUS( x4, x0), 9));
  214. x12 = XOR(x12,ROTATE(PLUS( x8, x4),13));
  215. x0 = XOR( x0,ROTATE(PLUS(x12, x8),18));
  216. x9 = XOR( x9,ROTATE(PLUS( x5, x1), 7));
  217. x13 = XOR(x13,ROTATE(PLUS( x9, x5), 9));
  218. x1 = XOR( x1,ROTATE(PLUS(x13, x9),13));
  219. x5 = XOR( x5,ROTATE(PLUS( x1,x13),18));
  220. x14 = XOR(x14,ROTATE(PLUS(x10, x6), 7));
  221. x2 = XOR( x2,ROTATE(PLUS(x14,x10), 9));
  222. x6 = XOR( x6,ROTATE(PLUS( x2,x14),13));
  223. x10 = XOR(x10,ROTATE(PLUS( x6, x2),18));
  224. x3 = XOR( x3,ROTATE(PLUS(x15,x11), 7));
  225. x7 = XOR( x7,ROTATE(PLUS( x3,x15), 9));
  226. x11 = XOR(x11,ROTATE(PLUS( x7, x3),13));
  227. x15 = XOR(x15,ROTATE(PLUS(x11, x7),18));
  228. x1 = XOR( x1,ROTATE(PLUS( x0, x3), 7));
  229. x2 = XOR( x2,ROTATE(PLUS( x1, x0), 9));
  230. x3 = XOR( x3,ROTATE(PLUS( x2, x1),13));
  231. x0 = XOR( x0,ROTATE(PLUS( x3, x2),18));
  232. x6 = XOR( x6,ROTATE(PLUS( x5, x4), 7));
  233. x7 = XOR( x7,ROTATE(PLUS( x6, x5), 9));
  234. x4 = XOR( x4,ROTATE(PLUS( x7, x6),13));
  235. x5 = XOR( x5,ROTATE(PLUS( x4, x7),18));
  236. x11 = XOR(x11,ROTATE(PLUS(x10, x9), 7));
  237. x8 = XOR( x8,ROTATE(PLUS(x11,x10), 9));
  238. x9 = XOR( x9,ROTATE(PLUS( x8,x11),13));
  239. x10 = XOR(x10,ROTATE(PLUS( x9, x8),18));
  240. x12 = XOR(x12,ROTATE(PLUS(x15,x14), 7));
  241. x13 = XOR(x13,ROTATE(PLUS(x12,x15), 9));
  242. x14 = XOR(x14,ROTATE(PLUS(x13,x12),13));
  243. x15 = XOR(x15,ROTATE(PLUS(x14,x13),18));
  244. }
  245. x0 = PLUS(x0,j0);
  246. x1 = PLUS(x1,j1);
  247. x2 = PLUS(x2,j2);
  248. x3 = PLUS(x3,j3);
  249. x4 = PLUS(x4,j4);
  250. x5 = PLUS(x5,j5);
  251. x6 = PLUS(x6,j6);
  252. x7 = PLUS(x7,j7);
  253. x8 = PLUS(x8,j8);
  254. x9 = PLUS(x9,j9);
  255. x10 = PLUS(x10,j10);
  256. x11 = PLUS(x11,j11);
  257. x12 = PLUS(x12,j12);
  258. x13 = PLUS(x13,j13);
  259. x14 = PLUS(x14,j14);
  260. x15 = PLUS(x15,j15);
  261. U32TO8_LITTLE(c + 0,XOR(x0,U8TO32_LITTLE(m + 0)));
  262. U32TO8_LITTLE(c + 4,XOR(x1,U8TO32_LITTLE(m + 4)));
  263. U32TO8_LITTLE(c + 8,XOR(x2,U8TO32_LITTLE(m + 8)));
  264. U32TO8_LITTLE(c + 12,XOR(x3,U8TO32_LITTLE(m + 12)));
  265. U32TO8_LITTLE(c + 16,XOR(x4,U8TO32_LITTLE(m + 16)));
  266. U32TO8_LITTLE(c + 20,XOR(x5,U8TO32_LITTLE(m + 20)));
  267. U32TO8_LITTLE(c + 24,XOR(x6,U8TO32_LITTLE(m + 24)));
  268. U32TO8_LITTLE(c + 28,XOR(x7,U8TO32_LITTLE(m + 28)));
  269. U32TO8_LITTLE(c + 32,XOR(x8,U8TO32_LITTLE(m + 32)));
  270. U32TO8_LITTLE(c + 36,XOR(x9,U8TO32_LITTLE(m + 36)));
  271. U32TO8_LITTLE(c + 40,XOR(x10,U8TO32_LITTLE(m + 40)));
  272. U32TO8_LITTLE(c + 44,XOR(x11,U8TO32_LITTLE(m + 44)));
  273. U32TO8_LITTLE(c + 48,XOR(x12,U8TO32_LITTLE(m + 48)));
  274. U32TO8_LITTLE(c + 52,XOR(x13,U8TO32_LITTLE(m + 52)));
  275. U32TO8_LITTLE(c + 56,XOR(x14,U8TO32_LITTLE(m + 56)));
  276. U32TO8_LITTLE(c + 60,XOR(x15,U8TO32_LITTLE(m + 60)));
  277. if (!(++j8)) {
  278. ++j9;
  279. /* stopping at 2^70 bytes per nonce is user's responsibility */
  280. }
  281. #endif
  282. if (bytes <= 64) {
  283. if (bytes < 64) {
  284. for (i = 0;i < bytes;++i)
  285. ctarget[i] = c[i];
  286. }
  287. #ifndef ZT_SALSA20_SSE
  288. _state.i[8] = j8;
  289. _state.i[9] = j9;
  290. #endif
  291. return;
  292. }
  293. bytes -= 64;
  294. c += 64;
  295. m += 64;
  296. }
  297. }
  298. } // namespace ZeroTier