Salsa20.hpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /*
  2. * Based on public domain code available at: http://cr.yp.to/snuffle.html
  3. *
  4. * This therefore is public domain.
  5. */
  6. #ifndef ZT_SALSA20_HPP
  7. #define ZT_SALSA20_HPP
  8. #include "Constants.hpp"
  9. #include "Utils.hpp"
  10. #include <stdint.h>
  11. #include <stdio.h>
  12. #include <stdlib.h>
  13. #include <string.h>
  14. #if (! defined(ZT_SALSA20_SSE)) && (defined(__SSE2__) || (defined(__WINDOWS__) && ! defined(__MINGW32__) && ! defined(_M_ARM64)))
  15. #define ZT_SALSA20_SSE 1
  16. #endif
  17. #ifdef ZT_SALSA20_SSE
  18. #include <emmintrin.h>
  19. #endif // ZT_SALSA20_SSE
  20. namespace ZeroTier {
  21. /**
  22. * Salsa20 stream cipher
  23. */
  24. class Salsa20 {
  25. public:
  26. Salsa20()
  27. {
  28. }
  29. ~Salsa20()
  30. {
  31. Utils::burn(&_state, sizeof(_state));
  32. }
  33. /**
  34. * XOR d with s
  35. *
  36. * This is done efficiently using e.g. SSE if available. It's used when
  37. * alternative Salsa20 implementations are used in Packet and is here
  38. * since this is where all the SSE stuff is already included.
  39. *
  40. * @param d Destination to XOR
  41. * @param s Source bytes to XOR with destination
  42. * @param len Length of s and d
  43. */
  44. static inline void memxor(uint8_t* d, const uint8_t* s, unsigned int len)
  45. {
  46. #ifdef ZT_SALSA20_SSE
  47. while (len >= 128) {
  48. __m128i s0 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s));
  49. __m128i s1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 16));
  50. __m128i s2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 32));
  51. __m128i s3 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 48));
  52. __m128i s4 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 64));
  53. __m128i s5 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 80));
  54. __m128i s6 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 96));
  55. __m128i s7 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(s + 112));
  56. __m128i d0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d));
  57. __m128i d1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 16));
  58. __m128i d2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 32));
  59. __m128i d3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 48));
  60. __m128i d4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 64));
  61. __m128i d5 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 80));
  62. __m128i d6 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 96));
  63. __m128i d7 = _mm_loadu_si128(reinterpret_cast<__m128i*>(d + 112));
  64. d0 = _mm_xor_si128(d0, s0);
  65. d1 = _mm_xor_si128(d1, s1);
  66. d2 = _mm_xor_si128(d2, s2);
  67. d3 = _mm_xor_si128(d3, s3);
  68. d4 = _mm_xor_si128(d4, s4);
  69. d5 = _mm_xor_si128(d5, s5);
  70. d6 = _mm_xor_si128(d6, s6);
  71. d7 = _mm_xor_si128(d7, s7);
  72. _mm_storeu_si128(reinterpret_cast<__m128i*>(d), d0);
  73. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 16), d1);
  74. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 32), d2);
  75. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 48), d3);
  76. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 64), d4);
  77. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 80), d5);
  78. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 96), d6);
  79. _mm_storeu_si128(reinterpret_cast<__m128i*>(d + 112), d7);
  80. s += 128;
  81. d += 128;
  82. len -= 128;
  83. }
  84. while (len >= 16) {
  85. _mm_storeu_si128(reinterpret_cast<__m128i*>(d), _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<__m128i*>(d)), _mm_loadu_si128(reinterpret_cast<const __m128i*>(s))));
  86. s += 16;
  87. d += 16;
  88. len -= 16;
  89. }
  90. #else
  91. #ifndef ZT_NO_TYPE_PUNNING
  92. while (len >= 16) {
  93. (*reinterpret_cast<uint64_t*>(d)) ^= (*reinterpret_cast<const uint64_t*>(s));
  94. s += 8;
  95. d += 8;
  96. (*reinterpret_cast<uint64_t*>(d)) ^= (*reinterpret_cast<const uint64_t*>(s));
  97. s += 8;
  98. d += 8;
  99. len -= 16;
  100. }
  101. #endif
  102. #endif
  103. while (len) {
  104. --len;
  105. *(d++) ^= *(s++);
  106. }
  107. }
  108. /**
  109. * @param key 256-bit (32 byte) key
  110. * @param iv 64-bit initialization vector
  111. */
  112. Salsa20(const void* key, const void* iv)
  113. {
  114. init(key, iv);
  115. }
  116. /**
  117. * Initialize cipher
  118. *
  119. * @param key Key bits
  120. * @param iv 64-bit initialization vector
  121. */
  122. void init(const void* key, const void* iv);
  123. /**
  124. * Encrypt/decrypt data using Salsa20/12
  125. *
  126. * @param in Input data
  127. * @param out Output buffer
  128. * @param bytes Length of data
  129. */
  130. void crypt12(const void* in, void* out, unsigned int bytes);
  131. /**
  132. * Encrypt/decrypt data using Salsa20/20
  133. *
  134. * @param in Input data
  135. * @param out Output buffer
  136. * @param bytes Length of data
  137. */
  138. void crypt20(const void* in, void* out, unsigned int bytes);
  139. private:
  140. union {
  141. #ifdef ZT_SALSA20_SSE
  142. __m128i v[4];
  143. #endif // ZT_SALSA20_SSE
  144. uint32_t i[16];
  145. } _state;
  146. };
  147. } // namespace ZeroTier
  148. #endif