aes_ctr_hw_intel.odin 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #+build amd64
  2. package aes
  3. import "base:intrinsics"
  4. import "core:crypto/_aes"
  5. import "core:math/bits"
  6. import "core:mem"
  7. import "core:simd/x86"
  8. @(private)
  9. CTR_STRIDE_HW :: 4
  10. @(private)
  11. CTR_STRIDE_BYTES_HW :: CTR_STRIDE_HW * BLOCK_SIZE
  12. @(private, enable_target_feature = "sse2,aes")
  13. ctr_blocks_hw :: proc(ctx: ^Context_CTR, dst, src: []byte, nr_blocks: int) #no_bounds_check {
  14. hw_ctx := ctx._impl.(Context_Impl_Hardware)
  15. sks: [15]x86.__m128i = ---
  16. for i in 0 ..= hw_ctx._num_rounds {
  17. sks[i] = intrinsics.unaligned_load((^x86.__m128i)(&hw_ctx._sk_exp_enc[i]))
  18. }
  19. hw_inc_ctr := #force_inline proc "contextless" (hi, lo: u64) -> (x86.__m128i, u64, u64) {
  20. ret := x86.__m128i{
  21. i64(intrinsics.byte_swap(hi)),
  22. i64(intrinsics.byte_swap(lo)),
  23. }
  24. hi, lo := hi, lo
  25. carry: u64
  26. lo, carry = bits.add_u64(lo, 1, 0)
  27. hi, _ = bits.add_u64(hi, 0, carry)
  28. return ret, hi, lo
  29. }
  30. // The latency of AESENC depends on mfg and microarchitecture:
  31. // - 7 -> up to Broadwell
  32. // - 4 -> AMD and Skylake - Cascade Lake
  33. // - 3 -> Ice Lake and newer
  34. //
  35. // This implementation does 4 blocks at once, since performance
  36. // should be "adequate" across most CPUs.
  37. src, dst := src, dst
  38. nr_blocks := nr_blocks
  39. ctr_hi, ctr_lo := ctx._ctr_hi, ctx._ctr_lo
  40. blks: [CTR_STRIDE_HW]x86.__m128i = ---
  41. for nr_blocks >= CTR_STRIDE_HW {
  42. #unroll for i in 0..< CTR_STRIDE_HW {
  43. blks[i], ctr_hi, ctr_lo = hw_inc_ctr(ctr_hi, ctr_lo)
  44. }
  45. #unroll for i in 0 ..< CTR_STRIDE_HW {
  46. blks[i] = x86._mm_xor_si128(blks[i], sks[0])
  47. }
  48. #unroll for i in 1 ..= 9 {
  49. #unroll for j in 0 ..< CTR_STRIDE_HW {
  50. blks[j] = x86._mm_aesenc_si128(blks[j], sks[i])
  51. }
  52. }
  53. switch hw_ctx._num_rounds {
  54. case _aes.ROUNDS_128:
  55. #unroll for i in 0 ..< CTR_STRIDE_HW {
  56. blks[i] = x86._mm_aesenclast_si128(blks[i], sks[10])
  57. }
  58. case _aes.ROUNDS_192:
  59. #unroll for i in 10 ..= 11 {
  60. #unroll for j in 0 ..< CTR_STRIDE_HW {
  61. blks[j] = x86._mm_aesenc_si128(blks[j], sks[i])
  62. }
  63. }
  64. #unroll for i in 0 ..< CTR_STRIDE_HW {
  65. blks[i] = x86._mm_aesenclast_si128(blks[i], sks[12])
  66. }
  67. case _aes.ROUNDS_256:
  68. #unroll for i in 10 ..= 13 {
  69. #unroll for j in 0 ..< CTR_STRIDE_HW {
  70. blks[j] = x86._mm_aesenc_si128(blks[j], sks[i])
  71. }
  72. }
  73. #unroll for i in 0 ..< CTR_STRIDE_HW {
  74. blks[i] = x86._mm_aesenclast_si128(blks[i], sks[14])
  75. }
  76. }
  77. xor_blocks_hw(dst, src, blks[:])
  78. if src != nil {
  79. src = src[CTR_STRIDE_BYTES_HW:]
  80. }
  81. dst = dst[CTR_STRIDE_BYTES_HW:]
  82. nr_blocks -= CTR_STRIDE_HW
  83. }
  84. // Handle the remainder.
  85. for nr_blocks > 0 {
  86. blks[0], ctr_hi, ctr_lo = hw_inc_ctr(ctr_hi, ctr_lo)
  87. blks[0] = x86._mm_xor_si128(blks[0], sks[0])
  88. #unroll for i in 1 ..= 9 {
  89. blks[0] = x86._mm_aesenc_si128(blks[0], sks[i])
  90. }
  91. switch hw_ctx._num_rounds {
  92. case _aes.ROUNDS_128:
  93. blks[0] = x86._mm_aesenclast_si128(blks[0], sks[10])
  94. case _aes.ROUNDS_192:
  95. #unroll for i in 10 ..= 11 {
  96. blks[0] = x86._mm_aesenc_si128(blks[0], sks[i])
  97. }
  98. blks[0] = x86._mm_aesenclast_si128(blks[0], sks[12])
  99. case _aes.ROUNDS_256:
  100. #unroll for i in 10 ..= 13 {
  101. blks[0] = x86._mm_aesenc_si128(blks[0], sks[i])
  102. }
  103. blks[0] = x86._mm_aesenclast_si128(blks[0], sks[14])
  104. }
  105. xor_blocks_hw(dst, src, blks[:1])
  106. if src != nil {
  107. src = src[BLOCK_SIZE:]
  108. }
  109. dst = dst[BLOCK_SIZE:]
  110. nr_blocks -= 1
  111. }
  112. // Write back the counter.
  113. ctx._ctr_hi, ctx._ctr_lo = ctr_hi, ctr_lo
  114. mem.zero_explicit(&blks, size_of(blks))
  115. mem.zero_explicit(&sks, size_of(sks))
  116. }
  117. @(private, enable_target_feature = "sse2")
  118. xor_blocks_hw :: proc(dst, src: []byte, blocks: []x86.__m128i) {
  119. #no_bounds_check {
  120. if src != nil {
  121. for i in 0 ..< len(blocks) {
  122. off := i * BLOCK_SIZE
  123. tmp := intrinsics.unaligned_load((^x86.__m128i)(raw_data(src[off:])))
  124. blocks[i] = x86._mm_xor_si128(blocks[i], tmp)
  125. }
  126. }
  127. for i in 0 ..< len(blocks) {
  128. intrinsics.unaligned_store((^x86.__m128i)(raw_data(dst[i * BLOCK_SIZE:])), blocks[i])
  129. }
  130. }
  131. }