فهرست منبع

Small AES optimizations on ARM64.

Adam Ierymenko 5 سال پیش
والد
کامیت
d0cc3ac333
1فایلهای تغییر یافته به همراه26 افزوده شده و 22 حذف شده
  1. 26 22
      core/AES.cpp

+ 26 - 22
core/AES.cpp

@@ -22,17 +22,19 @@ namespace {
 
 #ifdef ZT_AES_NEON
 
-ZT_INLINE uint8x16_t s_clmul_armneon_crypto(uint8x16_t a8, const uint8x16_t y, const uint8_t b[16]) noexcept
+ZT_INLINE uint8x16_t s_clmul_armneon_crypto(uint8x16_t h, uint8x16_t y, const uint8_t b[16]) noexcept
 {
-	const uint8x16_t p = vreinterpretq_u8_u64(vdupq_n_u64(0x0000000000000087));
-	const uint8x16_t z = vdupq_n_u8(0);
-	uint8x16_t b8 = vrbitq_u8(veorq_u8(vld1q_u8(b), y));
 	uint8x16_t r0, r1, t0, t1;
-	__asm__ __volatile__("pmull     %0.1q, %1.1d, %2.1d \n\t" : "=w" (r0) : "w" (a8), "w" (b8));
-	__asm__ __volatile__("pmull2   %0.1q, %1.2d, %2.2d \n\t" :"=w" (r1) : "w" (a8), "w" (b8));
-	t0 = vextq_u8(b8, b8, 8);
-	__asm__ __volatile__("pmull     %0.1q, %1.1d, %2.1d \n\t" : "=w" (t1) : "w" (a8), "w" (t0));
-	__asm__ __volatile__("pmull2   %0.1q, %1.2d, %2.2d \n\t" :"=w" (t0) : "w" (a8), "w" (t0));
+	r0 = vld1q_u8(b);
+	const uint8x16_t z = veorq_u8(h, h);
+	y = veorq_u8(r0, y);
+	y = vrbitq_u8(y);
+	const uint8x16_t p = vreinterpretq_u8_u64(vdupq_n_u64(0x0000000000000087));
+	t0 = vextq_u8(y, y, 8);
+	__asm__ __volatile__("pmull     %0.1q, %1.1d, %2.1d \n\t" : "=w" (r0) : "w" (h), "w" (y));
+	__asm__ __volatile__("pmull2   %0.1q, %1.2d, %2.2d \n\t" :"=w" (r1) : "w" (h), "w" (y));
+	__asm__ __volatile__("pmull     %0.1q, %1.1d, %2.1d \n\t" : "=w" (t1) : "w" (h), "w" (t0));
+	__asm__ __volatile__("pmull2   %0.1q, %1.2d, %2.2d \n\t" :"=w" (t0) : "w" (h), "w" (t0));
 	t0 = veorq_u8(t0, t1);
 	t1 = vextq_u8(z, t0, 8);
 	r0 = veorq_u8(r0, t1);
@@ -871,9 +873,9 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 				--len;
 				out[totalLen++] = *(in++);
 				if (!(totalLen & 15U)) {
-					uint8x16_t pt = vld1q_u8(out + (totalLen - 16));
+					uint8_t *const otmp = out + (totalLen - 16);
 					uint8x16_t d0 = vrev32q_u8(dd);
-					dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
+					uint8x16_t pt = vld1q_u8(otmp);
 					d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
 					d0 = vaesmcq_u8(vaeseq_u8(d0, k1));
 					d0 = vaesmcq_u8(vaeseq_u8(d0, k2));
@@ -888,7 +890,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 					d0 = vaesmcq_u8(vaeseq_u8(d0, k11));
 					d0 = vaesmcq_u8(vaeseq_u8(d0, k12));
 					d0 = veorq_u8(vaeseq_u8(d0, k13), k14);
-					vst1q_u8(out + (totalLen - 16), veorq_u8(pt, d0));
+					vst1q_u8(otmp, veorq_u8(pt, d0));
+					dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
 					break;
 				}
 			}
@@ -898,23 +901,18 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 		_len = totalLen + len;
 
 		if (likely(len >= 64)) {
-			const uint32x4_t four = {0,0,0,4};
+			const uint32x4_t four = vshlq_n_u32(one, 2);
 			uint8x16_t dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
 			uint8x16_t dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, one);
 			uint8x16_t dd3 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, one);
 			for (;;) {
 				len -= 64;
-				uint8x16_t pt0 = vld1q_u8(in);
-				uint8x16_t pt1 = vld1q_u8(in + 16);
-				uint8x16_t pt2 = vld1q_u8(in + 32);
-				uint8x16_t pt3 = vld1q_u8(in + 48);
-				in += 64;
-
 				uint8x16_t d0 = vrev32q_u8(dd);
 				uint8x16_t d1 = vrev32q_u8(dd1);
 				uint8x16_t d2 = vrev32q_u8(dd2);
 				uint8x16_t d3 = vrev32q_u8(dd3);
-
+				uint8x16_t pt0 = vld1q_u8(in);
+				in += 16;
 				d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k0));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k0));
@@ -927,6 +925,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k2));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k2));
 				d3 = vaesmcq_u8(vaeseq_u8(d3, k2));
+				uint8x16_t pt1 = vld1q_u8(in);
+				in += 16;
 				d0 = vaesmcq_u8(vaeseq_u8(d0, k3));
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k3));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k3));
@@ -939,6 +939,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k5));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k5));
 				d3 = vaesmcq_u8(vaeseq_u8(d3, k5));
+				uint8x16_t pt2 = vld1q_u8(in);
+				in += 16;
 				d0 = vaesmcq_u8(vaeseq_u8(d0, k6));
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k6));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k6));
@@ -951,6 +953,8 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k8));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k8));
 				d3 = vaesmcq_u8(vaeseq_u8(d3, k8));
+				uint8x16_t pt3 = vld1q_u8(in);
+				in += 16;
 				d0 = vaesmcq_u8(vaeseq_u8(d0, k9));
 				d1 = vaesmcq_u8(vaeseq_u8(d1, k9));
 				d2 = vaesmcq_u8(vaeseq_u8(d2, k9));
@@ -984,7 +988,7 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 				out += 64;
 
 				dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, four);
-				if (len < 64)
+				if (unlikely(len < 64))
 					break;
 				dd1 = (uint8x16_t)vaddq_u32((uint32x4_t)dd1, four);
 				dd2 = (uint8x16_t)vaddq_u32((uint32x4_t)dd2, four);
@@ -994,9 +998,9 @@ void AES::CTR::crypt(const void *const input, unsigned int len) noexcept
 
 		while (len >= 16) {
 			len -= 16;
+			uint8x16_t d0 = vrev32q_u8(dd);
 			uint8x16_t pt = vld1q_u8(in);
 			in += 16;
-			uint8x16_t d0 = vrev32q_u8(dd);
 			dd = (uint8x16_t)vaddq_u32((uint32x4_t)dd, one);
 			d0 = vaesmcq_u8(vaeseq_u8(d0, k0));
 			d0 = vaesmcq_u8(vaeseq_u8(d0, k1));