Jelajahi Sumber

core/crypto/sha2: Refactor update/final

This is largely modeled off the SM3 versions of these routines, since
the relevant parts of the code are the same between SHA-256 and SM3,
and the alterations required to support SHA-512 are relatively simple.

The prior versions of update and the transform would leak memory, and
doing things this way also reduces the context buffer sizes by 1 block.
Yawning Angel 1 tahun lalu
induk
melakukan
b71afdc3ee
1 mengubah file dengan 78 tambahan dan 76 penghapusan
  1. 78 76
      core/crypto/sha2/sha2.odin

+ 78 - 76
core/crypto/sha2/sha2.odin

@@ -14,7 +14,6 @@ package sha2
 import "core:encoding/endian"
 import "core:io"
 import "core:math/bits"
-import "core:mem"
 import "core:os"
 
 /*
@@ -482,8 +481,8 @@ init :: proc(ctx: ^$T) {
 		}
 	}
 
-	ctx.tot_len = 0
 	ctx.length = 0
+	ctx.bitlength = 0
 
 	ctx.is_initialized = true
 }
@@ -491,65 +490,72 @@ init :: proc(ctx: ^$T) {
 update :: proc(ctx: ^$T, data: []byte) {
 	assert(ctx.is_initialized)
 
-	length := uint(len(data))
-	block_nb: uint
-	new_len, rem_len, tmp_len: uint
-	shifted_message := make([]byte, length)
-
 	when T == Sha256_Context {
 		CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE
 	} else when T == Sha512_Context {
 		CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE
 	}
 
-	tmp_len = CURR_BLOCK_SIZE - ctx.length
-	rem_len = length < tmp_len ? length : tmp_len
-	copy(ctx.block[ctx.length:], data[:rem_len])
+	data := data
+	ctx.length += u64(len(data))
 
-	if ctx.length + length < CURR_BLOCK_SIZE {
-		ctx.length += length
-		return
+	if ctx.bitlength > 0 {
+		n := copy(ctx.block[ctx.bitlength:], data[:])
+		ctx.bitlength += u64(n)
+		if ctx.bitlength == CURR_BLOCK_SIZE {
+			sha2_transf(ctx, ctx.block[:])
+			ctx.bitlength = 0
+		}
+		data = data[n:]
 	}
-
-	new_len = length - rem_len
-	block_nb = new_len / CURR_BLOCK_SIZE
-	shifted_message = data[rem_len:]
-
-	sha2_transf(ctx, ctx.block[:], 1)
-	sha2_transf(ctx, shifted_message, block_nb)
-
-	rem_len = new_len % CURR_BLOCK_SIZE
-	if rem_len > 0 {
-		when T == Sha256_Context {copy(ctx.block[:], shifted_message[block_nb << 6:rem_len])} else when T == Sha512_Context {copy(ctx.block[:], shifted_message[block_nb << 7:rem_len])}
+	if len(data) >= CURR_BLOCK_SIZE {
+		n := len(data) &~ (CURR_BLOCK_SIZE - 1)
+		sha2_transf(ctx, data[:n])
+		data = data[n:]
+	}
+	if len(data) > 0 {
+		ctx.bitlength = u64(copy(ctx.block[:], data[:]))
 	}
-
-	ctx.length = rem_len
-	when T == Sha256_Context {ctx.tot_len += (block_nb + 1) << 6} else when T == Sha512_Context {ctx.tot_len += (block_nb + 1) << 7}
 }
 
 final :: proc(ctx: ^$T, hash: []byte) {
 	assert(ctx.is_initialized)
 
-	block_nb, pm_len: uint
-	len_b: u64
-
 	if len(hash) * 8 < ctx.md_bits {
 		panic("crypto/sha2: invalid destination digest size")
 	}
 
-	when T == Sha256_Context {CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE} else when T == Sha512_Context {CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE}
-
-	when T == Sha256_Context {block_nb = 1 + ((CURR_BLOCK_SIZE - 9) < (ctx.length % CURR_BLOCK_SIZE) ? 1 : 0)} else when T == Sha512_Context {block_nb = 1 + ((CURR_BLOCK_SIZE - 17) < (ctx.length % CURR_BLOCK_SIZE) ? 1 : 0)}
+	length := ctx.length
 
-	len_b = u64(ctx.tot_len + ctx.length) << 3
-	when T == Sha256_Context {pm_len = block_nb << 6} else when T == Sha512_Context {pm_len = block_nb << 7}
-
-	mem.set(rawptr(&(ctx.block[ctx.length:])[0]), 0, int(pm_len - ctx.length))
-	ctx.block[ctx.length] = 0x80
+	raw_pad: [SHA512_BLOCK_SIZE]byte
+	when T == Sha256_Context {
+		CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE
+		pm_len := 8 // 64-bits for length
+	} else when T == Sha512_Context {
+		CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE
+		pm_len := 16 // 128-bits for length
+	}
+	pad := raw_pad[:CURR_BLOCK_SIZE]
+	pad_len := u64(CURR_BLOCK_SIZE - pm_len)
 
-	endian.unchecked_put_u64be(ctx.block[pm_len - 8:], len_b)
+	pad[0] = 0x80
+	if length % CURR_BLOCK_SIZE < pad_len {
+		update(ctx, pad[0:pad_len - length % CURR_BLOCK_SIZE])
+	} else {
+		update(ctx, pad[0:CURR_BLOCK_SIZE + pad_len - length % CURR_BLOCK_SIZE])
+	}
 
-	sha2_transf(ctx, ctx.block[:], block_nb)
+	length_hi, length_lo := bits.mul_u64(length, 8) // Length in bits
+	when T == Sha256_Context {
+		_ = length_hi
+		endian.unchecked_put_u64be(pad[:], length_lo)
+		update(ctx, pad[:8])
+	} else when T == Sha512_Context {
+		endian.unchecked_put_u64be(pad[:], length_hi)
+		endian.unchecked_put_u64be(pad[8:], length_lo)
+		update(ctx, pad[0:16])
+	}
+	assert(ctx.bitlength == 0)
 
 	when T == Sha256_Context {
 		for i := 0; i < ctx.md_bits / 32; i += 1 {
@@ -572,21 +578,21 @@ SHA256_BLOCK_SIZE :: 64
 SHA512_BLOCK_SIZE :: 128
 
 Sha256_Context :: struct {
-	tot_len: uint,
-	length:  uint,
-	block:   [128]byte,
-	h:       [8]u32,
-	md_bits: int,
+	block:     [SHA256_BLOCK_SIZE]byte,
+	h:         [8]u32,
+	bitlength: u64,
+	length:    u64,
+	md_bits:   int,
 
 	is_initialized: bool,
 }
 
 Sha512_Context :: struct {
-	tot_len: uint,
-	length:  uint,
-	block:   [256]byte,
-	h:       [8]u64,
-	md_bits: int,
+	block:     [SHA512_BLOCK_SIZE]byte,
+	h:         [8]u64,
+	bitlength: u64,
+	length:    u64,
+	md_bits:   int,
 
 	is_initialized: bool,
 }
@@ -716,52 +722,46 @@ SHA512_F4 :: #force_inline proc "contextless" (x: u64) -> u64 {
 }
 
 @(private)
-sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) {
+sha2_transf :: proc "contextless" (ctx: ^$T, data: []byte) {
 	when T == Sha256_Context {
 		w: [64]u32
 		wv: [8]u32
 		t1, t2: u32
+		CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE
 	} else when T == Sha512_Context {
 		w: [80]u64
 		wv: [8]u64
 		t1, t2: u64
+		CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE
 	}
 
-	sub_block := make([]byte, len(data))
-	i, j: i32
-
-	for i = 0; i < i32(block_nb); i += 1 {
-		when T == Sha256_Context {
-			sub_block = data[i << 6:]
-		} else when T == Sha512_Context {
-			sub_block = data[i << 7:]
-		}
-
-		for j = 0; j < 16; j += 1 {
+	data := data
+	for len(data) >= CURR_BLOCK_SIZE {
+		for i := 0; i < 16; i += 1 {
 			when T == Sha256_Context {
-				w[j] = endian.unchecked_get_u32be(sub_block[j << 2:])
+				w[i] = endian.unchecked_get_u32be(data[i * 4:])
 			} else when T == Sha512_Context {
-				w[j] = endian.unchecked_get_u64be(sub_block[j << 3:])
+				w[i] = endian.unchecked_get_u64be(data[i * 8:])
 			}
 		}
 
 		when T == Sha256_Context {
-			for j = 16; j < 64; j += 1 {
-				w[j] = SHA256_F4(w[j - 2]) + w[j - 7] + SHA256_F3(w[j - 15]) + w[j - 16]
+			for i := 16; i < 64; i += 1 {
+				w[i] = SHA256_F4(w[i - 2]) + w[i - 7] + SHA256_F3(w[i - 15]) + w[i - 16]
 			}
 		} else when T == Sha512_Context {
-			for j = 16; j < 80; j += 1 {
-				w[j] = SHA512_F4(w[j - 2]) + w[j - 7] + SHA512_F3(w[j - 15]) + w[j - 16]
+			for i := 16; i < 80; i += 1 {
+				w[i] = SHA512_F4(w[i - 2]) + w[i - 7] + SHA512_F3(w[i - 15]) + w[i - 16]
 			}
 		}
 
-		for j = 0; j < 8; j += 1 {
-			wv[j] = ctx.h[j]
+		for i := 0; i < 8; i += 1 {
+			wv[i] = ctx.h[i]
 		}
 
 		when T == Sha256_Context {
-			for j = 0; j < 64; j += 1 {
-				t1 = wv[7] + SHA256_F2(wv[4]) + SHA256_CH(wv[4], wv[5], wv[6]) + sha256_k[j] + w[j]
+			for i := 0; i < 64; i += 1 {
+				t1 = wv[7] + SHA256_F2(wv[4]) + SHA256_CH(wv[4], wv[5], wv[6]) + sha256_k[i] + w[i]
 				t2 = SHA256_F1(wv[0]) + SHA256_MAJ(wv[0], wv[1], wv[2])
 				wv[7] = wv[6]
 				wv[6] = wv[5]
@@ -773,8 +773,8 @@ sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) {
 				wv[0] = t1 + t2
 			}
 		} else when T == Sha512_Context {
-			for j = 0; j < 80; j += 1 {
-				t1 = wv[7] + SHA512_F2(wv[4]) + SHA512_CH(wv[4], wv[5], wv[6]) + sha512_k[j] + w[j]
+			for i := 0; i < 80; i += 1 {
+				t1 = wv[7] + SHA512_F2(wv[4]) + SHA512_CH(wv[4], wv[5], wv[6]) + sha512_k[i] + w[i]
 				t2 = SHA512_F1(wv[0]) + SHA512_MAJ(wv[0], wv[1], wv[2])
 				wv[7] = wv[6]
 				wv[6] = wv[5]
@@ -787,8 +787,10 @@ sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) {
 			}
 		}
 
-		for j = 0; j < 8; j += 1 {
-			ctx.h[j] += wv[j]
+		for i := 0; i < 8; i += 1 {
+			ctx.h[i] += wv[i]
 		}
+
+		data = data[CURR_BLOCK_SIZE:]
 	}
 }