Browse Source

Add 4X parallel ARM AES so VTEC will kick in, yo. Seems to help on Graviton, not much on small chips but thats okay.

Adam Ierymenko 5 years ago
parent
commit
7efaab2af1
1 changed files with 134 additions and 46 deletions
  1. 134 46
      core/AES.cpp

+ 134 - 46
core/AES.cpp

@@ -841,7 +841,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 
 
 #ifdef ZT_AES_NEON
 #ifdef ZT_AES_NEON
 	if (Utils::ARMCAP.aes) {
 	if (Utils::ARMCAP.aes) {
-		uint8x16_t dd = vld1q_u8(reinterpret_cast<uint8_t *>(_ctr));
+		uint8x16_t dd = vrev32q_u8(vld1q_u8(reinterpret_cast<uint8_t *>(_ctr)));
 		const uint32x4_t one = {0,0,0,1};
 		const uint32x4_t one = {0,0,0,1};
 
 
 		uint8x16_t k0 = _aes._k.neon.ek[0];
 		uint8x16_t k0 = _aes._k.neon.ek[0];
@@ -864,36 +864,31 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 		if ((totalLen & 15U)) {
 		if ((totalLen & 15U)) {
 			for (;;) {
 			for (;;) {
 				if (unlikely(!len)) {
 				if (unlikely(!len)) {
-					vst1q_u8(reinterpret_cast<uint8_t *>(_ctr), dd);
+					vst1q_u8(reinterpret_cast<uint8_t *>(_ctr), vrev32q_u8(dd));
 					_len = totalLen;
 					_len = totalLen;
 					return;
 					return;
 				}
 				}
 				--len;
 				--len;
 				out[totalLen++] = *(in++);
 				out[totalLen++] = *(in++);
 				if (!(totalLen & 15U)) {
 				if (!(totalLen & 15U)) {
-					uint8x16_t tmp = dd;
-					dd = vrev32q_u8(dd);
+					uint8x16_t pt = vld1q_u8(out + (totalLen - 16));
+					uint8x16_t d0 = vrev32q_u8(dd);
 					dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
 					dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
-					dd = vrev32q_u8(dd);
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k0));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k1));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k2));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k3));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k4));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k5));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k6));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k7));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k8));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k9));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k10));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k11));
-					tmp = vaesmcq_u8(vaeseq_u8(tmp, k12));
-					tmp = veorq_u8(vaeseq_u8(tmp, k13), k14);
-					uint8x16_t pt = vld1q_u8(reinterpret_cast<const uint8_t *>(out + (totalLen - 16)));
-					vst1q_u8(reinterpret_cast<uint8_t *>(out + (totalLen - 16)), veorq_u8(pt, tmp));
-					//__m128i *const outblk = reinterpret_cast<__m128i *>(out + (totalLen - 16));
-					//const __m128i p0 = _mm_loadu_si128(outblk);
-					//_mm_storeu_si128(outblk, _mm_xor_si128(p0, d0));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k1));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k2));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k3));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k4));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k5));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k6));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k7));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k8));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k9));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k10));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k11));
+					d0 = vaesmcq_u8(vaeseq_u8(d0, k12));
+					d0 = veorq_u8(vaeseq_u8(d0, k13), k14);
+					vst1q_u8(out + (totalLen - 16), veorq_u8(pt, d0));
 					break;
 					break;
 				}
 				}
 			}
 			}
@@ -902,29 +897,122 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 		out += totalLen;
 		out += totalLen;
 		_len = totalLen + len;
 		_len = totalLen + len;
 
 
+		if (len >= 64) {
+			const uint32x4_t four = {0,0,0,4};
+			uint8x16_t dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
+			uint8x16_t dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, one);
+			uint8x16_t dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, one);
+			for (;;) {
+				len -= 64;
+				uint8x16_t pt0 = vld1q_u8(in);
+				uint8x16_t pt1 = vld1q_u8(in + 16);
+				uint8x16_t pt2 = vld1q_u8(in + 32);
+				uint8x16_t pt3 = vld1q_u8(in + 48);
+				in += 64;
+
+				uint8x16_t d0 = vrev32q_u8(dd);
+				uint8x16_t d1 = vrev32q_u8(dd1);
+				uint8x16_t d2 = vrev32q_u8(dd2);
+				uint8x16_t d3 = vrev32q_u8(dd3);
+
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k0));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k0));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k0));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k1));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k1));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k1));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k1));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k2));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k2));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k2));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k2));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k3));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k3));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k3));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k3));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k4));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k4));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k4));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k4));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k5));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k5));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k5));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k5));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k6));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k6));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k6));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k6));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k7));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k7));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k7));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k7));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k8));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k8));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k8));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k8));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k9));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k9));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k9));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k9));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k10));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k10));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k10));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k10));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k11));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k11));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k11));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k11));
+				d0 = vaesmcq_u8(vaeseq_u8(d0, k12));
+				d1 = vaesmcq_u8(vaeseq_u8(d1, k12));
+				d2 = vaesmcq_u8(vaeseq_u8(d2, k12));
+				d3 = vaesmcq_u8(vaeseq_u8(d3, k12));
+				d0 = veorq_u8(vaeseq_u8(d0, k13), k14);
+				d1 = veorq_u8(vaeseq_u8(d1, k13), k14);
+				d2 = veorq_u8(vaeseq_u8(d2, k13), k14);
+				d3 = veorq_u8(vaeseq_u8(d3, k13), k14);
+
+				d0 = veorq_u8(pt0, d0);
+				d1 = veorq_u8(pt1, d1);
+				d2 = veorq_u8(pt2, d2);
+				d3 = veorq_u8(pt3, d3);
+
+				vst1q_u8(out, d0);
+				vst1q_u8(out + 16, d1);
+				vst1q_u8(out + 32, d2);
+				vst1q_u8(out + 48, d3);
+				out += 64;
+
+				dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, four);
+				if (len < 64)
+					break;
+				dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, four);
+				dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, four);
+				dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd3, four);
+			}
+		}
+
 		while (len >= 16) {
 		while (len >= 16) {
-			uint8x16_t tmp = dd;
-			dd = vrev32q_u8(dd);
-			dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
-			dd = vrev32q_u8(dd);
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k0));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k1));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k2));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k3));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k4));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k5));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k6));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k7));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k8));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k9));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k10));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k11));
-			tmp = vaesmcq_u8(vaeseq_u8(tmp, k12));
-			tmp = veorq_u8(vaeseq_u8(tmp, k13), k14);
-			uint8x16_t pt = vld1q_u8(reinterpret_cast<const uint8_t *>(in));
-			vst1q_u8(reinterpret_cast<uint8_t *>(out), veorq_u8(pt, tmp));
-			in += 16;
 			len -= 16;
 			len -= 16;
+			uint8x16_t pt = vld1q_u8(in);
+			in += 16;
+			uint8x16_t d0 = vrev32q_u8(dd);
+			dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k1));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k2));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k3));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k4));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k5));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k6));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k7));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k8));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k9));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k10));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k11));
+			d0 = vaesmcq_u8(vaeseq_u8(d0, k12));
+			d0 = veorq_u8(vaeseq_u8(d0, k13), k14);
+			vst1q_u8(out, veorq_u8(pt, d0));
 			out += 16;
 			out += 16;
 		}
 		}
 
 
@@ -934,7 +1022,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 		for (unsigned int i = 0; i < len; ++i)
 		for (unsigned int i = 0; i < len; ++i)
 			out[i] = in[i];
 			out[i] = in[i];
 
 
-		vst1q_u8(reinterpret_cast<uint8_t *>(_ctr), dd);
+		vst1q_u8(reinterpret_cast<uint8_t *>(_ctr), vrev32q_u8(dd));
 		return;
 		return;
 	}
 	}
 #endif // ZT_AES_NEON
 #endif // ZT_AES_NEON