Browse Source

encoding/base64: add decode_into, add tests

Laytan Laats 1 year ago
parent
commit
b6c47e7963

+ 90 - 43
core/encoding/base64/base64.odin

@@ -44,71 +44,80 @@ DEC_TABLE := [128]int {
 }
 }
 
 
 encode :: proc(data: []byte, ENC_TBL := ENC_TABLE, allocator := context.allocator) -> (encoded: string, err: mem.Allocator_Error) #optional_allocator_error {
 encode :: proc(data: []byte, ENC_TBL := ENC_TABLE, allocator := context.allocator) -> (encoded: string, err: mem.Allocator_Error) #optional_allocator_error {
-	out_length := encoded_length(data)
+	out_length := encoded_len(data)
 	if out_length == 0 {
 	if out_length == 0 {
 		return
 		return
 	}
 	}
 
 
-	out: strings.Builder
-	strings.builder_init(&out, 0, out_length, allocator) or_return
-
+	out   := strings.builder_make(0, out_length, allocator) or_return
 	ioerr := encode_into(strings.to_stream(&out), data, ENC_TBL)
 	ioerr := encode_into(strings.to_stream(&out), data, ENC_TBL)
-	assert(ioerr == nil)
+
+	assert(ioerr == nil,                           "string builder should not IO error")
+	assert(strings.builder_cap(out) == out_length, "buffer resized, `encoded_len` was wrong")
 
 
 	return strings.to_string(out), nil
 	return strings.to_string(out), nil
 }
 }
 
 
-encoded_length :: #force_inline proc(data: []byte) -> int {
+encode_into :: proc(w: io.Writer, data: []byte, ENC_TBL := ENC_TABLE) -> io.Error {
 	length := len(data)
 	length := len(data)
 	if length == 0 {
 	if length == 0 {
-		return 0
+		return nil
 	}
 	}
 
 
-	return ((4 * length / 3) + 3) &~ 3
+	c0, c1, c2, block: int
+	out: [4]byte
+	for i := 0; i < length; i += 3 {
+		#no_bounds_check {
+			c0, c1, c2 = int(data[i]), -1, -1
+
+			if i + 1 < length { c1 = int(data[i + 1]) }
+			if i + 2 < length { c2 = int(data[i + 2]) }
+
+			block = (c0 << 16) | (max(c1, 0) << 8) | max(c2, 0)
+			
+			out[0] = ENC_TBL[block >> 18 & 63]
+			out[1] = ENC_TBL[block >> 12 & 63]
+			out[2] = c1 == -1 ? PADDING : ENC_TBL[block >> 6 & 63]
+			out[3] = c2 == -1 ? PADDING : ENC_TBL[block & 63]
+		}
+		io.write_full(w, out[:]) or_return
+	}
+	return nil
 }
 }
 
 
-encode_into :: proc(w: io.Writer, data: []byte, ENC_TBL := ENC_TABLE) -> (err: io.Error) #no_bounds_check {
+encoded_len :: proc(data: []byte) -> int {
 	length := len(data)
 	length := len(data)
 	if length == 0 {
 	if length == 0 {
-		return
+		return 0
 	}
 	}
 
 
-	c0, c1, c2, block: int
+	return ((4 * length / 3) + 3) &~ 3
+}
 
 
-	for i, d := 0, 0; i < length; i, d = i + 3, d + 4 {
-		c0, c1, c2 = int(data[i]), -1, -1
+decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocator) -> (decoded: []byte, err: mem.Allocator_Error) #optional_allocator_error {
+	out_length := decoded_len(data)
 
 
-		if i + 1 < length { c1 = int(data[i + 1]) }
-		if i + 2 < length { c2 = int(data[i + 2]) }
+	out   := strings.builder_make(0, out_length, allocator) or_return
+	ioerr := decode_into(strings.to_stream(&out), data, DEC_TBL)
 
 
-		block = (c0 << 16) | (max(c1, 0) << 8) | max(c2, 0)
-		
-		out: [4]byte
-		out[0] = ENC_TBL[block >> 18 & 63]
-		out[1] = ENC_TBL[block >> 12 & 63]
-		out[2] = c1 == -1 ? PADDING : ENC_TBL[block >> 6 & 63]
-		out[3] = c2 == -1 ? PADDING : ENC_TBL[block & 63]
+	assert(ioerr == nil,                           "string builder should not IO error")
+	assert(strings.builder_cap(out) == out_length, "buffer resized, `decoded_len` was wrong")
 
 
-		#bounds_check { io.write_full(w, out[:]) or_return }
-	}
-	return
+	return out.buf[:], nil
 }
 }
 
 
-decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocator) -> (out: []byte, err: mem.Allocator_Error) #optional_allocator_error {
-	#no_bounds_check {
-		length := len(data)
-		if length == 0 {
-			return
-		}
-
-		pad_count := data[length - 1] == PADDING ? (data[length - 2] == PADDING ? 2 : 1) : 0
-		out_length := ((length * 6) >> 3) - pad_count
-		out = make([]byte, out_length, allocator) or_return
-
-		c0, c1, c2, c3: int
-		b0, b1, b2: int
+decode_into :: proc(w: io.Writer, data: string, DEC_TBL := DEC_TABLE) -> io.Error {
+	length := decoded_len(data)
+	if length == 0 {
+		return nil
+	}
 
 
-		for i, j := 0, 0; i < length; i, j = i + 4, j + 3 {
+	c0, c1, c2, c3: int
+	b0, b1, b2: int
+	buf: [3]byte
+	i, j: int
+	for ; j + 3 <= length; i, j = i + 4, j + 3 {
+		#no_bounds_check {
 			c0 = DEC_TBL[data[i]]
 			c0 = DEC_TBL[data[i]]
 			c1 = DEC_TBL[data[i + 1]]
 			c1 = DEC_TBL[data[i + 1]]
 			c2 = DEC_TBL[data[i + 2]]
 			c2 = DEC_TBL[data[i + 2]]
@@ -118,10 +127,48 @@ decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocato
 			b1 = (c1 << 4) | (c2 >> 2)
 			b1 = (c1 << 4) | (c2 >> 2)
 			b2 = (c2 << 6) | c3
 			b2 = (c2 << 6) | c3
 
 
-			out[j]     = byte(b0)
-			out[j + 1] = byte(b1)
-			out[j + 2] = byte(b2)
+			buf[0] = byte(b0)
+			buf[1] = byte(b1)
+			buf[2] = byte(b2)
 		}
 		}
-		return 
+
+		io.write_full(w, buf[:]) or_return
 	}
 	}
+
+	rest := length - j
+	if rest > 0 {
+		#no_bounds_check {
+			c0 = DEC_TBL[data[i]]
+			c1 = DEC_TBL[data[i + 1]]
+			c2 = DEC_TBL[data[i + 2]]
+
+			b0 = (c0 << 2) | (c1 >> 4)
+			b1 = (c1 << 4) | (c2 >> 2)
+		}
+
+		switch rest {
+		case 1: io.write_byte(w, byte(b0))             or_return
+		case 2: io.write_full(w, {byte(b0), byte(b1)}) or_return
+		}
+	}
+
+	return nil
+}
+
+decoded_len :: proc(data: string) -> int {
+	length := len(data)
+	if length == 0 {
+		return 0
+	}
+
+	padding: int
+	if data[length - 1] == PADDING {
+		if length > 1 && data[length - 2] == PADDING {
+			padding = 2
+		} else {
+			padding = 1
+		}
+	}
+
+	return ((length * 6) >> 3) - padding
 }
 }

+ 3 - 0
tests/core/Makefile

@@ -51,11 +51,14 @@ noise_test:
 	$(ODIN) run math/noise $(COMMON) -out:test_noise
 	$(ODIN) run math/noise $(COMMON) -out:test_noise
 
 
 encoding_test:
 encoding_test:
+<<<<<<< HEAD
 	$(ODIN) run encoding/hxa    $(COMMON) $(COLLECTION) -out:test_hxa
 	$(ODIN) run encoding/hxa    $(COMMON) $(COLLECTION) -out:test_hxa
 	$(ODIN) run encoding/json   $(COMMON) -out:test_json
 	$(ODIN) run encoding/json   $(COMMON) -out:test_json
 	$(ODIN) run encoding/varint $(COMMON) -out:test_varint
 	$(ODIN) run encoding/varint $(COMMON) -out:test_varint
 	$(ODIN) run encoding/xml    $(COMMON) -out:test_xml
 	$(ODIN) run encoding/xml    $(COMMON) -out:test_xml
 	$(ODIN) run encoding/cbor   $(COMMON) -out:test_cbor
 	$(ODIN) run encoding/cbor   $(COMMON) -out:test_cbor
+	$(ODIN) run encoding/hex    $(COMMON) -out:test_hex
+	$(ODIN) run encoding/base64 $(COMMON) -out:test_base64
 
 
 math_test:
 math_test:
 	$(ODIN) run math $(COMMON) $(COLLECTION) -out:test_core_math
 	$(ODIN) run math $(COMMON) $(COLLECTION) -out:test_core_math

+ 2 - 0
tests/core/build.bat

@@ -41,6 +41,8 @@ rem %PATH_TO_ODIN% run encoding/hxa    %COMMON% %COLLECTION% -out:test_hxa.exe |
 %PATH_TO_ODIN% run encoding/varint %COMMON% -out:test_varint.exe || exit /b
 %PATH_TO_ODIN% run encoding/varint %COMMON% -out:test_varint.exe || exit /b
 %PATH_TO_ODIN% run encoding/xml    %COMMON% -out:test_xml.exe || exit /b
 %PATH_TO_ODIN% run encoding/xml    %COMMON% -out:test_xml.exe || exit /b
 %PATH_TO_ODIN% test encoding/cbor  %COMMON% -out:test_cbor.exe || exit /b
 %PATH_TO_ODIN% test encoding/cbor  %COMMON% -out:test_cbor.exe || exit /b
+%PATH_TO_ODIN% run encoding/hex    %COMMON% -out:test_hex.exe || exit /b
+%PATH_TO_ODIN% run encoding/base64 %COMMON% -out:test_base64.exe || exit /b
 
 
 echo ---
 echo ---
 echo Running core:math/noise tests
 echo Running core:math/noise tests

+ 60 - 0
tests/core/encoding/base64/base64.odin

@@ -0,0 +1,60 @@
+package test_encoding_base64
+
+import "core:encoding/base64"
+import "core:fmt"
+import "core:intrinsics"
+import "core:os"
+import "core:reflect"
+import "core:testing"
+
+TEST_count := 0
+TEST_fail  := 0
+
+when ODIN_TEST {
+	expect_value :: testing.expect_value
+
+} else {
+	expect_value :: proc(t: ^testing.T, value, expected: $T, loc := #caller_location) -> bool where intrinsics.type_is_comparable(T) {
+		TEST_count += 1
+		ok := value == expected || reflect.is_nil(value) && reflect.is_nil(expected)
+		if !ok {
+			TEST_fail += 1
+			fmt.printf("[%v] expected %v, got %v\n", loc, expected, value)
+		}
+		return ok
+	}
+}
+
+main :: proc() {
+	t := testing.T{}
+
+	test_encoding(&t)
+	test_decoding(&t)
+
+	fmt.printf("%v/%v tests successful.\n", TEST_count - TEST_fail, TEST_count)
+	if TEST_fail > 0 {
+		os.exit(1)
+	}
+}
+
+@(test)
+test_encoding :: proc(t: ^testing.T) {
+	expect_value(t, base64.encode(transmute([]byte)string("")), "")
+	expect_value(t, base64.encode(transmute([]byte)string("f")), "Zg==")
+	expect_value(t, base64.encode(transmute([]byte)string("fo")), "Zm8=")
+	expect_value(t, base64.encode(transmute([]byte)string("foo")), "Zm9v")
+	expect_value(t, base64.encode(transmute([]byte)string("foob")), "Zm9vYg==")
+	expect_value(t, base64.encode(transmute([]byte)string("fooba")), "Zm9vYmE=")
+	expect_value(t, base64.encode(transmute([]byte)string("foobar")), "Zm9vYmFy")
+}
+
+@(test)
+test_decoding :: proc(t: ^testing.T) {
+	expect_value(t, string(base64.decode("")), "")
+	expect_value(t, string(base64.decode("Zg==")), "f")
+	expect_value(t, string(base64.decode("Zm8=")), "fo")
+	expect_value(t, string(base64.decode("Zm9v")), "foo")
+	expect_value(t, string(base64.decode("Zm9vYg==")), "foob")
+	expect_value(t, string(base64.decode("Zm9vYmE=")), "fooba")
+	expect_value(t, string(base64.decode("Zm9vYmFy")), "foobar")
+}