Browse Source

encoding/cbor: fully support marshal/unmarshal of unions

Laytan Laats 1 year ago
parent
commit
d77ae9abab

+ 25 - 1
core/encoding/cbor/marshal.odin

@@ -506,8 +506,32 @@ marshal_into_encoder :: proc(e: Encoder, v: any) -> (err: Marshal_Error) {
 		if v.data == nil || tag <= 0 {
 			return _encode_nil(e.writer)
 		}
+
 		id := info.variants[tag-1].id
-		return marshal_into(e, any{v.data, id})
+		if len(info.variants) == 1 {
+			id := info.variants[tag-1].id
+			return marshal_into(e, any{v.data, id})
+		}
+
+		// Encode a non-nil multi-variant union as the `TAG_OBJECT_TYPE`.
+		// Which is a tag of an array, where the first element is the textual id/type of the object
+		// that follows it.
+
+		err_conv(_encode_u16(e, TAG_OBJECT_TYPE, .Tag)) or_return
+		_encode_u8(e.writer, 2, .Array) or_return
+
+		vti := reflect.union_variant_type_info(v)
+		#partial switch vt in vti.variant {
+		case reflect.Type_Info_Named:
+			err_conv(_encode_text(e, vt.name)) or_return
+		case:
+			builder := strings.builder_make(context.temp_allocator) or_return
+			defer strings.builder_destroy(&builder)
+			reflect.write_type(&builder, vti)
+			err_conv(_encode_text(e, strings.to_string(builder))) or_return
+		}
+
+		return marshal_into(e, any{v.data, vti.id})
 
 	case runtime.Type_Info_Enum:
 		return marshal_into(e, any{v.data, info.base.id})

+ 9 - 0
core/encoding/cbor/tags.odin

@@ -38,6 +38,15 @@ TAG_BASE64_ID :: "base64"
 // given content is definitely CBOR.
 TAG_SELF_DESCRIBED_CBOR :: 55799
 
+// A tag that is used to assign a textual type to the object following it.
+// The tag's value must be an array of 2 items, where the first is text (describing the following type)
+// and the second is any valid CBOR value.
+//
+// See the registration: https://datatracker.ietf.org/doc/draft-rundgren-cotx/05/
+//
+// We use this in Odin to marshal and unmarshal unions.
+TAG_OBJECT_TYPE :: 1010
+
 // A tag implementation that handles marshals and unmarshals for the tag it is registered on.
 Tag_Implementation :: struct {
 	data:      rawptr,

+ 71 - 5
core/encoding/cbor/unmarshal.odin

@@ -8,9 +8,6 @@ import "core:runtime"
 import "core:strings"
 import "core:unicode/utf8"
 
-// `strings` is only used in poly procs, but -vet thinks it is fully unused.
-_ :: strings
-
 /*
 Unmarshals the given CBOR into the given pointer using reflection.
 Types that require allocation are allocated using the given allocator.
@@ -79,7 +76,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
 		dst = err_conv(decode(r, hdr)) or_return
 		return
 	}
-	
+
 	switch hdr {
 	case .U8:
 		decoded := _decode_u8(r) or_return
@@ -275,10 +272,12 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
 		}
 
 		nr := err_conv(_decode_tag_nr(r, add)) or_return
-		
+
 		// Custom tag implementations.
 		if impl, ok := _tag_implementations_nr[nr]; ok {
 			return impl->unmarshal(r, nr, v)
+		} else if nr == TAG_OBJECT_TYPE {
+			return _unmarshal_union(r, v, ti, hdr)
 		} else {
 			// Discard the tag info and unmarshal as its value.
 			return _unmarshal_value(r, v, _decode_header(r) or_return)
@@ -717,6 +716,73 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 	}
 }
 
+// Unmarshal into a union, based on the `TAG_OBJECT_TYPE` tag of the spec, it denotes a tag which
+// contains an array of exactly two elements, the first is a textual representation of the following
+// CBOR value's type.
+_unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header) -> (err: Unmarshal_Error) {
+	#partial switch t in ti.variant {
+	case reflect.Type_Info_Union:
+		idhdr: Header
+		target_name: string
+		{
+			vhdr := _decode_header(r) or_return
+			vmaj, vadd := _header_split(vhdr)
+			if vmaj != .Array {
+				return .Bad_Tag_Value
+			}
+
+			n_items, unknown := err_conv(_decode_container_length(r, vadd)) or_return
+			if unknown || n_items != 2 {
+				return .Bad_Tag_Value
+			}
+			
+			idhdr = _decode_header(r) or_return
+			idmaj, idadd := _header_split(idhdr)
+			if idmaj != .Text {
+				return .Bad_Tag_Value
+			}
+
+			context.allocator = context.temp_allocator
+			target_name = err_conv(_decode_text(r, idadd)) or_return
+		}
+		defer delete(target_name, context.temp_allocator)
+
+		for variant, i in t.variants {
+			tag := i64(i)
+			if !t.no_nil {
+				tag += 1
+			}
+
+			#partial switch vti in variant.variant {
+			case reflect.Type_Info_Named:
+				if vti.name == target_name {
+					reflect.set_union_variant_raw_tag(v, tag)
+					return _unmarshal_value(r, any{v.data, variant.id}, _decode_header(r) or_return)
+				}
+
+			case:
+				builder := strings.builder_make(context.temp_allocator)
+				defer strings.builder_destroy(&builder)
+
+				reflect.write_type(&builder, variant)
+				variant_name := strings.to_string(builder)
+				
+				if variant_name == target_name {
+					reflect.set_union_variant_raw_tag(v, tag)
+					return _unmarshal_value(r, any{v.data, variant.id}, _decode_header(r) or_return)
+				}
+			}
+		}
+
+		// No variant matched.
+		return _unsupported(v, idhdr)
+
+	case:
+		// Not a union.
+		return _unsupported(v, hdr)
+	}
+}
+
 _assign_int :: proc(val: any, i: $T) -> bool {
 	v := reflect.any_core(val)
 

+ 220 - 40
tests/core/encoding/cbor/test_core_cbor.odin

@@ -6,10 +6,96 @@ import "core:fmt"
 import "core:intrinsics"
 import "core:math/big"
 import "core:mem"
+import "core:os"
 import "core:reflect"
 import "core:testing"
 import "core:time"
 
+TEST_count := 0
+TEST_fail  := 0
+
+when ODIN_TEST {
+	expect       :: testing.expect
+	expect_value :: testing.expect_value
+	errorf       :: testing.errorf
+	log          :: testing.log
+
+} else {
+	expect :: proc(t: ^testing.T, condition: bool, message: string, loc := #caller_location) {
+		TEST_count += 1
+		if !condition {
+			TEST_fail += 1
+			fmt.printf("[%v] %v\n", loc, message)
+			return
+		}
+	}
+
+	expect_value :: proc(t: ^testing.T, value, expected: $T, loc := #caller_location) -> bool where intrinsics.type_is_comparable(T) {
+		TEST_count += 1
+		ok := value == expected || reflect.is_nil(value) && reflect.is_nil(expected)
+		if !ok {
+			TEST_fail += 1
+			fmt.printf("[%v] expected %v, got %v\n", loc, expected, value)
+		}
+		return ok
+	}
+
+	errorf :: proc(t: ^testing.T, fmts: string, args: ..any, loc := #caller_location) {
+		TEST_fail += 1
+		fmt.printf("[%v] ERROR: ", loc)
+		fmt.printf(fmts, ..args)
+		fmt.println()
+	}
+
+	log :: proc(t: ^testing.T, v: any, loc := #caller_location) {
+		fmt.printf("[%v] ", loc)
+		fmt.printf("log: %v\n", v)
+	}
+}
+
+main :: proc() {
+	t := testing.T{}
+
+	test_marshalling(&t)
+
+	test_marshalling_maybe(&t)
+	test_marshalling_nil_maybe(&t)
+
+	test_cbor_marshalling_union(&t)
+
+	test_decode_unsigned(&t)
+	test_encode_unsigned(&t)
+
+	test_decode_negative(&t)
+	test_encode_negative(&t)
+
+	test_decode_simples(&t)
+	test_encode_simples(&t)
+
+	test_decode_floats(&t)
+	test_encode_floats(&t)
+
+	test_decode_bytes(&t)
+	test_encode_bytes(&t)
+
+	test_decode_strings(&t)
+	test_encode_strings(&t)
+
+	test_decode_lists(&t)
+	test_encode_lists(&t)
+
+	test_decode_maps(&t)
+	test_encode_maps(&t)
+
+	test_decode_tags(&t)
+	test_encode_tags(&t)
+
+	fmt.printf("%v/%v tests successful.\n", TEST_count - TEST_fail, TEST_count)
+	if TEST_fail > 0 {
+		os.exit(1)
+	}
+}
+
 Foo :: struct {
 	str: string,
 	cstr: cstring,
@@ -58,7 +144,7 @@ test_marshalling :: proc(t: ^testing.T) {
 	context.temp_allocator = context.allocator
 	defer mem.tracking_allocator_destroy(&tracker)
 
-	ev :: testing.expect_value
+	ev :: expect_value
 
 	{
 		nice := "16 is a nice number"
@@ -228,7 +314,7 @@ test_marshalling :: proc(t: ^testing.T) {
 				}
 			}
 
-		case: testing.error(t, v)
+		case: errorf(t, "wrong type %v", v)
 		}
 
 		ev(t, backf.neg, f.neg)
@@ -258,22 +344,116 @@ test_marshalling :: proc(t: ^testing.T) {
 		s_equals, s_err := big.equals(&backf.smallest, &f.smallest)
 		ev(t, s_err, nil)
 		if !s_equals {
-			testing.errorf(t, "smallest: %v does not equal %v", big.itoa(&backf.smallest), big.itoa(&f.smallest))
+			errorf(t, "smallest: %v does not equal %v", big.itoa(&backf.smallest), big.itoa(&f.smallest))
 		}
 
 		b_equals, b_err := big.equals(&backf.biggest, &f.biggest)
 		ev(t, b_err, nil)
 		if !b_equals {
-			testing.errorf(t, "biggest: %v does not equal %v", big.itoa(&backf.biggest), big.itoa(&f.biggest))
+			errorf(t, "biggest: %v does not equal %v", big.itoa(&backf.biggest), big.itoa(&f.biggest))
 		}
 	}
 
 	for _, leak in tracker.allocation_map {
-		testing.errorf(t, "%v leaked %m\n", leak.location, leak.size)
+		errorf(t, "%v leaked %m\n", leak.location, leak.size)
 	}
 
 	for bad_free in tracker.bad_free_array {
-		testing.errorf(t, "%v allocation %p was freed badly\n", bad_free.location, bad_free.memory)
+		errorf(t, "%v allocation %p was freed badly\n", bad_free.location, bad_free.memory)
+	}
+}
+
+@(test)
+test_marshalling_maybe :: proc(t: ^testing.T) {
+	maybe_test: Maybe(int) = 1
+	data, err := cbor.marshal(maybe_test)
+	expect_value(t, err, nil)
+
+	val, derr := cbor.decode(string(data))
+	expect_value(t, derr, nil)
+
+	expect_value(t, cbor.diagnose(val), "1")
+	
+	maybe_dest: Maybe(int)
+	uerr := cbor.unmarshal(string(data), &maybe_dest)
+	expect_value(t, uerr, nil)
+	expect_value(t, maybe_dest, 1)
+}
+
+@(test)
+test_marshalling_nil_maybe :: proc(t: ^testing.T) {
+	maybe_test: Maybe(int)
+	data, err := cbor.marshal(maybe_test)
+	expect_value(t, err, nil)
+
+	val, derr := cbor.decode(string(data))
+	expect_value(t, derr, nil)
+
+	expect_value(t, cbor.diagnose(val), "nil")
+	
+	maybe_dest: Maybe(int)
+	uerr := cbor.unmarshal(string(data), &maybe_dest)
+	expect_value(t, uerr, nil)
+	expect_value(t, maybe_dest, nil)
+}
+
+@(test)
+test_cbor_marshalling_union :: proc(t: ^testing.T) {
+	My_Distinct :: distinct string
+
+	My_Enum :: enum {
+		One,
+		Two,
+	}
+
+	My_Struct :: struct {
+		my_enum: My_Enum,
+	}
+
+	My_Union :: union {
+		string,
+		My_Distinct,
+		My_Struct,
+		int,
+	}
+
+	{
+		test: My_Union = My_Distinct("Hello, World!")
+		data, err := cbor.marshal(test)
+		expect_value(t, err, nil)
+
+		val, derr := cbor.decode(string(data))
+		expect_value(t, derr, nil)
+
+		expect_value(t, cbor.diagnose(val, -1), `1010(["My_Distinct", "Hello, World!"])`)
+
+		dest: My_Union
+		uerr := cbor.unmarshal(string(data), &dest)
+		expect_value(t, uerr, nil)
+		expect_value(t, dest, My_Distinct("Hello, World!"))
+	}
+
+	My_Union_No_Nil :: union #no_nil {
+		string,
+		My_Distinct,
+		My_Struct,
+		int,
+	}
+
+	{
+		test: My_Union_No_Nil = My_Struct{.Two}
+		data, err := cbor.marshal(test)
+		expect_value(t, err, nil)
+
+		val, derr := cbor.decode(string(data))
+		expect_value(t, derr, nil)
+
+		expect_value(t, cbor.diagnose(val, -1), `1010(["My_Struct", {"my_enum": 1}])`)
+
+		dest: My_Union_No_Nil
+		uerr := cbor.unmarshal(string(data), &dest)
+		expect_value(t, uerr, nil)
+		expect_value(t, dest, My_Struct{.Two})
 	}
 }
 
@@ -500,34 +680,34 @@ test_encode_lists :: proc(t: ^testing.T) {
 		
 		err: cbor.Encode_Error
 		err = cbor.encode_stream_begin(stream, .Array)
-		testing.expect_value(t, err, nil)
+		expect_value(t, err, nil)
 
 		{
 			err = cbor.encode_stream_array_item(encoder, u8(1))
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 
 			err = cbor.encode_stream_array_item(encoder, &cbor.Array{u8(2), u8(3)})
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 
 			err = cbor.encode_stream_begin(stream, .Array)
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 
 			{
 				err = cbor.encode_stream_array_item(encoder, u8(4))
-				testing.expect_value(t, err, nil)
+				expect_value(t, err, nil)
 
 				err = cbor.encode_stream_array_item(encoder, u8(5))
-				testing.expect_value(t, err, nil)
+				expect_value(t, err, nil)
 			}
 
 			err = cbor.encode_stream_end(stream)
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 		}
 
 		err = cbor.encode_stream_end(stream)
-		testing.expect_value(t, err, nil)
+		expect_value(t, err, nil)
 		
-		testing.expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)string("\x9f\x01\x82\x02\x03\x9f\x04\x05\xff\xff")))
+		expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)string("\x9f\x01\x82\x02\x03\x9f\x04\x05\xff\xff")))
 	}
 	
 	{
@@ -535,26 +715,26 @@ test_encode_lists :: proc(t: ^testing.T) {
 	
 		err: cbor.Encode_Error
 		err = cbor._encode_u8(stream, 2, .Array)
-		testing.expect_value(t, err, nil)
+		expect_value(t, err, nil)
 		
 		a := "a"
 		err = cbor.encode(encoder, &a)
-		testing.expect_value(t, err, nil)
+		expect_value(t, err, nil)
 		
 		{
 			err = cbor.encode_stream_begin(stream, .Map)
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 			
 			b := "b"
 			c := "c"
 			err = cbor.encode_stream_map_entry(encoder, &b, &c)
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 
 			err = cbor.encode_stream_end(stream)
-			testing.expect_value(t, err, nil)
+			expect_value(t, err, nil)
 		}
 		
-		testing.expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)string("\x82\x61\x61\xbf\x61\x62\x61\x63\xff")))
+		expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)string("\x82\x61\x61\xbf\x61\x62\x61\x63\xff")))
 	}
 }
 
@@ -619,13 +799,13 @@ expect_decoding :: proc(t: ^testing.T, encoded: string, decoded: string, type: t
     res, err := cbor.decode(stream)
 	defer cbor.destroy(res)
 
-	testing.expect_value(t, reflect.union_variant_typeid(res), type, loc)
-    testing.expect_value(t, err, nil, loc)
+	expect_value(t, reflect.union_variant_typeid(res), type, loc)
+    expect_value(t, err, nil, loc)
 
 	str := cbor.diagnose(res, padding=-1)
 	defer delete(str)
 
-    testing.expect_value(t, str, decoded, loc)
+    expect_value(t, str, decoded, loc)
 }
 
 expect_tag :: proc(t: ^testing.T, encoded: string, nr: cbor.Tag_Number, value_decoded: string, loc := #caller_location) {
@@ -635,17 +815,17 @@ expect_tag :: proc(t: ^testing.T, encoded: string, nr: cbor.Tag_Number, value_de
 	res, err := cbor.decode(stream)
 	defer cbor.destroy(res)
 
-	testing.expect_value(t, err, nil, loc)
+	expect_value(t, err, nil, loc)
 	
 	if tag, is_tag := res.(^cbor.Tag); is_tag {
-		testing.expect_value(t, tag.number, nr, loc)
+		expect_value(t, tag.number, nr, loc)
 
 		str := cbor.diagnose(tag, padding=-1)
 		defer delete(str)
 
-		testing.expect_value(t, str, value_decoded, loc)
+		expect_value(t, str, value_decoded, loc)
 	} else {
-		testing.errorf(t, "Value %#v is not a tag", res, loc)
+		errorf(t, "Value %#v is not a tag", res, loc)
 	}
 }
 
@@ -656,16 +836,16 @@ expect_float :: proc(t: ^testing.T, encoded: string, expected: $T, loc := #calle
     res, err := cbor.decode(stream)
 	defer cbor.destroy(res)
 
-	testing.expect_value(t, reflect.union_variant_typeid(res), typeid_of(T), loc)
-    testing.expect_value(t, err, nil, loc)
+	expect_value(t, reflect.union_variant_typeid(res), typeid_of(T), loc)
+    expect_value(t, err, nil, loc)
 
 	#partial switch r in res {
 	case f16:
-		when T == f16 { testing.expect_value(t, res, expected, loc) } else { unreachable() }
+		when T == f16 { expect_value(t, res, expected, loc) } else { unreachable() }
 	case f32:
-		when T == f32 { testing.expect_value(t, res, expected, loc) } else { unreachable() }
+		when T == f32 { expect_value(t, res, expected, loc) } else { unreachable() }
 	case f64:
-		when T == f64 { testing.expect_value(t, res, expected, loc) } else { unreachable() }
+		when T == f64 { expect_value(t, res, expected, loc) } else { unreachable() }
 	case:
 		unreachable()
 	}
@@ -675,8 +855,8 @@ expect_encoding :: proc(t: ^testing.T, val: cbor.Value, encoded: string, loc :=
 	bytes.buffer_reset(&buf)
 
 	err := cbor.encode(encoder, val)
-	testing.expect_value(t, err, nil, loc)
-	testing.expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)encoded), loc)
+	expect_value(t, err, nil, loc)
+	expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)encoded), loc)
 }
 
 expect_streamed_encoding :: proc(t: ^testing.T, encoded: string, values: ..cbor.Value, loc := #caller_location) {
@@ -705,15 +885,15 @@ expect_streamed_encoding :: proc(t: ^testing.T, encoded: string, values: ..cbor.
 				if err2 != nil { break }
 			}
 		case:
-			testing.errorf(t, "%v does not support streamed encoding", reflect.union_variant_typeid(value))
+			errorf(t, "%v does not support streamed encoding", reflect.union_variant_typeid(value))
 		}
 
-		testing.expect_value(t, err, nil, loc)
-		testing.expect_value(t, err2, nil, loc)
+		expect_value(t, err, nil, loc)
+		expect_value(t, err2, nil, loc)
 	}
 
 	err := cbor.encode_stream_end(stream)
-	testing.expect_value(t, err, nil, loc)
+	expect_value(t, err, nil, loc)
 
-	testing.expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)encoded), loc)
+	expect_value(t, fmt.tprint(bytes.buffer_to_bytes(&buf)), fmt.tprint(transmute([]byte)encoded), loc)
 }