Browse Source

Add decrypt

Adam Ierymenko 6 years ago
parent
commit
7bdca83de3
3 changed files with 96 additions and 32 deletions
  1. 0 0
      node/AES.cpp
  2. 86 27
      node/AES.hpp
  3. 10 5
      selftest.cpp

File diff suppressed because it is too large
+ 0 - 0
node/AES.cpp


+ 86 - 27
node/AES.hpp

@@ -37,14 +37,13 @@
 #define ZT_AES_AESNI 1
 #define ZT_AES_AESNI 1
 #endif
 #endif
 
 
+#define ZT_AES_KEY_SIZE 32
+#define ZT_AES_BLOCK_SIZE 16
+
 namespace ZeroTier {
 namespace ZeroTier {
 
 
 /**
 /**
- * AES-256 and GCM AEAD
- * 
- * AES with 128-bit or 192-bit key sizes isn't supported here. This also only
- * supports the encrypt operation since we use AES in GCM mode. For HW acceleration
- * the code is inlined for maximum performance.
+ * AES-256 and AES-GCM AEAD
  */
  */
 class AES
 class AES
 {
 {
@@ -81,7 +80,18 @@ public:
 		_encryptSW(in,out);
 		_encryptSW(in,out);
 	}
 	}
 
 
-	inline void ecbEncrypt(const void *in,unsigned int inlen,void *out)
+	inline void decrypt(const uint8_t in[16],uint8_t out[16]) const
+	{
+#ifdef ZT_AES_AESNI
+		if (likely(HW_ACCEL)) {
+			_decrypt_aesni(in,out);
+			return;
+		}
+#endif
+		_decryptSW(in,out);
+	}
+
+	inline void ecbScramble(const void *in,unsigned int inlen,void *out)
 	{
 	{
 		if (inlen < 16)
 		if (inlen < 16)
 			return;
 			return;
@@ -101,7 +111,7 @@ public:
 				o += 16;
 				o += 16;
 				inlen -= 16;
 				inlen -= 16;
 			}
 			}
-			if (inlen != 0) {
+			if (inlen) {
 				i -= (16 - inlen);
 				i -= (16 - inlen);
 				o -= (16 - inlen);
 				o -= (16 - inlen);
 				_encrypt_aesni(i,o);
 				_encrypt_aesni(i,o);
@@ -117,7 +127,7 @@ public:
 			o += 16;
 			o += 16;
 			inlen -= 16;
 			inlen -= 16;
 		}
 		}
-		if (inlen != 0) {
+		if (inlen) {
 			i -= (16 - inlen);
 			i -= (16 - inlen);
 			o -= (16 - inlen);
 			o -= (16 - inlen);
 			_encryptSW(i,o);
 			_encryptSW(i,o);
@@ -151,16 +161,18 @@ public:
 private:
 private:
 	void _initSW(const uint8_t key[32]);
 	void _initSW(const uint8_t key[32]);
 	void _encryptSW(const uint8_t in[16],uint8_t out[16]) const;
 	void _encryptSW(const uint8_t in[16],uint8_t out[16]) const;
+	void _decryptSW(const uint8_t in[16],uint8_t out[16]) const;
 
 
 	union {
 	union {
 #ifdef ZT_AES_AESNI
 #ifdef ZT_AES_AESNI
 		struct {
 		struct {
-			__m128i k[15];
+			__m128i k[28];
 			__m128i h,hh,hhh,hhhh;
 			__m128i h,hh,hhh,hhhh;
 		} ni;
 		} ni;
 #endif
 #endif
 		struct {
 		struct {
-			uint32_t k[60];
+			uint32_t ek[60];
+			uint32_t dk[60];
 		} sw;
 		} sw;
 	} _k;
 	} _k;
 
 
@@ -211,6 +223,19 @@ private:
 		_k.ni.k[12] = t1 = _init256_1_aesni(t1,_mm_aeskeygenassist_si128(t2,0x20));
 		_k.ni.k[12] = t1 = _init256_1_aesni(t1,_mm_aeskeygenassist_si128(t2,0x20));
 		_k.ni.k[13] = t2 = _init256_2_aesni(t1,t2);
 		_k.ni.k[13] = t2 = _init256_2_aesni(t1,t2);
 		_k.ni.k[14] = _init256_1_aesni(t1,_mm_aeskeygenassist_si128(t2,0x40));
 		_k.ni.k[14] = _init256_1_aesni(t1,_mm_aeskeygenassist_si128(t2,0x40));
+		_k.ni.k[15] = _mm_aesimc_si128(_k.ni.k[13]);
+		_k.ni.k[16] = _mm_aesimc_si128(_k.ni.k[12]);
+		_k.ni.k[17] = _mm_aesimc_si128(_k.ni.k[11]);
+		_k.ni.k[18] = _mm_aesimc_si128(_k.ni.k[10]);
+		_k.ni.k[19] = _mm_aesimc_si128(_k.ni.k[9]);
+		_k.ni.k[20] = _mm_aesimc_si128(_k.ni.k[8]);
+		_k.ni.k[21] = _mm_aesimc_si128(_k.ni.k[7]);
+		_k.ni.k[22] = _mm_aesimc_si128(_k.ni.k[6]);
+		_k.ni.k[23] = _mm_aesimc_si128(_k.ni.k[5]);
+		_k.ni.k[24] = _mm_aesimc_si128(_k.ni.k[4]);
+		_k.ni.k[25] = _mm_aesimc_si128(_k.ni.k[3]);
+		_k.ni.k[26] = _mm_aesimc_si128(_k.ni.k[2]);
+		_k.ni.k[27] = _mm_aesimc_si128(_k.ni.k[1]);
 
 
 		/* Init GCM / GHASH */
 		/* Init GCM / GHASH */
 		__m128i h = _mm_xor_si128(_mm_setzero_si128(),_k.ni.k[0]);
 		__m128i h = _mm_xor_si128(_mm_setzero_si128(),_k.ni.k[0]);
@@ -412,6 +437,26 @@ private:
 			_mm_storeu_si128((__m128i *)((uint8_t *)out + 112),_mm_aesenclast_si128(tmp7,k14));
 			_mm_storeu_si128((__m128i *)((uint8_t *)out + 112),_mm_aesenclast_si128(tmp7,k14));
 		}
 		}
 	}
 	}
+	inline void _decrypt_aesni(const void *in,void *out) const
+	{
+		__m128i tmp;
+		tmp = _mm_loadu_si128((const __m128i *)in);
+		tmp = _mm_xor_si128(tmp,_k.ni.k[14]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[15]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[16]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[17]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[18]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[19]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[20]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[21]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[22]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[23]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[24]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[25]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[26]);
+		tmp = _mm_aesdec_si128(tmp,_k.ni.k[27]);
+		_mm_storeu_si128((__m128i *)out,_mm_aesdeclast_si128(tmp,_k.ni.k[0]));
+	}
 
 
 	static inline __m128i _swap128_aesni(__m128i x) { return _mm_shuffle_epi8(x,_mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)); }
 	static inline __m128i _swap128_aesni(__m128i x) { return _mm_shuffle_epi8(x,_mm_set_epi8(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)); }
 	static inline __m128i _mult_block_aesni(__m128i h,__m128i y)
 	static inline __m128i _mult_block_aesni(__m128i h,__m128i y)
@@ -828,22 +873,6 @@ private:
 		__m128i *bi = (__m128i *)in;
 		__m128i *bi = (__m128i *)in;
 		__m128i *bo = (__m128i *)out;
 		__m128i *bo = (__m128i *)out;
 
 
-		__m128i k0 = _k.ni.k[0];
-		__m128i k1 = _k.ni.k[1];
-		__m128i k2 = _k.ni.k[2];
-		__m128i k3 = _k.ni.k[3];
-		__m128i k4 = _k.ni.k[4];
-		__m128i k5 = _k.ni.k[5];
-		__m128i k6 = _k.ni.k[6];
-		__m128i k7 = _k.ni.k[7];
-		__m128i k8 = _k.ni.k[8];
-		__m128i k9 = _k.ni.k[9];
-		__m128i k10 = _k.ni.k[10];
-		__m128i k11 = _k.ni.k[11];
-		__m128i k12 = _k.ni.k[12];
-		__m128i k13 = _k.ni.k[13];
-		__m128i k14 = _k.ni.k[14];
-
 		unsigned int i;
 		unsigned int i;
 		for (i=0;i<pblocks;i+=4) {
 		for (i=0;i<pblocks;i+=4) {
 			__m128i d1 = _mm_loadu_si128(bi + i + 0);
 			__m128i d1 = _mm_loadu_si128(bi + i + 0);
@@ -852,7 +881,11 @@ private:
 			__m128i d4 = _mm_loadu_si128(bi + i + 3);
 			__m128i d4 = _mm_loadu_si128(bi + i + 3);
 			y = _mm_xor_si128(y,d1);
 			y = _mm_xor_si128(y,d1);
 			y = _mult4xor_aesni(_k.ni.hhhh,_k.ni.hhh,_k.ni.hh,_k.ni.h,y,d2,d3,d4);
 			y = _mult4xor_aesni(_k.ni.hhhh,_k.ni.hhh,_k.ni.hh,_k.ni.h,y,d2,d3,d4);
-			__m128i t1 = _mm_xor_si128(cb,k0);
+			__m128i k0 = _k.ni.k[0];
+			__m128i k1 = _k.ni.k[1];
+			__m128i k2 = _k.ni.k[2];
+			__m128i k3 = _k.ni.k[3];
+				__m128i t1 = _mm_xor_si128(cb,k0);
 			cb = _increment_be_aesni(cb);
 			cb = _increment_be_aesni(cb);
 			__m128i t2 = _mm_xor_si128(cb,k0);
 			__m128i t2 = _mm_xor_si128(cb,k0);
 			cb = _increment_be_aesni(cb);
 			cb = _increment_be_aesni(cb);
@@ -872,6 +905,10 @@ private:
 			t2 = _mm_aesenc_si128(t2,k3);
 			t2 = _mm_aesenc_si128(t2,k3);
 			t3 = _mm_aesenc_si128(t3,k3);
 			t3 = _mm_aesenc_si128(t3,k3);
 			t4 = _mm_aesenc_si128(t4,k3);
 			t4 = _mm_aesenc_si128(t4,k3);
+			__m128i k4 = _k.ni.k[4];
+			__m128i k5 = _k.ni.k[5];
+			__m128i k6 = _k.ni.k[6];
+			__m128i k7 = _k.ni.k[7];
 			t1 = _mm_aesenc_si128(t1,k4);
 			t1 = _mm_aesenc_si128(t1,k4);
 			t2 = _mm_aesenc_si128(t2,k4);
 			t2 = _mm_aesenc_si128(t2,k4);
 			t3 = _mm_aesenc_si128(t3,k4);
 			t3 = _mm_aesenc_si128(t3,k4);
@@ -888,6 +925,10 @@ private:
 			t2 = _mm_aesenc_si128(t2,k7);
 			t2 = _mm_aesenc_si128(t2,k7);
 			t3 = _mm_aesenc_si128(t3,k7);
 			t3 = _mm_aesenc_si128(t3,k7);
 			t4 = _mm_aesenc_si128(t4,k7);
 			t4 = _mm_aesenc_si128(t4,k7);
+			__m128i k8 = _k.ni.k[8];
+			__m128i k9 = _k.ni.k[9];
+			__m128i k10 = _k.ni.k[10];
+			__m128i k11 = _k.ni.k[11];
 			t1 = _mm_aesenc_si128(t1,k8);
 			t1 = _mm_aesenc_si128(t1,k8);
 			t2 = _mm_aesenc_si128(t2,k8);
 			t2 = _mm_aesenc_si128(t2,k8);
 			t3 = _mm_aesenc_si128(t3,k8);
 			t3 = _mm_aesenc_si128(t3,k8);
@@ -904,6 +945,9 @@ private:
 			t2 = _mm_aesenc_si128(t2,k11);
 			t2 = _mm_aesenc_si128(t2,k11);
 			t3 = _mm_aesenc_si128(t3,k11);
 			t3 = _mm_aesenc_si128(t3,k11);
 			t4 = _mm_aesenc_si128(t4,k11);
 			t4 = _mm_aesenc_si128(t4,k11);
+			__m128i k12 = _k.ni.k[12];
+			__m128i k13 = _k.ni.k[13];
+			__m128i k14 = _k.ni.k[14];
 			t1 = _mm_aesenc_si128(t1,k12);
 			t1 = _mm_aesenc_si128(t1,k12);
 			t2 = _mm_aesenc_si128(t2,k12);
 			t2 = _mm_aesenc_si128(t2,k12);
 			t3 = _mm_aesenc_si128(t3,k12);
 			t3 = _mm_aesenc_si128(t3,k12);
@@ -929,18 +973,33 @@ private:
 		for (i=pblocks;i<blocks;i++) {
 		for (i=pblocks;i<blocks;i++) {
 			__m128i d1 = _mm_loadu_si128(bi + i);
 			__m128i d1 = _mm_loadu_si128(bi + i);
 			y = _ghash_aesni(_k.ni.h,y,d1);
 			y = _ghash_aesni(_k.ni.h,y,d1);
+			__m128i k0 = _k.ni.k[0];
+			__m128i k1 = _k.ni.k[1];
+			__m128i k2 = _k.ni.k[2];
+			__m128i k3 = _k.ni.k[3];
 			__m128i t1 = _mm_xor_si128(cb,k0);
 			__m128i t1 = _mm_xor_si128(cb,k0);
 			t1 = _mm_aesenc_si128(t1,k1);
 			t1 = _mm_aesenc_si128(t1,k1);
 			t1 = _mm_aesenc_si128(t1,k2);
 			t1 = _mm_aesenc_si128(t1,k2);
 			t1 = _mm_aesenc_si128(t1,k3);
 			t1 = _mm_aesenc_si128(t1,k3);
+			__m128i k4 = _k.ni.k[4];
+			__m128i k5 = _k.ni.k[5];
+			__m128i k6 = _k.ni.k[6];
+			__m128i k7 = _k.ni.k[7];
 			t1 = _mm_aesenc_si128(t1,k4);
 			t1 = _mm_aesenc_si128(t1,k4);
 			t1 = _mm_aesenc_si128(t1,k5);
 			t1 = _mm_aesenc_si128(t1,k5);
 			t1 = _mm_aesenc_si128(t1,k6);
 			t1 = _mm_aesenc_si128(t1,k6);
 			t1 = _mm_aesenc_si128(t1,k7);
 			t1 = _mm_aesenc_si128(t1,k7);
+			__m128i k8 = _k.ni.k[8];
+			__m128i k9 = _k.ni.k[9];
+			__m128i k10 = _k.ni.k[10];
+			__m128i k11 = _k.ni.k[11];
 			t1 = _mm_aesenc_si128(t1,k8);
 			t1 = _mm_aesenc_si128(t1,k8);
 			t1 = _mm_aesenc_si128(t1,k9);
 			t1 = _mm_aesenc_si128(t1,k9);
 			t1 = _mm_aesenc_si128(t1,k10);
 			t1 = _mm_aesenc_si128(t1,k10);
 			t1 = _mm_aesenc_si128(t1,k11);
 			t1 = _mm_aesenc_si128(t1,k11);
+			__m128i k12 = _k.ni.k[12];
+			__m128i k13 = _k.ni.k[13];
+			__m128i k14 = _k.ni.k[14];
 			t1 = _mm_aesenc_si128(t1,k12);
 			t1 = _mm_aesenc_si128(t1,k12);
 			t1 = _mm_aesenc_si128(t1,k13);
 			t1 = _mm_aesenc_si128(t1,k13);
 			t1 = _mm_aesenclast_si128(t1,k14);
 			t1 = _mm_aesenclast_si128(t1,k14);

+ 10 - 5
selftest.cpp

@@ -184,7 +184,12 @@ static int testCrypto()
 	AES tv(AES_TEST_VECTOR_0_KEY);
 	AES tv(AES_TEST_VECTOR_0_KEY);
 	tv.encrypt(AES_TEST_VECTOR_0_IN,(uint8_t *)buf1);
 	tv.encrypt(AES_TEST_VECTOR_0_IN,(uint8_t *)buf1);
 	if (memcmp(buf1,AES_TEST_VECTOR_0_OUT,16) != 0) {
 	if (memcmp(buf1,AES_TEST_VECTOR_0_OUT,16) != 0) {
-		std::cout << "  FAILED (test vector 0)" << std::endl;
+		std::cout << "FAILED (test vector 0 encrypt)" << std::endl;
+		return -1;
+	}
+	tv.decrypt((const uint8_t *)buf1,(uint8_t *)buf2);
+	if (memcmp(AES_TEST_VECTOR_0_IN,buf2,16) != 0) {
+		std::cout << "FAILED (test vector 0 decrypt)" << std::endl;
 		return -1;
 		return -1;
 	}
 	}
 	std::cout << "PASS" << std::endl << "  AES-256 GCM (test vectors, benchmark): "; std::cout.flush();
 	std::cout << "PASS" << std::endl << "  AES-256 GCM (test vectors, benchmark): "; std::cout.flush();
@@ -220,8 +225,8 @@ static int testCrypto()
 	double ecbBytes = 0.0;
 	double ecbBytes = 0.0;
 	start = OSUtils::now();
 	start = OSUtils::now();
 	for(unsigned long i=0;i<100000;++i) {
 	for(unsigned long i=0;i<100000;++i) {
-		tv.ecbEncrypt(buf1,sizeof(buf1),buf2);
-		tv.ecbEncrypt(buf2,sizeof(buf1),buf1);
+		tv.ecbScramble(buf1,sizeof(buf1),buf2);
+		tv.ecbScramble(buf2,sizeof(buf1),buf1);
 		ecbBytes += (double)(sizeof(buf1) * 2);
 		ecbBytes += (double)(sizeof(buf1) * 2);
 	}
 	}
 	end = OSUtils::now();
 	end = OSUtils::now();
@@ -231,9 +236,9 @@ static int testCrypto()
 	start = OSUtils::now();
 	start = OSUtils::now();
 	for(unsigned long i=0;i<100000;++i) {
 	for(unsigned long i=0;i<100000;++i) {
 		tv.gcmEncrypt((const uint8_t *)hexbuf,buf1,sizeof(buf1),nullptr,0,buf2,(uint8_t *)(hexbuf + 32),16);
 		tv.gcmEncrypt((const uint8_t *)hexbuf,buf1,sizeof(buf1),nullptr,0,buf2,(uint8_t *)(hexbuf + 32),16);
-		tv.ecbEncrypt(buf1,sizeof(buf1),buf2);
+		tv.ecbScramble(buf1,sizeof(buf1),buf2);
 		tv.gcmEncrypt((const uint8_t *)hexbuf,buf2,sizeof(buf2),nullptr,0,buf1,(uint8_t *)(hexbuf + 32),16);
 		tv.gcmEncrypt((const uint8_t *)hexbuf,buf2,sizeof(buf2),nullptr,0,buf1,(uint8_t *)(hexbuf + 32),16);
-		tv.ecbEncrypt(buf2,sizeof(buf1),buf1);
+		tv.ecbScramble(buf2,sizeof(buf1),buf1);
 		ecbBytes += (double)(sizeof(buf1) * 2);
 		ecbBytes += (double)(sizeof(buf1) * 2);
 	}
 	}
 	end = OSUtils::now();
 	end = OSUtils::now();

Some files were not shown because too many files changed in this diff