Browse Source

More AES tweaks

Adam Ierymenko 5 years ago
parent
commit
61b72d42b8
1 changed files with 57 additions and 60 deletions
  1. 57 60
      node/AES.cpp

+ 57 - 60
node/AES.cpp

@@ -477,19 +477,26 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 		uint64_t c0 = _ctr[0];
 		uint64_t c1 = Utils::ntoh(_ctr[1]);
 
-		// There are 16 XMM registers. We can reserve six of them for the
-		// first six parts of the expanded AES key. The rest are used for
-		// other key material, counter, or data depending on the chunk size.
-		const __m128i k0 = _aes._k.ni.k[0];
-		const __m128i k1 = _aes._k.ni.k[1];
-		const __m128i k2 = _aes._k.ni.k[2];
-		const __m128i k3 = _aes._k.ni.k[3];
-		const __m128i k4 = _aes._k.ni.k[4];
-		const __m128i k5 = _aes._k.ni.k[5];
+		// This uses some spare XMM registers to hold some of the key.
+		const __m128i *const k = _aes._k.ni.k;
+		const __m128i k0 = k[0];
+		const __m128i k1 = k[1];
+		const __m128i k2 = k[2];
+		const __m128i k3 = k[3];
+		const __m128i k4 = k[4];
+		const __m128i k5 = k[5];
 
 		// Complete any unfinished blocks from previous calls to crypt().
 		unsigned int totalLen = _len;
 		if ((totalLen & 15U)) {
+			const __m128i k7 = k[7];
+			const __m128i k8 = k[8];
+			const __m128i k9 = k[9];
+			const __m128i k10 = k[10];
+			const __m128i k11 = k[11];
+			const __m128i k12 = k[12];
+			const __m128i k13 = k[13];
+			const __m128i k14 = k[14];
 			for (;;) {
 				if (!len) {
 					_ctr[0] = c0;
@@ -503,30 +510,21 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 					__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
 					d0 = _mm_xor_si128(d0,k0);
 					d0 = _mm_aesenc_si128(d0,k1);
-					__m128i ka = _aes._k.ni.k[6];
 					d0 = _mm_aesenc_si128(d0,k2);
-					__m128i kb = _aes._k.ni.k[7];
 					d0 = _mm_aesenc_si128(d0,k3);
-					__m128i kc = _aes._k.ni.k[8];
 					d0 = _mm_aesenc_si128(d0,k4);
-					__m128i kd = _aes._k.ni.k[9];
 					d0 = _mm_aesenc_si128(d0,k5);
-					__m128i ke = _aes._k.ni.k[10];
-					d0 = _mm_aesenc_si128(d0,ka);
-					__m128i kf = _aes._k.ni.k[11];
-					d0 = _mm_aesenc_si128(d0,kb);
-					__m128i kg = _aes._k.ni.k[12];
-					d0 = _mm_aesenc_si128(d0,kc);
-					__m128i kh = _aes._k.ni.k[13];
-					d0 = _mm_aesenc_si128(d0,kd);
-					ka = _aes._k.ni.k[14];
-					d0 = _mm_aesenc_si128(d0,ke);
+					d0 = _mm_aesenc_si128(d0,k[6]);
+					d0 = _mm_aesenc_si128(d0,k7);
+					d0 = _mm_aesenc_si128(d0,k8);
+					d0 = _mm_aesenc_si128(d0,k9);
+					d0 = _mm_aesenc_si128(d0,k10);
 					__m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16));
-					d0 = _mm_aesenc_si128(d0,kf);
+					d0 = _mm_aesenc_si128(d0,k11);
 					const __m128i p0 = _mm_loadu_si128(outblk);
-					d0 = _mm_aesenc_si128(d0,kg);
-					d0 = _mm_aesenc_si128(d0,kh);
-					d0 = _mm_aesenclast_si128(d0,ka);
+					d0 = _mm_aesenc_si128(d0,k12);
+					d0 = _mm_aesenc_si128(d0,k13);
+					d0 = _mm_aesenclast_si128(d0,k14);
 					_mm_storeu_si128(outblk,_mm_xor_si128(p0,d0));
 					if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 					break;
@@ -564,47 +562,47 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 			d1 = _mm_aesenc_si128(d1,k1);
 			d2 = _mm_aesenc_si128(d2,k1);
 			d3 = _mm_aesenc_si128(d3,k1);
-			__m128i ka = _aes._k.ni.k[6];
+			__m128i ka = k[6];
 			d0 = _mm_aesenc_si128(d0,k2);
 			d1 = _mm_aesenc_si128(d1,k2);
 			d2 = _mm_aesenc_si128(d2,k2);
 			d3 = _mm_aesenc_si128(d3,k2);
-			__m128i kb = _aes._k.ni.k[7];
+			__m128i kb = k[7];
 			d0 = _mm_aesenc_si128(d0,k3);
 			d1 = _mm_aesenc_si128(d1,k3);
 			d2 = _mm_aesenc_si128(d2,k3);
 			d3 = _mm_aesenc_si128(d3,k3);
-			__m128i kc = _aes._k.ni.k[8];
+			__m128i kc = k[8];
 			d0 = _mm_aesenc_si128(d0,k4);
 			d1 = _mm_aesenc_si128(d1,k4);
 			d2 = _mm_aesenc_si128(d2,k4);
 			d3 = _mm_aesenc_si128(d3,k4);
-			__m128i kd = _aes._k.ni.k[9];
+			__m128i kd = k[9];
 			d0 = _mm_aesenc_si128(d0,k5);
 			d1 = _mm_aesenc_si128(d1,k5);
 			d2 = _mm_aesenc_si128(d2,k5);
 			d3 = _mm_aesenc_si128(d3,k5);
-			__m128i ke = _aes._k.ni.k[10];
+			__m128i ke = k[10];
 			d0 = _mm_aesenc_si128(d0,ka);
 			d1 = _mm_aesenc_si128(d1,ka);
 			d2 = _mm_aesenc_si128(d2,ka);
 			d3 = _mm_aesenc_si128(d3,ka);
-			__m128i kf = _aes._k.ni.k[11];
+			__m128i kf = k[11];
 			d0 = _mm_aesenc_si128(d0,kb);
 			d1 = _mm_aesenc_si128(d1,kb);
 			d2 = _mm_aesenc_si128(d2,kb);
 			d3 = _mm_aesenc_si128(d3,kb);
-			ka = _aes._k.ni.k[12];
+			ka = k[12];
 			d0 = _mm_aesenc_si128(d0,kc);
 			d1 = _mm_aesenc_si128(d1,kc);
 			d2 = _mm_aesenc_si128(d2,kc);
 			d3 = _mm_aesenc_si128(d3,kc);
-			kb = _aes._k.ni.k[13];
+			kb = k[13];
 			d0 = _mm_aesenc_si128(d0,kd);
 			d1 = _mm_aesenc_si128(d1,kd);
 			d2 = _mm_aesenc_si128(d2,kd);
 			d3 = _mm_aesenc_si128(d3,kd);
-			kc = _aes._k.ni.k[14];
+			kc = k[14];
 			d0 = _mm_aesenc_si128(d0,ke);
 			d1 = _mm_aesenc_si128(d1,ke);
 			d2 = _mm_aesenc_si128(d2,ke);
@@ -644,41 +642,40 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 		}
 
 		{
-			__m128i ka = _aes._k.ni.k[6];
-			__m128i kb = _aes._k.ni.k[7];
-			const __m128i kc = _aes._k.ni.k[8];
-			const __m128i kd = _aes._k.ni.k[9];
-			const __m128i ke = _aes._k.ni.k[10];
-			const __m128i kf = _aes._k.ni.k[11];
-			const __m128i kg = _aes._k.ni.k[12];
-			const __m128i kh = _aes._k.ni.k[13];
+			const __m128i k7 = k[7];
+			const __m128i k8 = k[8];
+			const __m128i k9 = k[9];
+			const __m128i k10 = k[10];
+			const __m128i k11 = k[11];
+			const __m128i k12 = k[12];
+			const __m128i k13 = k[13];
+			const __m128i k14 = k[14];
 			while (len >= 16) {
-				__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
+				__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1++),(long long)c0);
+				if (unlikely(c1 == 0)) {
+					c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
+					d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
+				}
 				d0 = _mm_xor_si128(d0,k0);
 				d0 = _mm_aesenc_si128(d0,k1);
 				d0 = _mm_aesenc_si128(d0,k2);
 				d0 = _mm_aesenc_si128(d0,k3);
 				d0 = _mm_aesenc_si128(d0,k4);
 				d0 = _mm_aesenc_si128(d0,k5);
-				d0 = _mm_aesenc_si128(d0,ka);
-				d0 = _mm_aesenc_si128(d0,kb);
-				d0 = _mm_aesenc_si128(d0,kc);
-				d0 = _mm_aesenc_si128(d0,kd);
-				ka = _aes._k.ni.k[14];
-				d0 = _mm_aesenc_si128(d0,ke);
-				d0 = _mm_aesenc_si128(d0,kf);
-				d0 = _mm_aesenc_si128(d0,kg);
-				d0 = _mm_aesenc_si128(d0,kh);
-				kb = _mm_loadu_si128(reinterpret_cast<const __m128i *>(in));
-				d0 = _mm_aesenclast_si128(d0,ka);
-				kb = _mm_xor_si128(d0,kb);
-				_mm_storeu_si128(reinterpret_cast<__m128i *>(out),kb);
+				d0 = _mm_aesenc_si128(d0,k[6]);
+				d0 = _mm_aesenc_si128(d0,k7);
+				d0 = _mm_aesenc_si128(d0,k8);
+				d0 = _mm_aesenc_si128(d0,k9);
+				d0 = _mm_aesenc_si128(d0,k10);
+				d0 = _mm_aesenc_si128(d0,k11);
+				d0 = _mm_aesenc_si128(d0,k12);
+				d0 = _mm_aesenc_si128(d0,k13);
+				d0 = _mm_aesenclast_si128(d0,k14);
+				_mm_storeu_si128(reinterpret_cast<__m128i *>(out),_mm_xor_si128(d0,_mm_loadu_si128(reinterpret_cast<const __m128i *>(in))));
 
 				in += 16;
 				len -= 16;
 				out += 16;
-
-				if (unlikely(++c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 			}
 		}