瀏覽代碼

core/crypto/sm3: Cleanups

- Use `encoding/endian`
- Use `math/bits`
- Add `@(private)` annotations to internals
Yawning Angel 1 年之前
父節點
當前提交
b71d3c739a
共有 1 個文件被更改,包括 42 次插入36 次删除
  1. 42 36
      core/crypto/sm3/sm3.odin

+ 42 - 36
core/crypto/sm3/sm3.odin

@@ -10,11 +10,11 @@ package sm3
     Implementation of the SM3 hashing algorithm, as defined in <https://datatracker.ietf.org/doc/html/draft-sca-cfrg-sm3-02>
 */
 
+import "core:encoding/endian"
 import "core:io"
+import "core:math/bits"
 import "core:os"
 
-import "../util"
-
 /*
     High level API
 */
@@ -110,6 +110,9 @@ init :: proc(ctx: ^Sm3_Context) {
 	ctx.state[5] = IV[5]
 	ctx.state[6] = IV[6]
 	ctx.state[7] = IV[7]
+
+	ctx.length = 0
+	ctx.bitlength = 0
 }
 
 update :: proc(ctx: ^Sm3_Context, data: []byte) {
@@ -119,14 +122,14 @@ update :: proc(ctx: ^Sm3_Context, data: []byte) {
 	if ctx.bitlength > 0 {
 		n := copy(ctx.x[ctx.bitlength:], data[:])
 		ctx.bitlength += u64(n)
-		if ctx.bitlength == 64 {
+		if ctx.bitlength == BLOCK_SIZE {
 			block(ctx, ctx.x[:])
 			ctx.bitlength = 0
 		}
 		data = data[n:]
 	}
-	if len(data) >= 64 {
-		n := len(data) &~ (64 - 1)
+	if len(data) >= BLOCK_SIZE {
+		n := len(data) &~ (BLOCK_SIZE - 1)
 		block(ctx, data[:n])
 		data = data[n:]
 	}
@@ -138,45 +141,44 @@ update :: proc(ctx: ^Sm3_Context, data: []byte) {
 final :: proc(ctx: ^Sm3_Context, hash: []byte) {
 	length := ctx.length
 
-	pad: [64]byte
+	pad: [BLOCK_SIZE]byte
 	pad[0] = 0x80
-	if length % 64 < 56 {
-		update(ctx, pad[0:56 - length % 64])
+	if length % BLOCK_SIZE < 56 {
+		update(ctx, pad[0:56 - length % BLOCK_SIZE])
 	} else {
-		update(ctx, pad[0:64 + 56 - length % 64])
+		update(ctx, pad[0:BLOCK_SIZE + 56 - length % BLOCK_SIZE])
 	}
 
 	length <<= 3
-	util.PUT_U64_BE(pad[:], length)
+	endian.unchecked_put_u64be(pad[:], length)
 	update(ctx, pad[0:8])
 	assert(ctx.bitlength == 0)
 
-	util.PUT_U32_BE(hash[0:], ctx.state[0])
-	util.PUT_U32_BE(hash[4:], ctx.state[1])
-	util.PUT_U32_BE(hash[8:], ctx.state[2])
-	util.PUT_U32_BE(hash[12:], ctx.state[3])
-	util.PUT_U32_BE(hash[16:], ctx.state[4])
-	util.PUT_U32_BE(hash[20:], ctx.state[5])
-	util.PUT_U32_BE(hash[24:], ctx.state[6])
-	util.PUT_U32_BE(hash[28:], ctx.state[7])
+	for i := 0; i < DIGEST_SIZE / 4; i += 1 {
+		endian.unchecked_put_u32be(hash[i * 4:], ctx.state[i])
+	}
 }
 
 /*
     SM3 implementation
 */
 
+BLOCK_SIZE :: 64
+
 Sm3_Context :: struct {
 	state:     [8]u32,
-	x:         [64]byte,
+	x:         [BLOCK_SIZE]byte,
 	bitlength: u64,
 	length:    u64,
 }
 
+@(private)
 IV := [8]u32 {
 	0x7380166f, 0x4914b2b9, 0x172442d7, 0xda8a0600,
 	0xa96f30bc, 0x163138aa, 0xe38dee4d, 0xb0fb0e4e,
 }
 
+@(private)
 block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) {
 	buf := buf
 
@@ -186,20 +188,18 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) {
 	state0, state1, state2, state3 := ctx.state[0], ctx.state[1], ctx.state[2], ctx.state[3]
 	state4, state5, state6, state7 := ctx.state[4], ctx.state[5], ctx.state[6], ctx.state[7]
 
-	for len(buf) >= 64 {
+	for len(buf) >= BLOCK_SIZE {
 		for i := 0; i < 16; i += 1 {
-			j := i * 4
-			w[i] =
-				u32(buf[j]) << 24 | u32(buf[j + 1]) << 16 | u32(buf[j + 2]) << 8 | u32(buf[j + 3])
+			w[i] = endian.unchecked_get_u32be(buf[i * 4:])
 		}
 		for i := 16; i < 68; i += 1 {
-			p1v := w[i - 16] ~ w[i - 9] ~ util.ROTL32(w[i - 3], 15)
+			p1v := w[i - 16] ~ w[i - 9] ~ bits.rotate_left32(w[i - 3], 15)
 			// @note(zh): inlined P1
 			w[i] =
 				p1v ~
-				util.ROTL32(p1v, 15) ~
-				util.ROTL32(p1v, 23) ~
-				util.ROTL32(w[i - 13], 7) ~
+				bits.rotate_left32(p1v, 15) ~
+				bits.rotate_left32(p1v, 23) ~
+				bits.rotate_left32(w[i - 13], 7) ~
 				w[i - 6]
 		}
 		for i := 0; i < 64; i += 1 {
@@ -210,8 +210,8 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) {
 		e, f, g, h := state4, state5, state6, state7
 
 		for i := 0; i < 16; i += 1 {
-			v1 := util.ROTL32(u32(a), 12)
-			ss1 := util.ROTL32(v1 + u32(e) + util.ROTL32(0x79cc4519, i), 7)
+			v1 := bits.rotate_left32(u32(a), 12)
+			ss1 := bits.rotate_left32(v1 + u32(e) + bits.rotate_left32(0x79cc4519, i), 7)
 			ss2 := ss1 ~ v1
 
 			// @note(zh): inlined FF1
@@ -219,15 +219,18 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) {
 			// @note(zh): inlined GG1
 			tt2 := u32(e ~ f ~ g) + u32(h) + ss1 + w[i]
 
-			a, b, c, d = tt1, a, util.ROTL32(u32(b), 9), c
+			a, b, c, d = tt1, a, bits.rotate_left32(u32(b), 9), c
 			// @note(zh): inlined P0
 			e, f, g, h =
-				(tt2 ~ util.ROTL32(tt2, 9) ~ util.ROTL32(tt2, 17)), e, util.ROTL32(u32(f), 19), g
+				(tt2 ~ bits.rotate_left32(tt2, 9) ~ bits.rotate_left32(tt2, 17)),
+				e,
+				bits.rotate_left32(u32(f), 19),
+				g
 		}
 
 		for i := 16; i < 64; i += 1 {
-			v := util.ROTL32(u32(a), 12)
-			ss1 := util.ROTL32(v + u32(e) + util.ROTL32(0x7a879d8a, i % 32), 7)
+			v := bits.rotate_left32(u32(a), 12)
+			ss1 := bits.rotate_left32(v + u32(e) + bits.rotate_left32(0x7a879d8a, i % 32), 7)
 			ss2 := ss1 ~ v
 
 			// @note(zh): inlined FF2
@@ -235,10 +238,13 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) {
 			// @note(zh): inlined GG2
 			tt2 := u32(((e & f) | ((~e) & g)) + h) + ss1 + w[i]
 
-			a, b, c, d = tt1, a, util.ROTL32(u32(b), 9), c
+			a, b, c, d = tt1, a, bits.rotate_left32(u32(b), 9), c
 			// @note(zh): inlined P0
 			e, f, g, h =
-				(tt2 ~ util.ROTL32(tt2, 9) ~ util.ROTL32(tt2, 17)), e, util.ROTL32(u32(f), 19), g
+				(tt2 ~ bits.rotate_left32(tt2, 9) ~ bits.rotate_left32(tt2, 17)),
+				e,
+				bits.rotate_left32(u32(f), 19),
+				g
 		}
 
 		state0 ~= a
@@ -250,7 +256,7 @@ block :: proc "contextless" (ctx: ^Sm3_Context, buf: []byte) {
 		state6 ~= g
 		state7 ~= h
 
-		buf = buf[64:]
+		buf = buf[BLOCK_SIZE:]
 	}
 
 	ctx.state[0], ctx.state[1], ctx.state[2], ctx.state[3] = state0, state1, state2, state3