Browse Source

Ridiculously fast AES-CTR

Adam Ierymenko 5 years ago
parent
commit
1f02250dd8
1 changed files with 127 additions and 105 deletions
  1. 127 105
      node/AES.cpp

+ 127 - 105
node/AES.cpp

@@ -477,6 +477,15 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 		uint64_t c0 = _ctr[0];
 		uint64_t c0 = _ctr[0];
 		uint64_t c1 = Utils::ntoh(_ctr[1]);
 		uint64_t c1 = Utils::ntoh(_ctr[1]);
 
 
+		// There are 16 XMM registers. We can reserve six of them for the
+		// first six parts of the expanded AES key.
+		const __m128i k0 = _aes._k.ni.k[0];
+		const __m128i k1 = _aes._k.ni.k[1];
+		const __m128i k2 = _aes._k.ni.k[2];
+		const __m128i k3 = _aes._k.ni.k[3];
+		const __m128i k4 = _aes._k.ni.k[4];
+		const __m128i k5 = _aes._k.ni.k[5];
+
 		// Complete any unfinished blocks from previous calls to crypt().
 		// Complete any unfinished blocks from previous calls to crypt().
 		unsigned int totalLen = _len;
 		unsigned int totalLen = _len;
 		if ((totalLen & 15U)) {
 		if ((totalLen & 15U)) {
@@ -491,23 +500,33 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 				out[totalLen++] = *(in++);
 				out[totalLen++] = *(in++);
 				if (!(totalLen & 15U)) {
 				if (!(totalLen & 15U)) {
 					__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
 					__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
-					d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[2]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[3]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[4]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[5]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[6]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[7]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[8]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[9]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[10]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[11]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[12]);
-					d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[13]);
-					d0 = _mm_aesenclast_si128(d0,_aes._k.ni.k[14]);
+					d0 = _mm_xor_si128(d0,k0);
+					d0 = _mm_aesenc_si128(d0,k1);
+					__m128i ka = _aes._k.ni.k[6];
+					d0 = _mm_aesenc_si128(d0,k2);
+					__m128i kb = _aes._k.ni.k[7];
+					d0 = _mm_aesenc_si128(d0,k3);
+					__m128i kc = _aes._k.ni.k[8];
+					d0 = _mm_aesenc_si128(d0,k4);
+					__m128i kd = _aes._k.ni.k[9];
+					d0 = _mm_aesenc_si128(d0,k5);
+					__m128i ke = _aes._k.ni.k[10];
+					d0 = _mm_aesenc_si128(d0,ka);
+					__m128i kf = _aes._k.ni.k[11];
+					d0 = _mm_aesenc_si128(d0,kb);
+					__m128i kg = _aes._k.ni.k[12];
+					d0 = _mm_aesenc_si128(d0,kc);
+					__m128i kh = _aes._k.ni.k[13];
+					d0 = _mm_aesenc_si128(d0,kd);
+					ka = _aes._k.ni.k[14];
+					d0 = _mm_aesenc_si128(d0,ke);
 					__m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16));
 					__m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16));
-					_mm_storeu_si128(outblk,_mm_xor_si128(_mm_loadu_si128(outblk),d0));
+					d0 = _mm_aesenc_si128(d0,kf);
+					const __m128i p0 = _mm_loadu_si128(outblk);
+					d0 = _mm_aesenc_si128(d0,kg);
+					d0 = _mm_aesenc_si128(d0,kh);
+					d0 = _mm_aesenclast_si128(d0,ka);
+					_mm_storeu_si128(outblk,_mm_xor_si128(p0,d0));
 					if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 					if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 					break;
 					break;
 				}
 				}
@@ -536,10 +555,6 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 				if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 				if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 			}
 			}
 
 
-			__m128i k0 = _aes._k.ni.k[0];
-			__m128i k1 = _aes._k.ni.k[1];
-			__m128i k2 = _aes._k.ni.k[2];
-			__m128i k3 = _aes._k.ni.k[3];
 			d0 = _mm_xor_si128(d0,k0);
 			d0 = _mm_xor_si128(d0,k0);
 			d1 = _mm_xor_si128(d1,k0);
 			d1 = _mm_xor_si128(d1,k0);
 			d2 = _mm_xor_si128(d2,k0);
 			d2 = _mm_xor_si128(d2,k0);
@@ -548,82 +563,79 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 			d1 = _mm_aesenc_si128(d1,k1);
 			d1 = _mm_aesenc_si128(d1,k1);
 			d2 = _mm_aesenc_si128(d2,k1);
 			d2 = _mm_aesenc_si128(d2,k1);
 			d3 = _mm_aesenc_si128(d3,k1);
 			d3 = _mm_aesenc_si128(d3,k1);
-			k0 = _aes._k.ni.k[4];
-			k1 = _aes._k.ni.k[5];
-			d0 = _mm_aesenc_si128(d0,k2);
-			d1 = _mm_aesenc_si128(d1,k2);
-			d2 = _mm_aesenc_si128(d2,k2);
-			d3 = _mm_aesenc_si128(d3,k2);
-			d0 = _mm_aesenc_si128(d0,k3);
-			d1 = _mm_aesenc_si128(d1,k3);
-			d2 = _mm_aesenc_si128(d2,k3);
-			d3 = _mm_aesenc_si128(d3,k3);
-			k2 = _aes._k.ni.k[6];
-			k3 = _aes._k.ni.k[7];
-			d0 = _mm_aesenc_si128(d0,k0);
-			d1 = _mm_aesenc_si128(d1,k0);
-			d2 = _mm_aesenc_si128(d2,k0);
-			d3 = _mm_aesenc_si128(d3,k0);
-			d0 = _mm_aesenc_si128(d0,k1);
-			d1 = _mm_aesenc_si128(d1,k1);
-			d2 = _mm_aesenc_si128(d2,k1);
-			d3 = _mm_aesenc_si128(d3,k1);
-			k0 = _aes._k.ni.k[8];
-			k1 = _aes._k.ni.k[9];
-			d0 = _mm_aesenc_si128(d0,k2);
-			d1 = _mm_aesenc_si128(d1,k2);
-			d2 = _mm_aesenc_si128(d2,k2);
-			d3 = _mm_aesenc_si128(d3,k2);
-			d0 = _mm_aesenc_si128(d0,k3);
-			d1 = _mm_aesenc_si128(d1,k3);
-			d2 = _mm_aesenc_si128(d2,k3);
-			d3 = _mm_aesenc_si128(d3,k3);
-			k2 = _aes._k.ni.k[10];
-			k3 = _aes._k.ni.k[11];
-			d0 = _mm_aesenc_si128(d0,k0);
-			d1 = _mm_aesenc_si128(d1,k0);
-			d2 = _mm_aesenc_si128(d2,k0);
-			d3 = _mm_aesenc_si128(d3,k0);
-			d0 = _mm_aesenc_si128(d0,k1);
-			d1 = _mm_aesenc_si128(d1,k1);
-			d2 = _mm_aesenc_si128(d2,k1);
-			d3 = _mm_aesenc_si128(d3,k1);
-			k0 = _aes._k.ni.k[12];
-			k1 = _aes._k.ni.k[13];
+			__m128i ka = _aes._k.ni.k[6];
 			d0 = _mm_aesenc_si128(d0,k2);
 			d0 = _mm_aesenc_si128(d0,k2);
 			d1 = _mm_aesenc_si128(d1,k2);
 			d1 = _mm_aesenc_si128(d1,k2);
 			d2 = _mm_aesenc_si128(d2,k2);
 			d2 = _mm_aesenc_si128(d2,k2);
 			d3 = _mm_aesenc_si128(d3,k2);
 			d3 = _mm_aesenc_si128(d3,k2);
-			__m128i p0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));
-			__m128i p1 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in + 16));
+			__m128i kb = _aes._k.ni.k[7];
 			d0 = _mm_aesenc_si128(d0,k3);
 			d0 = _mm_aesenc_si128(d0,k3);
 			d1 = _mm_aesenc_si128(d1,k3);
 			d1 = _mm_aesenc_si128(d1,k3);
 			d2 = _mm_aesenc_si128(d2,k3);
 			d2 = _mm_aesenc_si128(d2,k3);
 			d3 = _mm_aesenc_si128(d3,k3);
 			d3 = _mm_aesenc_si128(d3,k3);
-			k2 = _aes._k.ni.k[14];
-			d0 = _mm_aesenc_si128(d0,k0);
-			d1 = _mm_aesenc_si128(d1,k0);
-			d2 = _mm_aesenc_si128(d2,k0);
-			d3 = _mm_aesenc_si128(d3,k0);
-			__m128i p2 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in + 32));
-			__m128i p3 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in + 48));
-			d0 = _mm_aesenc_si128(d0,k1);
-			d1 = _mm_aesenc_si128(d1,k1);
-			d2 = _mm_aesenc_si128(d2,k1);
-			d3 = _mm_aesenc_si128(d3,k1);
-			d0 = _mm_aesenclast_si128(d0,k2);
-			d1 = _mm_aesenclast_si128(d1,k2);
-			d2 = _mm_aesenclast_si128(d2,k2);
-			d3 = _mm_aesenclast_si128(d3,k2);
-
-			p0 = _mm_xor_si128(d0,p0);
-			p1 = _mm_xor_si128(d1,p1);
-			p2 = _mm_xor_si128(d2,p2);
-			p3 = _mm_xor_si128(d3,p3);
-			_mm_storeu_si128(reinterpret_cast<__m128i *>(out),p0);
-			_mm_storeu_si128(reinterpret_cast<__m128i *>(out + 16),p1);
-			_mm_storeu_si128(reinterpret_cast<__m128i *>(out + 32),p2);
-			_mm_storeu_si128(reinterpret_cast<__m128i *>(out + 48),p3);
+			__m128i kc = _aes._k.ni.k[8];
+			d0 = _mm_aesenc_si128(d0,k4);
+			d1 = _mm_aesenc_si128(d1,k4);
+			d2 = _mm_aesenc_si128(d2,k4);
+			d3 = _mm_aesenc_si128(d3,k4);
+			__m128i kd = _aes._k.ni.k[9];
+			d0 = _mm_aesenc_si128(d0,k5);
+			d1 = _mm_aesenc_si128(d1,k5);
+			d2 = _mm_aesenc_si128(d2,k5);
+			d3 = _mm_aesenc_si128(d3,k5);
+			__m128i ke = _aes._k.ni.k[10];
+			d0 = _mm_aesenc_si128(d0,ka);
+			d1 = _mm_aesenc_si128(d1,ka);
+			d2 = _mm_aesenc_si128(d2,ka);
+			d3 = _mm_aesenc_si128(d3,ka);
+			__m128i kf = _aes._k.ni.k[11];
+			d0 = _mm_aesenc_si128(d0,kb);
+			d1 = _mm_aesenc_si128(d1,kb);
+			d2 = _mm_aesenc_si128(d2,kb);
+			d3 = _mm_aesenc_si128(d3,kb);
+			ka = _aes._k.ni.k[12];
+			d0 = _mm_aesenc_si128(d0,kc);
+			d1 = _mm_aesenc_si128(d1,kc);
+			d2 = _mm_aesenc_si128(d2,kc);
+			d3 = _mm_aesenc_si128(d3,kc);
+			kb = _aes._k.ni.k[13];
+			d0 = _mm_aesenc_si128(d0,kd);
+			d1 = _mm_aesenc_si128(d1,kd);
+			d2 = _mm_aesenc_si128(d2,kd);
+			d3 = _mm_aesenc_si128(d3,kd);
+			kc = _aes._k.ni.k[14];
+			d0 = _mm_aesenc_si128(d0,ke);
+			d1 = _mm_aesenc_si128(d1,ke);
+			d2 = _mm_aesenc_si128(d2,ke);
+			d3 = _mm_aesenc_si128(d3,ke);
+			kd = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));
+			d0 = _mm_aesenc_si128(d0,kf);
+			d1 = _mm_aesenc_si128(d1,kf);
+			d2 = _mm_aesenc_si128(d2,kf);
+			d3 = _mm_aesenc_si128(d3,kf);
+			ke = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in + 16));
+			d0 = _mm_aesenc_si128(d0,ka);
+			d1 = _mm_aesenc_si128(d1,ka);
+			d2 = _mm_aesenc_si128(d2,ka);
+			d3 = _mm_aesenc_si128(d3,ka);
+			kf = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in + 32));
+			d0 = _mm_aesenc_si128(d0,kb);
+			d1 = _mm_aesenc_si128(d1,kb);
+			d2 = _mm_aesenc_si128(d2,kb);
+			d3 = _mm_aesenc_si128(d3,kb);
+			ka = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in + 48));
+			d0 = _mm_aesenclast_si128(d0,kc);
+			d1 = _mm_aesenclast_si128(d1,kc);
+			d2 = _mm_aesenclast_si128(d2,kc);
+			d3 = _mm_aesenclast_si128(d3,kc);
+			kd = _mm_xor_si128(d0,kd);
+			ke = _mm_xor_si128(d1,ke);
+			kf = _mm_xor_si128(d2,kf);
+			ka = _mm_xor_si128(d3,ka);
+			_mm_storeu_si128(reinterpret_cast<__m128i *>(out),kd);
+			_mm_storeu_si128(reinterpret_cast<__m128i *>(out + 16),ke);
+			_mm_storeu_si128(reinterpret_cast<__m128i *>(out + 32),kf);
+			_mm_storeu_si128(reinterpret_cast<__m128i *>(out + 48),ka);
 
 
 			in += 64;
 			in += 64;
 			len -= 64;
 			len -= 64;
@@ -632,23 +644,31 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 
 
 		while (len >= 16) {
 		while (len >= 16) {
 			__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
 			__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
-			d0 = _mm_xor_si128(d0,_aes._k.ni.k[0]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[1]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[2]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[3]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[4]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[5]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[6]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[7]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[8]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[9]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[10]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[11]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[12]);
-			d0 = _mm_aesenc_si128(d0,_aes._k.ni.k[13]);
-			d0 = _mm_aesenclast_si128(d0,_aes._k.ni.k[14]);
-
+			d0 = _mm_xor_si128(d0,k0);
+			d0 = _mm_aesenc_si128(d0,k1);
+			__m128i ka = _aes._k.ni.k[6];
+			d0 = _mm_aesenc_si128(d0,k2);
+			__m128i kb = _aes._k.ni.k[7];
+			d0 = _mm_aesenc_si128(d0,k3);
+			__m128i kc = _aes._k.ni.k[8];
+			d0 = _mm_aesenc_si128(d0,k4);
+			__m128i kd = _aes._k.ni.k[9];
+			d0 = _mm_aesenc_si128(d0,k5);
+			__m128i ke = _aes._k.ni.k[10];
+			d0 = _mm_aesenc_si128(d0,ka);
+			__m128i kf = _aes._k.ni.k[11];
+			d0 = _mm_aesenc_si128(d0,kb);
+			__m128i kg = _aes._k.ni.k[12];
+			d0 = _mm_aesenc_si128(d0,kc);
 			__m128i p0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));
 			__m128i p0 = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));
+			d0 = _mm_aesenc_si128(d0,kd);
+			__m128i kh = _aes._k.ni.k[13];
+			d0 = _mm_aesenc_si128(d0,ke);
+			ka = _aes._k.ni.k[14];
+			d0 = _mm_aesenc_si128(d0,kf);
+			d0 = _mm_aesenc_si128(d0,kg);
+			d0 = _mm_aesenc_si128(d0,kh);
+			d0 = _mm_aesenclast_si128(d0,ka);
 			p0 = _mm_xor_si128(d0,p0);
 			p0 = _mm_xor_si128(d0,p0);
 			_mm_storeu_si128(reinterpret_cast<__m128i *>(out),p0);
 			_mm_storeu_si128(reinterpret_cast<__m128i *>(out),p0);
 
 
@@ -678,8 +698,10 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 	unsigned int totalLen = _len;
 	unsigned int totalLen = _len;
 	if ((totalLen & 15U)) {
 	if ((totalLen & 15U)) {
 		for (;;) {
 		for (;;) {
-			if (!len)
+			if (!len) {
+				_len = (totalLen + len);
 				return;
 				return;
+			}
 			--len;
 			--len;
 			out[totalLen++] = *(in++);
 			out[totalLen++] = *(in++);
 			if (!(totalLen & 15U)) {
 			if (!(totalLen & 15U)) {