Salsa20.hpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /*
  2. * Based on public domain code available at: http://cr.yp.to/snuffle.html
  3. *
  4. * This therefore is public domain.
  5. */
  6. #ifndef ZT_SALSA20_HPP
  7. #define ZT_SALSA20_HPP
  8. #include <stdio.h>
  9. #include <stdint.h>
  10. #include <stdlib.h>
  11. #include <string.h>
  12. #include "Constants.hpp"
  13. #include "Utils.hpp"
  14. #if (!defined(ZT_SALSA20_SSE)) && (defined(__SSE2__) || (defined(__WINDOWS__) && !defined(__MINGW32__) && !defined(_M_ARM64)))
  15. #define ZT_SALSA20_SSE 1
  16. #endif
  17. #ifdef ZT_SALSA20_SSE
  18. #include <emmintrin.h>
  19. #endif // ZT_SALSA20_SSE
  20. namespace ZeroTier {
  21. /**
  22. * Salsa20 stream cipher
  23. */
  24. class Salsa20
  25. {
  26. public:
  27. Salsa20() {}
  28. ~Salsa20() { Utils::burn(&_state,sizeof(_state)); }
  29. /**
  30. * XOR d with s
  31. *
  32. * This is done efficiently using e.g. SSE if available. It's used when
  33. * alternative Salsa20 implementations are used in Packet and is here
  34. * since this is where all the SSE stuff is already included.
  35. *
  36. * @param d Destination to XOR
  37. * @param s Source bytes to XOR with destination
  38. * @param len Length of s and d
  39. */
  40. static inline void memxor(uint8_t *d,const uint8_t *s,unsigned int len)
  41. {
  42. #ifdef ZT_SALSA20_SSE
  43. while (len >= 128) {
  44. __m128i s0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s));
  45. __m128i s1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 16));
  46. __m128i s2 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 32));
  47. __m128i s3 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 48));
  48. __m128i s4 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 64));
  49. __m128i s5 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 80));
  50. __m128i s6 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 96));
  51. __m128i s7 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(s + 112));
  52. __m128i d0 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d));
  53. __m128i d1 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 16));
  54. __m128i d2 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 32));
  55. __m128i d3 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 48));
  56. __m128i d4 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 64));
  57. __m128i d5 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 80));
  58. __m128i d6 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 96));
  59. __m128i d7 = _mm_loadu_si128(reinterpret_cast<__m128i *>(d + 112));
  60. d0 = _mm_xor_si128(d0,s0);
  61. d1 = _mm_xor_si128(d1,s1);
  62. d2 = _mm_xor_si128(d2,s2);
  63. d3 = _mm_xor_si128(d3,s3);
  64. d4 = _mm_xor_si128(d4,s4);
  65. d5 = _mm_xor_si128(d5,s5);
  66. d6 = _mm_xor_si128(d6,s6);
  67. d7 = _mm_xor_si128(d7,s7);
  68. _mm_storeu_si128(reinterpret_cast<__m128i *>(d),d0);
  69. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 16),d1);
  70. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 32),d2);
  71. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 48),d3);
  72. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 64),d4);
  73. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 80),d5);
  74. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 96),d6);
  75. _mm_storeu_si128(reinterpret_cast<__m128i *>(d + 112),d7);
  76. s += 128;
  77. d += 128;
  78. len -= 128;
  79. }
  80. while (len >= 16) {
  81. _mm_storeu_si128(reinterpret_cast<__m128i *>(d),_mm_xor_si128(_mm_loadu_si128(reinterpret_cast<__m128i *>(d)),_mm_loadu_si128(reinterpret_cast<const __m128i *>(s))));
  82. s += 16;
  83. d += 16;
  84. len -= 16;
  85. }
  86. #else
  87. #ifndef ZT_NO_TYPE_PUNNING
  88. while (len >= 16) {
  89. (*reinterpret_cast<uint64_t *>(d)) ^= (*reinterpret_cast<const uint64_t *>(s));
  90. s += 8;
  91. d += 8;
  92. (*reinterpret_cast<uint64_t *>(d)) ^= (*reinterpret_cast<const uint64_t *>(s));
  93. s += 8;
  94. d += 8;
  95. len -= 16;
  96. }
  97. #endif
  98. #endif
  99. while (len) {
  100. --len;
  101. *(d++) ^= *(s++);
  102. }
  103. }
  104. /**
  105. * @param key 256-bit (32 byte) key
  106. * @param iv 64-bit initialization vector
  107. */
  108. Salsa20(const void *key,const void *iv)
  109. {
  110. init(key,iv);
  111. }
  112. /**
  113. * Initialize cipher
  114. *
  115. * @param key Key bits
  116. * @param iv 64-bit initialization vector
  117. */
  118. void init(const void *key,const void *iv);
  119. /**
  120. * Encrypt/decrypt data using Salsa20/12
  121. *
  122. * @param in Input data
  123. * @param out Output buffer
  124. * @param bytes Length of data
  125. */
  126. void crypt12(const void *in,void *out,unsigned int bytes);
  127. /**
  128. * Encrypt/decrypt data using Salsa20/20
  129. *
  130. * @param in Input data
  131. * @param out Output buffer
  132. * @param bytes Length of data
  133. */
  134. void crypt20(const void *in,void *out,unsigned int bytes);
  135. private:
  136. union {
  137. #ifdef ZT_SALSA20_SSE
  138. __m128i v[4];
  139. #endif // ZT_SALSA20_SSE
  140. uint32_t i[16];
  141. } _state;
  142. };
  143. } // namespace ZeroTier
  144. #endif