Browse Source

faster without const variable second-guessing of the compiler

Adam Ierymenko 6 years ago
parent
commit
c0e92d06a5
1 changed files with 90 additions and 100 deletions
  1. 90 100
      node/AES.hpp

+ 90 - 100
node/AES.hpp

@@ -167,6 +167,9 @@ public:
 	 * to use makes the IV itself a secret. This is not strictly necessary
 	 * but comes at little cost.
 	 *
+	 * This code is ZeroTier-specific in a few ways, like the way the IV
+	 * is specified, but would not be hard to generalize.
+	 *
 	 * @param k1 GMAC key
 	 * @param k2 GMAC auth tag keyed hash key
 	 * @param k3 CTR IV keyed hash key
@@ -199,7 +202,7 @@ public:
 		miv[10] = (uint8_t)(len >> 8);
 		miv[11] = (uint8_t)len;
 
-		// Compute auth TAG: AES-ECB[k2](GMAC[k1](miv,plaintext))[0:8]
+		// Compute auth tag: AES-ECB[k2](GMAC[k1](miv,plaintext))[0:8]
 		k1.gmac(miv,in,len,ctrIv);
 		k2.encrypt(ctrIv,ctrIv); // ECB mode encrypt step is because GMAC is not a PRF
 #ifdef ZT_NO_TYPE_PUNNING
@@ -525,22 +528,6 @@ private:
 		const __m64 iv0 = (__m64)(*((const uint64_t *)iv));
 		uint64_t ctr = Utils::ntoh(*((const uint64_t *)(iv+8)));
 
-		const __m128i k0 = _k.ni.k[0];
-		const __m128i k1 = _k.ni.k[1];
-		const __m128i k2 = _k.ni.k[2];
-		const __m128i k3 = _k.ni.k[3];
-		const __m128i k4 = _k.ni.k[4];
-		const __m128i k5 = _k.ni.k[5];
-		const __m128i k6 = _k.ni.k[6];
-		const __m128i k7 = _k.ni.k[7];
-		const __m128i k8 = _k.ni.k[8];
-		const __m128i k9 = _k.ni.k[9];
-		const __m128i k10 = _k.ni.k[10];
-		const __m128i k11 = _k.ni.k[11];
-		const __m128i k12 = _k.ni.k[12];
-		const __m128i k13 = _k.ni.k[13];
-		const __m128i k14 = _k.ni.k[14];
-
 #define ZT_AES_CTR_AESNI_ROUND(k) \
 	c0 = _mm_aesenc_si128(c0,k); \
 	c1 = _mm_aesenc_si128(c1,k); \
@@ -552,36 +539,41 @@ private:
 	c7 = _mm_aesenc_si128(c7,k)
 
 		while (len >= 128) {
-			__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr),iv0),k0);
-			__m128i c1 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+1ULL)),iv0),k0);
-			__m128i c2 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+2ULL)),iv0),k0);
-			__m128i c3 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+3ULL)),iv0),k0);
-			__m128i c4 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+4ULL)),iv0),k0);
-			__m128i c5 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+5ULL)),iv0),k0);
-			__m128i c6 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+6ULL)),iv0),k0);
-			__m128i c7 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+7ULL)),iv0),k0);
+			_mm_prefetch(in,_MM_HINT_T0);
+			_mm_prefetch(in + 32,_MM_HINT_T0);
+			_mm_prefetch(in + 64,_MM_HINT_T0);
+			_mm_prefetch(in + 96,_MM_HINT_T0);
+			_mm_prefetch(in + 128,_MM_HINT_T0);
+			__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr),iv0),_k.ni.k[0]);
+			__m128i c1 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+1ULL)),iv0),_k.ni.k[0]);
+			__m128i c2 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+2ULL)),iv0),_k.ni.k[0]);
+			__m128i c3 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+3ULL)),iv0),_k.ni.k[0]);
+			__m128i c4 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+4ULL)),iv0),_k.ni.k[0]);
+			__m128i c5 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+5ULL)),iv0),_k.ni.k[0]);
+			__m128i c6 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+6ULL)),iv0),_k.ni.k[0]);
+			__m128i c7 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton((uint64_t)(ctr+7ULL)),iv0),_k.ni.k[0]);
 			ctr += 8;
-			ZT_AES_CTR_AESNI_ROUND(k1);
-			ZT_AES_CTR_AESNI_ROUND(k2);
-			ZT_AES_CTR_AESNI_ROUND(k3);
-			ZT_AES_CTR_AESNI_ROUND(k4);
-			ZT_AES_CTR_AESNI_ROUND(k5);
-			ZT_AES_CTR_AESNI_ROUND(k6);
-			ZT_AES_CTR_AESNI_ROUND(k7);
-			ZT_AES_CTR_AESNI_ROUND(k8);
-			ZT_AES_CTR_AESNI_ROUND(k9);
-			ZT_AES_CTR_AESNI_ROUND(k10);
-			ZT_AES_CTR_AESNI_ROUND(k11);
-			ZT_AES_CTR_AESNI_ROUND(k12);
-			ZT_AES_CTR_AESNI_ROUND(k13);
-			_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,k14)));
-			_mm_storeu_si128((__m128i *)(out + 16),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 16)),_mm_aesenclast_si128(c1,k14)));
-			_mm_storeu_si128((__m128i *)(out + 32),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 32)),_mm_aesenclast_si128(c2,k14)));
-			_mm_storeu_si128((__m128i *)(out + 48),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 48)),_mm_aesenclast_si128(c3,k14)));
-			_mm_storeu_si128((__m128i *)(out + 64),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 64)),_mm_aesenclast_si128(c4,k14)));
-			_mm_storeu_si128((__m128i *)(out + 80),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 80)),_mm_aesenclast_si128(c5,k14)));
-			_mm_storeu_si128((__m128i *)(out + 96),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 96)),_mm_aesenclast_si128(c6,k14)));
-			_mm_storeu_si128((__m128i *)(out + 112),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 112)),_mm_aesenclast_si128(c7,k14)));
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[1]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[2]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[3]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[4]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[5]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[6]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[7]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[8]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[9]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[10]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[11]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[12]);
+			ZT_AES_CTR_AESNI_ROUND(_k.ni.k[13]);
+			_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 16),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 16)),_mm_aesenclast_si128(c1,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 32),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 32)),_mm_aesenclast_si128(c2,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 48),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 48)),_mm_aesenclast_si128(c3,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 64),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 64)),_mm_aesenclast_si128(c4,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 80),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 80)),_mm_aesenclast_si128(c5,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 96),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 96)),_mm_aesenclast_si128(c6,_k.ni.k[14])));
+			_mm_storeu_si128((__m128i *)(out + 112),_mm_xor_si128(_mm_loadu_si128((const __m128i *)(in + 112)),_mm_aesenclast_si128(c7,_k.ni.k[14])));
 			in += 128;
 			out += 128;
 			len -= 128;
@@ -590,42 +582,42 @@ private:
 #undef ZT_AES_CTR_AESNI_ROUND
 
 		while (len >= 16) {
-			__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),k0);
-			c0 = _mm_aesenc_si128(c0,k1);
-			c0 = _mm_aesenc_si128(c0,k2);
-			c0 = _mm_aesenc_si128(c0,k3);
-			c0 = _mm_aesenc_si128(c0,k4);
-			c0 = _mm_aesenc_si128(c0,k5);
-			c0 = _mm_aesenc_si128(c0,k6);
-			c0 = _mm_aesenc_si128(c0,k7);
-			c0 = _mm_aesenc_si128(c0,k8);
-			c0 = _mm_aesenc_si128(c0,k9);
-			c0 = _mm_aesenc_si128(c0,k10);
-			c0 = _mm_aesenc_si128(c0,k11);
-			c0 = _mm_aesenc_si128(c0,k12);
-			c0 = _mm_aesenc_si128(c0,k13);
-			_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,k14)));
+			__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),_k.ni.k[0]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[1]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[2]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[3]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[4]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[5]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[6]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[7]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[8]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[9]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[10]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[11]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[12]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[13]);
+			_mm_storeu_si128((__m128i *)out,_mm_xor_si128(_mm_loadu_si128((const __m128i *)in),_mm_aesenclast_si128(c0,_k.ni.k[14])));
 			in += 16;
 			out += 16;
 			len -= 16;
 		}
 
 		if (len) {
-			__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),k0);
-			c0 = _mm_aesenc_si128(c0,k1);
-			c0 = _mm_aesenc_si128(c0,k2);
-			c0 = _mm_aesenc_si128(c0,k3);
-			c0 = _mm_aesenc_si128(c0,k4);
-			c0 = _mm_aesenc_si128(c0,k5);
-			c0 = _mm_aesenc_si128(c0,k6);
-			c0 = _mm_aesenc_si128(c0,k7);
-			c0 = _mm_aesenc_si128(c0,k8);
-			c0 = _mm_aesenc_si128(c0,k9);
-			c0 = _mm_aesenc_si128(c0,k10);
-			c0 = _mm_aesenc_si128(c0,k11);
-			c0 = _mm_aesenc_si128(c0,k12);
-			c0 = _mm_aesenc_si128(c0,k13);
-			c0 = _mm_aesenclast_si128(c0,k14);
+			__m128i c0 = _mm_xor_si128(_mm_set_epi64((__m64)Utils::hton(ctr++),(__m64)iv0),_k.ni.k[0]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[1]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[2]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[3]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[4]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[5]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[6]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[7]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[8]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[9]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[10]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[11]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[12]);
+			c0 = _mm_aesenc_si128(c0,_k.ni.k[13]);
+			c0 = _mm_aesenclast_si128(c0,_k.ni.k[14]);
 			for(unsigned int i=0;i<len;++i)
 				out[i] = in[i] ^ ((const uint8_t *)&c0)[i];
 		}
@@ -680,10 +672,6 @@ private:
 		unsigned int pblocks = blocks - (blocks % 4);
 		unsigned int rem = len % 16;
 
-		const __m128i h1 = _k.ni.hhhh;
-		const __m128i h2 = _k.ni.hhh;
-		const __m128i h3 = _k.ni.hh;
-		const __m128i h4 = _k.ni.h;
 		const __m128i shuf = _mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15);
 		__m128i y = _mm_setzero_si128();
 		unsigned int i = 0;
@@ -692,35 +680,37 @@ private:
 			__m128i d2 = _mm_shuffle_epi8(_mm_loadu_si128(ab + i + 1),shuf);
 			__m128i d3 = _mm_shuffle_epi8(_mm_loadu_si128(ab + i + 2),shuf);
 			__m128i d4 = _mm_shuffle_epi8(_mm_loadu_si128(ab + i + 3),shuf);
-			__m128i t0 = _mm_clmulepi64_si128(h1,d1,0x00);
-			__m128i t1 = _mm_clmulepi64_si128(h2,d2,0x00);
-			__m128i t2 = _mm_clmulepi64_si128(h3,d3,0x00);
-			__m128i t3 = _mm_clmulepi64_si128(h4,d4,0x00);
+			_mm_prefetch(ab + i + 4,_MM_HINT_T0);
+			_mm_prefetch(ab + i + 6,_MM_HINT_T0);
+			__m128i t0 = _mm_clmulepi64_si128(_k.ni.hhhh,d1,0x00);
+			__m128i t1 = _mm_clmulepi64_si128(_k.ni.hhh,d2,0x00);
+			__m128i t2 = _mm_clmulepi64_si128(_k.ni.hh,d3,0x00);
+			__m128i t3 = _mm_clmulepi64_si128(_k.ni.h,d4,0x00);
 			__m128i t8 = _mm_xor_si128(t0,t1);
 			t8 = _mm_xor_si128(t8,t2);
 			t8 = _mm_xor_si128(t8,t3);
-			__m128i t4 = _mm_clmulepi64_si128(h1,d1,0x11);
-			__m128i t5 = _mm_clmulepi64_si128(h2,d2,0x11);
-			__m128i t6 = _mm_clmulepi64_si128(h3,d3,0x11);
-			__m128i t7 = _mm_clmulepi64_si128(h4,d4,0x11);
+			__m128i t4 = _mm_clmulepi64_si128(_k.ni.hhhh,d1,0x11);
+			__m128i t5 = _mm_clmulepi64_si128(_k.ni.hhh,d2,0x11);
+			__m128i t6 = _mm_clmulepi64_si128(_k.ni.hh,d3,0x11);
+			__m128i t7 = _mm_clmulepi64_si128(_k.ni.h,d4,0x11);
 			__m128i t9 = _mm_xor_si128(t4,t5);
 			t9 = _mm_xor_si128(t9,t6);
 			t9 = _mm_xor_si128(t9,t7);
-			t0 = _mm_shuffle_epi32(h1,78);
+			t0 = _mm_shuffle_epi32(_k.ni.hhhh,78);
 			t4 = _mm_shuffle_epi32(d1,78);
-			t0 = _mm_xor_si128(t0,h1);
+			t0 = _mm_xor_si128(t0,_k.ni.hhhh);
 			t4 = _mm_xor_si128(t4,d1);
-			t1 = _mm_shuffle_epi32(h2,78);
+			t1 = _mm_shuffle_epi32(_k.ni.hhh,78);
 			t5 = _mm_shuffle_epi32(d2,78);
-			t1 = _mm_xor_si128(t1,h2);
+			t1 = _mm_xor_si128(t1,_k.ni.hhh);
 			t5 = _mm_xor_si128(t5,d2);
-			t2 = _mm_shuffle_epi32(h3,78);
+			t2 = _mm_shuffle_epi32(_k.ni.hh,78);
 			t6 = _mm_shuffle_epi32(d3,78);
-			t2 = _mm_xor_si128(t2,h3);
+			t2 = _mm_xor_si128(t2,_k.ni.hh);
 			t6 = _mm_xor_si128(t6,d3);
-			t3 = _mm_shuffle_epi32(h4,78);
+			t3 = _mm_shuffle_epi32(_k.ni.h,78);
 			t7 = _mm_shuffle_epi32(d4,78);
-			t3 = _mm_xor_si128(t3,h4);
+			t3 = _mm_xor_si128(t3,_k.ni.h);
 			t7 = _mm_xor_si128(t7,d4);
 			t0 = _mm_clmulepi64_si128(t0,t4,0x00);
 			t1 = _mm_clmulepi64_si128(t1,t5,0x00);
@@ -763,17 +753,17 @@ private:
 			t6 = _mm_xor_si128(t6,t3);
 			y = _mm_shuffle_epi8(t6,shuf);
 		}
-
+#undef h1
 		for (;i<blocks;++i)
-			y = _ghash_aesni(shuf,h4,y,_mm_loadu_si128(ab + i));
+			y = _ghash_aesni(shuf,_k.ni.h,y,_mm_loadu_si128(ab + i));
 
 		if (rem) {
 			__m128i last = _mm_setzero_si128();
 			memcpy(&last,ab + blocks,rem);
-			y = _ghash_aesni(shuf,h4,y,last);
+			y = _ghash_aesni(shuf,_k.ni.h,y,last);
 		}
 
-		y = _ghash_aesni(shuf,h4,y,_mm_set_epi64((__m64)0LL,(__m64)Utils::hton((uint64_t)len * (uint64_t)8)));
+		y = _ghash_aesni(shuf,_k.ni.h,y,_mm_set_epi64((__m64)0LL,(__m64)Utils::hton((uint64_t)len * (uint64_t)8)));
 
 		__m128i t = _mm_xor_si128(_mm_set_epi32(0x01000000,(int)*((const uint32_t *)(iv+8)),(int)*((const uint32_t *)(iv+4)),(int)*((const uint32_t *)(iv))),_k.ni.k[0]);
 		t = _mm_aesenc_si128(t,_k.ni.k[1]);