Browse Source

encoding/cbor: add decoder flags and protect from malicious untrusted input

Laytan Laats 1 year ago
parent
commit
21e6e28a3a

+ 7 - 1
core/encoding/cbor/cbor.odin

@@ -10,8 +10,13 @@ import "core:strings"
 
 
 // If we are decoding a stream of either a map or list, the initial capacity will be this value.
 // If we are decoding a stream of either a map or list, the initial capacity will be this value.
 INITIAL_STREAMED_CONTAINER_CAPACITY :: 8
 INITIAL_STREAMED_CONTAINER_CAPACITY :: 8
+
 // If we are decoding a stream of either text or bytes, the initial capacity will be this value.
 // If we are decoding a stream of either text or bytes, the initial capacity will be this value.
-INITIAL_STREAMED_BYTES_CAPACITY     :: 16
+INITIAL_STREAMED_BYTES_CAPACITY :: 16
+
+// The default maximum amount of bytes to allocate on a buffer/container at once to prevent
+// malicious input from causing massive allocations.
+DEFAULT_MAX_PRE_ALLOC :: mem.Kilobyte
 
 
 // Known/common headers are defined, undefined headers can still be valid.
 // Known/common headers are defined, undefined headers can still be valid.
 // Higher 3 bits is for the major type and lower 5 bits for the additional information.
 // Higher 3 bits is for the major type and lower 5 bits for the additional information.
@@ -157,6 +162,7 @@ Decode_Data_Error :: enum {
 	Nested_Indefinite_Length, // When an streamed/indefinite length container nests another, this is not allowed.
 	Nested_Indefinite_Length, // When an streamed/indefinite length container nests another, this is not allowed.
 	Nested_Tag,               // When a tag's value is another tag, this is not allowed.
 	Nested_Tag,               // When a tag's value is another tag, this is not allowed.
 	Length_Too_Big,           // When the length of a container (map, array, bytes, string) is more than `max(int)`.
 	Length_Too_Big,           // When the length of a container (map, array, bytes, string) is more than `max(int)`.
+	Disallowed_Streaming,     // When the `.Disallow_Streaming` flag is set and a streaming header is encountered.
 	Break,
 	Break,
 }
 }
 
 

+ 169 - 106
core/encoding/cbor/coding.odin

@@ -33,16 +33,40 @@ Encoder_Flags :: bit_set[Encoder_Flag]
 
 
 // Flags for fully deterministic output (if you are not using streaming/indeterminate length).
 // Flags for fully deterministic output (if you are not using streaming/indeterminate length).
 ENCODE_FULLY_DETERMINISTIC :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size, .Deterministic_Map_Sorting}
 ENCODE_FULLY_DETERMINISTIC :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size, .Deterministic_Map_Sorting}
+
 // Flags for the smallest encoding output.
 // Flags for the smallest encoding output.
-ENCODE_SMALL               :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size}
-// Flags for the fastest encoding output.
-ENCODE_FAST                :: Encoder_Flags{}
+ENCODE_SMALL :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size}
 
 
 Encoder :: struct {
 Encoder :: struct {
 	flags:  Encoder_Flags,
 	flags:  Encoder_Flags,
 	writer: io.Writer,
 	writer: io.Writer,
 }
 }
 
 
+Decoder_Flag :: enum {
+	// Rejects (with an error `.Disallowed_Streaming`) when a streaming CBOR header is encountered.
+	Disallow_Streaming,
+
+	// Pre-allocates buffers and containers with the size that was set in the CBOR header.
+	// This should only be enabled when you control both ends of the encoding, if you don't,
+	// attackers can craft input that causes massive (`max(u64)`) byte allocations for a few bytes of
+	// CBOR.
+	Trusted_Input,
+	
+	// Makes the decoder shrink of excess capacity from allocated buffers/containers before returning.
+	Shrink_Excess,
+}
+
+Decoder_Flags :: bit_set[Decoder_Flag]
+
+Decoder :: struct {
+	// The max amount of bytes allowed to pre-allocate when `.Trusted_Input` is not set on the
+	// flags.
+	max_pre_alloc: int,
+
+	flags:  Decoder_Flags,
+	reader: io.Reader,
+}
+
 /*
 /*
 Decodes both deterministic and non-deterministic CBOR into a `Value` variant.
 Decodes both deterministic and non-deterministic CBOR into a `Value` variant.
 
 
@@ -52,28 +76,60 @@ Allocations are done using the given allocator,
 *no* allocations are done on the `context.temp_allocator`.
 *no* allocations are done on the `context.temp_allocator`.
 
 
 A value can be (fully and recursively) deallocated using the `destroy` proc in this package.
 A value can be (fully and recursively) deallocated using the `destroy` proc in this package.
+
+Disable streaming/indeterminate lengths with the `.Disallow_Streaming` flag.
+
+Shrink excess bytes in buffers and containers with the `.Shrink_Excess` flag.
+
+Mark the input as trusted input with the `.Trusted_Input` flag, this turns off the safety feature
+of not pre-allocating more than `max_pre_alloc` bytes before reading into the bytes. You should only
+do this when you own both sides of the encoding and are sure there can't be malicious bytes used as
+an input.
 */
 */
-decode :: proc {
-	decode_string,
-	decode_reader,
+decode_from :: proc {
+	decode_from_string,
+	decode_from_reader,
+	decode_from_decoder,
 }
 }
+decode :: decode_from
 
 
 // Decodes the given string as CBOR.
 // Decodes the given string as CBOR.
 // See docs on the proc group `decode` for more information.
 // See docs on the proc group `decode` for more information.
-decode_string :: proc(s: string, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
+decode_from_string :: proc(s: string, flags: Decoder_Flags = {}, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
 	context.allocator = allocator
 	context.allocator = allocator
-
 	r: strings.Reader
 	r: strings.Reader
 	strings.reader_init(&r, s)
 	strings.reader_init(&r, s)
-	return decode(strings.reader_to_stream(&r), allocator=allocator)
+	return decode_from_reader(strings.reader_to_stream(&r), flags)
 }
 }
 
 
 // Reads a CBOR value from the given reader.
 // Reads a CBOR value from the given reader.
 // See docs on the proc group `decode` for more information.
 // See docs on the proc group `decode` for more information.
-decode_reader :: proc(r: io.Reader, hdr: Header = Header(0), allocator := context.allocator) -> (v: Value, err: Decode_Error) {
+decode_from_reader :: proc(r: io.Reader, flags: Decoder_Flags = {}, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
+	return decode_from_decoder(
+		Decoder{ DEFAULT_MAX_PRE_ALLOC, flags, r },
+		allocator=allocator,
+	)
+}
+
+// Reads a CBOR value from the given decoder.
+// See docs on the proc group `decode` for more information.
+decode_from_decoder :: proc(d: Decoder, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
 	context.allocator = allocator
 	context.allocator = allocator
 	
 	
+	d := d
+	if d.max_pre_alloc <= 0 {
+		d.max_pre_alloc = DEFAULT_MAX_PRE_ALLOC
+	}
+
+	v, err = _decode_from_decoder(d)
+	// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
+	if err == .EOF { err = .Unexpected_EOF }
+	return
+}
+
+_decode_from_decoder :: proc(d: Decoder, hdr: Header = Header(0)) -> (v: Value, err: Decode_Error) {
 	hdr := hdr
 	hdr := hdr
+	r := d.reader
 	if hdr == Header(0) { hdr = _decode_header(r) or_return }
 	if hdr == Header(0) { hdr = _decode_header(r) or_return }
 	switch hdr {
 	switch hdr {
 	case .U8:  return _decode_u8 (r)
 	case .U8:  return _decode_u8 (r)
@@ -105,11 +161,11 @@ decode_reader :: proc(r: io.Reader, hdr: Header = Header(0), allocator := contex
 	switch maj {
 	switch maj {
 	case .Unsigned: return _decode_tiny_u8(add)
 	case .Unsigned: return _decode_tiny_u8(add)
 	case .Negative: return Negative_U8(_decode_tiny_u8(add) or_return), nil
 	case .Negative: return Negative_U8(_decode_tiny_u8(add) or_return), nil
-	case .Bytes:    return _decode_bytes_ptr(r, add)
-	case .Text:     return _decode_text_ptr(r, add)
-	case .Array:    return _decode_array_ptr(r, add)
-	case .Map:      return _decode_map_ptr(r, add)
-	case .Tag:      return _decode_tag_ptr(r, add)
+	case .Bytes:    return _decode_bytes_ptr(d, add)
+	case .Text:     return _decode_text_ptr(d, add)
+	case .Array:    return _decode_array_ptr(d, add)
+	case .Map:      return _decode_map_ptr(d, add)
+	case .Tag:      return _decode_tag_ptr(d, add)
 	case .Other:    return _decode_tiny_simple(add)
 	case .Other:    return _decode_tiny_simple(add)
 	case:           return nil, .Bad_Major
 	case:           return nil, .Bad_Major
 	}
 	}
@@ -246,7 +302,7 @@ _encode_u8 :: proc(w: io.Writer, v: u8, major: Major = .Unsigned) -> (err: io.Er
 }
 }
 
 
 _decode_tiny_u8 :: proc(additional: Add) -> (u8, Decode_Data_Error) {
 _decode_tiny_u8 :: proc(additional: Add) -> (u8, Decode_Data_Error) {
-	if intrinsics.expect(additional < .One_Byte, true) {
+	if additional < .One_Byte {
 		return u8(additional), nil
 		return u8(additional), nil
 	}
 	}
 
 
@@ -316,64 +372,53 @@ _encode_u64_exact :: proc(w: io.Writer, v: u64, major: Major = .Unsigned) -> (er
 	return
 	return
 }
 }
 
 
-_decode_bytes_ptr :: proc(r: io.Reader, add: Add, type: Major = .Bytes) -> (v: ^Bytes, err: Decode_Error) {
+_decode_bytes_ptr :: proc(d: Decoder, add: Add, type: Major = .Bytes) -> (v: ^Bytes, err: Decode_Error) {
 	v = new(Bytes) or_return
 	v = new(Bytes) or_return
 	defer if err != nil { free(v) }
 	defer if err != nil { free(v) }
 
 
-	v^ = _decode_bytes(r, add, type) or_return
+	v^ = _decode_bytes(d, add, type) or_return
 	return
 	return
 }
 }
 
 
-_decode_bytes :: proc(r: io.Reader, add: Add, type: Major = .Bytes) -> (v: Bytes, err: Decode_Error) {
-	_n_items, length_is_unknown := _decode_container_length(r, add) or_return
-
-	n_items := _n_items.? or_else INITIAL_STREAMED_BYTES_CAPACITY
-
-	if length_is_unknown {
-		buf: strings.Builder
-		buf.buf = make([dynamic]byte, 0, n_items) or_return
-		defer if err != nil { strings.builder_destroy(&buf) }
-
-		buf_stream := strings.to_stream(&buf)
+_decode_bytes :: proc(d: Decoder, add: Add, type: Major = .Bytes) -> (v: Bytes, err: Decode_Error) {
+	n, scap := _decode_len_str(d, add) or_return
+	
+	buf := strings.builder_make(0, scap) or_return
+	defer if err != nil { strings.builder_destroy(&buf) }
+	buf_stream := strings.to_stream(&buf)
 
 
-		for {
-			header   := _decode_header(r) or_return
+	if n == -1 {
+		indefinite_loop: for {
+			header   := _decode_header(d.reader) or_return
 			maj, add := _header_split(header)
 			maj, add := _header_split(header)
-
 			#partial switch maj {
 			#partial switch maj {
 			case type:
 			case type:
-				_n_items, length_is_unknown := _decode_container_length(r, add) or_return
-				if length_is_unknown {
+				iter_n, iter_cap := _decode_len_str(d, add) or_return
+				if iter_n == -1 {
 					return nil, .Nested_Indefinite_Length
 					return nil, .Nested_Indefinite_Length
 				}
 				}
-				n_items := i64(_n_items.?)
+				reserve(&buf.buf, len(buf.buf) + iter_cap) or_return
+				io.copy_n(buf_stream, d.reader, i64(iter_n)) or_return
 
 
-				copied := io.copy_n(buf_stream, r, n_items) or_return
-				assert(copied == n_items)
-					
 			case .Other:
 			case .Other:
 				if add != .Break { return nil, .Bad_Argument }
 				if add != .Break { return nil, .Bad_Argument }
-				
-				v = buf.buf[:]
- 				
-				// Write zero byte so this can be converted to cstring.
-				io.write_full(buf_stream, {0}) or_return
-				shrink(&buf.buf) // Ignoring error, this is not critical to succeed.
-				return
+				break indefinite_loop
 
 
 			case:
 			case:
 				return nil, .Bad_Major
 				return nil, .Bad_Major
 			}
 			}
 		}
 		}
 	} else {
 	} else {
-		v = make([]byte, n_items + 1) or_return // Space for the bytes and a zero byte.
-		defer if err != nil { delete(v) }
+		io.copy_n(buf_stream, d.reader, i64(n)) or_return
+	}
 
 
-		io.read_full(r, v[:n_items]) or_return
+	v = buf.buf[:]
 
 
-		v = v[:n_items] // Take off zero byte.
-		return
-	}
+	// Write zero byte so this can be converted to cstring.
+	strings.write_byte(&buf, 0)
+
+	if .Shrink_Excess in d.flags { shrink(&buf.buf) }
+	return
 }
 }
 
 
 _encode_bytes :: proc(e: Encoder, val: Bytes, major: Major = .Bytes) -> (err: Encode_Error) {
 _encode_bytes :: proc(e: Encoder, val: Bytes, major: Major = .Bytes) -> (err: Encode_Error) {
@@ -383,43 +428,41 @@ _encode_bytes :: proc(e: Encoder, val: Bytes, major: Major = .Bytes) -> (err: En
 	return
 	return
 }
 }
 
 
-_decode_text_ptr :: proc(r: io.Reader, add: Add) -> (v: ^Text, err: Decode_Error) {
+_decode_text_ptr :: proc(d: Decoder, add: Add) -> (v: ^Text, err: Decode_Error) {
 	v = new(Text) or_return
 	v = new(Text) or_return
 	defer if err != nil { free(v) }
 	defer if err != nil { free(v) }
 
 
-	v^ = _decode_text(r, add) or_return
+	v^ = _decode_text(d, add) or_return
 	return
 	return
 }
 }
 
 
-_decode_text :: proc(r: io.Reader, add: Add) -> (v: Text, err: Decode_Error) {
-	return (Text)(_decode_bytes(r, add, .Text) or_return), nil
+_decode_text :: proc(d: Decoder, add: Add) -> (v: Text, err: Decode_Error) {
+	return (Text)(_decode_bytes(d, add, .Text) or_return), nil
 }
 }
 
 
 _encode_text :: proc(e: Encoder, val: Text) -> Encode_Error {
 _encode_text :: proc(e: Encoder, val: Text) -> Encode_Error {
     return _encode_bytes(e, transmute([]byte)val, .Text)
     return _encode_bytes(e, transmute([]byte)val, .Text)
 }
 }
 
 
-_decode_array_ptr :: proc(r: io.Reader, add: Add) -> (v: ^Array, err: Decode_Error) {
+_decode_array_ptr :: proc(d: Decoder, add: Add) -> (v: ^Array, err: Decode_Error) {
 	v = new(Array) or_return
 	v = new(Array) or_return
 	defer if err != nil { free(v) }
 	defer if err != nil { free(v) }
 
 
-	v^ = _decode_array(r, add) or_return
+	v^ = _decode_array(d, add) or_return
 	return
 	return
 }
 }
 
 
-_decode_array :: proc(r: io.Reader, add: Add) -> (v: Array, err: Decode_Error) {
-	_n_items, length_is_unknown := _decode_container_length(r, add) or_return
-	n_items := _n_items.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
-
-	array := make([dynamic]Value, 0, n_items) or_return
+_decode_array :: proc(d: Decoder, add: Add) -> (v: Array, err: Decode_Error) {
+	n, scap := _decode_len_container(d, add) or_return
+	array := make([dynamic]Value, 0, scap) or_return
 	defer if err != nil {
 	defer if err != nil {
 		for entry in array { destroy(entry) }
 		for entry in array { destroy(entry) }
 		delete(array)
 		delete(array)
 	}
 	}
 	
 	
-	for i := 0; length_is_unknown || i < n_items; i += 1 {
-		val, verr := decode(r)
-		if length_is_unknown && verr == .Break {
+	for i := 0; n == -1 || i < n; i += 1 {
+		val, verr := _decode_from_decoder(d)
+		if n == -1 && verr == .Break {
 			break
 			break
 		} else if verr != nil {
 		} else if verr != nil {
 			err = verr
 			err = verr
@@ -428,8 +471,9 @@ _decode_array :: proc(r: io.Reader, add: Add) -> (v: Array, err: Decode_Error) {
 
 
 		append(&array, val) or_return
 		append(&array, val) or_return
 	}
 	}
+
+	if .Shrink_Excess in d.flags { shrink(&array) }
 	
 	
-	shrink(&array)
 	v = array[:]
 	v = array[:]
 	return
 	return
 }
 }
@@ -443,19 +487,17 @@ _encode_array :: proc(e: Encoder, arr: Array) -> Encode_Error {
     return nil
     return nil
 }
 }
 
 
-_decode_map_ptr :: proc(r: io.Reader, add: Add) -> (v: ^Map, err: Decode_Error) {
+_decode_map_ptr :: proc(d: Decoder, add: Add) -> (v: ^Map, err: Decode_Error) {
 	v = new(Map) or_return
 	v = new(Map) or_return
 	defer if err != nil { free(v) }
 	defer if err != nil { free(v) }
 
 
-	v^ = _decode_map(r, add) or_return
+	v^ = _decode_map(d, add) or_return
 	return
 	return
 }
 }
 
 
-_decode_map :: proc(r: io.Reader, add: Add) -> (v: Map, err: Decode_Error) {
-	_n_items, length_is_unknown := _decode_container_length(r, add) or_return
-	n_items := _n_items.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
-	
-	items := make([dynamic]Map_Entry, 0, n_items) or_return
+_decode_map :: proc(d: Decoder, add: Add) -> (v: Map, err: Decode_Error) {
+	n, scap := _decode_len_container(d, add) or_return
+	items := make([dynamic]Map_Entry, 0, scap) or_return
 	defer if err != nil { 
 	defer if err != nil { 
 		for entry in items {
 		for entry in items {
 			destroy(entry.key)
 			destroy(entry.key)
@@ -464,23 +506,24 @@ _decode_map :: proc(r: io.Reader, add: Add) -> (v: Map, err: Decode_Error) {
 		delete(items)
 		delete(items)
 	}
 	}
 
 
-	for i := 0; length_is_unknown || i < n_items; i += 1 {
-		key, kerr := decode(r)
-		if length_is_unknown && kerr == .Break {
+	for i := 0; n == -1 || i < n; i += 1 {
+		key, kerr := _decode_from_decoder(d)
+		if n == -1 && kerr == .Break {
 			break
 			break
 		} else if kerr != nil {
 		} else if kerr != nil {
 			return nil, kerr
 			return nil, kerr
 		} 
 		} 
 
 
-		value := decode(r) or_return
+		value := decode_from_decoder(d) or_return
 
 
 		append(&items, Map_Entry{
 		append(&items, Map_Entry{
 			key   = key,
 			key   = key,
 			value = value,
 			value = value,
 		}) or_return
 		}) or_return
 	}
 	}
+
+	if .Shrink_Excess in d.flags { shrink(&items) }
 	
 	
-	shrink(&items)
 	v = items[:]
 	v = items[:]
 	return
 	return
 }
 }
@@ -537,8 +580,8 @@ _encode_map :: proc(e: Encoder, m: Map) -> (err: Encode_Error) {
     return nil
     return nil
 }
 }
 
 
-_decode_tag_ptr :: proc(r: io.Reader, add: Add) -> (v: Value, err: Decode_Error) {
-	tag := _decode_tag(r, add) or_return
+_decode_tag_ptr :: proc(d: Decoder, add: Add) -> (v: Value, err: Decode_Error) {
+	tag := _decode_tag(d, add) or_return
 	if t, ok := tag.?; ok {
 	if t, ok := tag.?; ok {
 		defer if err != nil { destroy(t.value) }
 		defer if err != nil { destroy(t.value) }
 		tp := new(Tag) or_return
 		tp := new(Tag) or_return
@@ -547,11 +590,11 @@ _decode_tag_ptr :: proc(r: io.Reader, add: Add) -> (v: Value, err: Decode_Error)
 	}
 	}
 
 
 	// no error, no tag, this was the self described CBOR tag, skip it.
 	// no error, no tag, this was the self described CBOR tag, skip it.
-	return decode(r)
+	return _decode_from_decoder(d)
 }
 }
 
 
-_decode_tag :: proc(r: io.Reader, add: Add) -> (v: Maybe(Tag), err: Decode_Error) {
-	num := _decode_tag_nr(r, add) or_return
+_decode_tag :: proc(d: Decoder, add: Add) -> (v: Maybe(Tag), err: Decode_Error) {
+	num := _decode_uint_as_u64(d.reader, add) or_return
 
 
 	// CBOR can be wrapped in a tag that decoders can use to see/check if the binary data is CBOR.
 	// CBOR can be wrapped in a tag that decoders can use to see/check if the binary data is CBOR.
 	// We can ignore it here.
 	// We can ignore it here.
@@ -561,7 +604,7 @@ _decode_tag :: proc(r: io.Reader, add: Add) -> (v: Maybe(Tag), err: Decode_Error
 
 
 	t := Tag{
 	t := Tag{
 		number = num,
 		number = num,
-		value = decode(r) or_return,
+		value = _decode_from_decoder(d) or_return,
 	}
 	}
 
 
 	if nested, ok := t.value.(^Tag); ok {
 	if nested, ok := t.value.(^Tag); ok {
@@ -572,7 +615,7 @@ _decode_tag :: proc(r: io.Reader, add: Add) -> (v: Maybe(Tag), err: Decode_Error
 	return t, nil
 	return t, nil
 }
 }
 
 
-_decode_tag_nr :: proc(r: io.Reader, add: Add) -> (nr: Tag_Number, err: Decode_Error) {
+_decode_uint_as_u64 :: proc(r: io.Reader, add: Add) -> (nr: u64, err: Decode_Error) {
 	#partial switch add {
 	#partial switch add {
 	case .One_Byte:    return u64(_decode_u8(r) or_return), nil
 	case .One_Byte:    return u64(_decode_u8(r) or_return), nil
 	case .Two_Bytes:   return u64(_decode_u16(r) or_return), nil
 	case .Two_Bytes:   return u64(_decode_u16(r) or_return), nil
@@ -719,30 +762,50 @@ encode_stream_map_entry :: proc(e: Encoder, key: Value, val: Value) -> Encode_Er
     return encode(e, val)
     return encode(e, val)
 }
 }
 
 
-//
-
-_decode_container_length :: proc(r: io.Reader, add: Add) -> (length: Maybe(int), is_unknown: bool, err: Decode_Error) {
-	if add == Add.Length_Unknown { return nil, true, nil }
-	#partial switch add {
-	case .One_Byte:  length = int(_decode_u8(r) or_return)
-	case .Two_Bytes: length = int(_decode_u16(r) or_return)
-	case .Four_Bytes:
-		big_length := _decode_u32(r) or_return
-		if u64(big_length) > u64(max(int)) {
-			err = .Length_Too_Big
-			return
+// For `Bytes` and `Text` strings: Decodes the number of items the header says follows.
+// If the number is not specified -1 is returned and streaming should be initiated.
+// A suitable starting capacity is also returned for a buffer that is allocated up the stack.
+_decode_len_str :: proc(d: Decoder, add: Add) -> (n: int, scap: int, err: Decode_Error) {
+	if add == .Length_Unknown {
+		if .Disallow_Streaming in d.flags {
+			return -1, -1, .Disallowed_Streaming
 		}
 		}
-		length = int(big_length)
-	case .Eight_Bytes:
-		big_length := _decode_u64(r) or_return
-		if big_length > u64(max(int)) {
-			err = .Length_Too_Big
-			return
+		return -1, INITIAL_STREAMED_BYTES_CAPACITY, nil
+	}
+
+	_n := _decode_uint_as_u64(d.reader, add) or_return
+	if _n > u64(max(int)) { return -1, -1, .Length_Too_Big }
+	n = int(_n)
+
+	scap = n + 1 // Space for zero byte.
+	if .Trusted_Input not_in d.flags {
+		scap = min(d.max_pre_alloc, scap)
+	}
+
+	return
+}
+
+// For `Array` and `Map` types: Decodes the number of items the header says follows.
+// If the number is not specified -1 is returned and streaming should be initiated.
+// A suitable starting capacity is also returned for a buffer that is allocated up the stack.
+_decode_len_container :: proc(d: Decoder, add: Add) -> (n: int, scap: int, err: Decode_Error) {
+	if add == .Length_Unknown {
+		if .Disallow_Streaming in d.flags {
+			return -1, -1, .Disallowed_Streaming
 		}
 		}
-		length = int(big_length)
-	case:
-		length = int(_decode_tiny_u8(add) or_return)
+		return -1, INITIAL_STREAMED_CONTAINER_CAPACITY, nil
+	}
+
+	_n := _decode_uint_as_u64(d.reader, add) or_return
+	if _n > u64(max(int)) { return -1, -1, .Length_Too_Big }
+	n = int(_n)
+
+	scap = n
+	if .Trusted_Input not_in d.flags {
+		// NOTE: if this is a map it will be twice this.
+		scap = min(d.max_pre_alloc / size_of(Value), scap)
 	}
 	}
+
 	return
 	return
 }
 }
 
 

+ 16 - 16
core/encoding/cbor/tags.odin

@@ -55,7 +55,7 @@ Tag_Implementation :: struct {
 }
 }
 
 
 // Procedure responsible for umarshalling the tag out of the reader into the given `any`.
 // Procedure responsible for umarshalling the tag out of the reader into the given `any`.
-Tag_Unmarshal_Proc :: #type proc(self: ^Tag_Implementation, r: io.Reader, tag_nr: Tag_Number, v: any) -> Unmarshal_Error
+Tag_Unmarshal_Proc :: #type proc(self: ^Tag_Implementation, d: Decoder, tag_nr: Tag_Number, v: any) -> Unmarshal_Error
 
 
 // Procedure responsible for marshalling the tag in the given `any` into the given encoder.
 // Procedure responsible for marshalling the tag in the given `any` into the given encoder.
 Tag_Marshal_Proc   :: #type proc(self: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_Error
 Tag_Marshal_Proc   :: #type proc(self: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_Error
@@ -121,30 +121,30 @@ tags_register_defaults :: proc() {
 //
 //
 // See RFC 8949 section 3.4.2.
 // See RFC 8949 section 3.4.2.
 @(private)
 @(private)
-tag_time_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
-	hdr := _decode_header(r) or_return
+tag_time_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
+	hdr := _decode_header(d.reader) or_return
 	#partial switch hdr {
 	#partial switch hdr {
 	case .U8, .U16, .U32, .U64, .Neg_U8, .Neg_U16, .Neg_U32, .Neg_U64:
 	case .U8, .U16, .U32, .U64, .Neg_U8, .Neg_U16, .Neg_U32, .Neg_U64:
 		switch &dst in v {
 		switch &dst in v {
 		case time.Time:
 		case time.Time:
 			i: i64
 			i: i64
-			_unmarshal_any_ptr(r, &i, hdr) or_return
+			_unmarshal_any_ptr(d, &i, hdr) or_return
 			dst = time.unix(i64(i), 0)
 			dst = time.unix(i64(i), 0)
 			return
 			return
 		case:
 		case:
-			return _unmarshal_value(r, v, hdr)
+			return _unmarshal_value(d, v, hdr)
 		}
 		}
 
 
 	case .F16, .F32, .F64:
 	case .F16, .F32, .F64:
 		switch &dst in v {
 		switch &dst in v {
 		case time.Time:
 		case time.Time:
 			f: f64
 			f: f64
-			_unmarshal_any_ptr(r, &f, hdr) or_return
+			_unmarshal_any_ptr(d, &f, hdr) or_return
 			whole, fract := math.modf(f)
 			whole, fract := math.modf(f)
 			dst = time.unix(i64(whole), i64(fract * 1e9))
 			dst = time.unix(i64(whole), i64(fract * 1e9))
 			return
 			return
 		case:
 		case:
-			return _unmarshal_value(r, v, hdr)
+			return _unmarshal_value(d, v, hdr)
 		}
 		}
 
 
 	case:
 	case:
@@ -182,8 +182,8 @@ tag_time_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_
 }
 }
 
 
 @(private)
 @(private)
-tag_big_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, tnr: Tag_Number, v: any) -> (err: Unmarshal_Error) {
-	hdr := _decode_header(r) or_return
+tag_big_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, tnr: Tag_Number, v: any) -> (err: Unmarshal_Error) {
+	hdr := _decode_header(d.reader) or_return
 	maj, add := _header_split(hdr)
 	maj, add := _header_split(hdr)
 	if maj != .Bytes {
 	if maj != .Bytes {
 		// Only bytes are supported in this tag.
 		// Only bytes are supported in this tag.
@@ -192,7 +192,7 @@ tag_big_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, tnr: Tag_Number,
 
 
 	switch &dst in v {
 	switch &dst in v {
 	case big.Int:
 	case big.Int:
-		bytes := err_conv(_decode_bytes(r, add)) or_return
+		bytes := err_conv(_decode_bytes(d, add)) or_return
 		defer delete(bytes)
 		defer delete(bytes)
 
 
 		if err := big.int_from_bytes_big(&dst, bytes); err != nil {
 		if err := big.int_from_bytes_big(&dst, bytes); err != nil {
@@ -246,13 +246,13 @@ tag_big_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_E
 }
 }
 
 
 @(private)
 @(private)
-tag_cbor_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> Unmarshal_Error {
-	hdr := _decode_header(r) or_return
+tag_cbor_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, v: any) -> Unmarshal_Error {
+	hdr := _decode_header(d.reader) or_return
 	major, add := _header_split(hdr)
 	major, add := _header_split(hdr)
 	#partial switch major {
 	#partial switch major {
 	case .Bytes:
 	case .Bytes:
 		ti := reflect.type_info_base(type_info_of(v.id))
 		ti := reflect.type_info_base(type_info_of(v.id))
-		return _unmarshal_bytes(r, v, ti, hdr, add)
+		return _unmarshal_bytes(d, v, ti, hdr, add)
 		
 		
 	case: return .Bad_Tag_Value
 	case: return .Bad_Tag_Value
 	}
 	}
@@ -283,8 +283,8 @@ tag_cbor_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_
 }
 }
 
 
 @(private)
 @(private)
-tag_base64_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
-	hdr        := _decode_header(r) or_return
+tag_base64_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
+	hdr        := _decode_header(d.reader) or_return
 	major, add := _header_split(hdr)
 	major, add := _header_split(hdr)
 	ti         := reflect.type_info_base(type_info_of(v.id))
 	ti         := reflect.type_info_base(type_info_of(v.id))
 
 
@@ -294,7 +294,7 @@ tag_base64_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number
 
 
 	bytes: string; {
 	bytes: string; {
 		context.allocator = context.temp_allocator
 		context.allocator = context.temp_allocator
-		bytes = string(err_conv(_decode_bytes(r, add)) or_return)
+		bytes = string(err_conv(_decode_bytes(d, add)) or_return)
 	}
 	}
 	defer delete(bytes, context.temp_allocator)
 	defer delete(bytes, context.temp_allocator)
 
 

+ 143 - 103
core/encoding/cbor/unmarshal.odin

@@ -15,25 +15,56 @@ Types that require allocation are allocated using the given allocator.
 Some temporary allocations are done on the `context.temp_allocator`, but, if you want to,
 Some temporary allocations are done on the `context.temp_allocator`, but, if you want to,
 this can be set to a "normal" allocator, because the necessary `delete` and `free` calls are still made.
 this can be set to a "normal" allocator, because the necessary `delete` and `free` calls are still made.
 This is helpful when the CBOR size is so big that you don't want to collect all the temporary allocations until the end.
 This is helpful when the CBOR size is so big that you don't want to collect all the temporary allocations until the end.
+
+Disable streaming/indeterminate lengths with the `.Disallow_Streaming` flag.
+
+Shrink excess bytes in buffers and containers with the `.Shrink_Excess` flag.
+
+Mark the input as trusted input with the `.Trusted_Input` flag, this turns off the safety feature
+of not pre-allocating more than `max_pre_alloc` bytes before reading into the bytes. You should only
+do this when you own both sides of the encoding and are sure there can't be malicious bytes used as
+an input.
 */
 */
 unmarshal :: proc {
 unmarshal :: proc {
 	unmarshal_from_reader,
 	unmarshal_from_reader,
 	unmarshal_from_string,
 	unmarshal_from_string,
 }
 }
 
 
-// Unmarshals from a reader, see docs on the proc group `Unmarshal` for more info.
-unmarshal_from_reader :: proc(r: io.Reader, ptr: ^$T, allocator := context.allocator) -> Unmarshal_Error {
-	return _unmarshal_any_ptr(r, ptr, allocator=allocator)
+unmarshal_from_reader :: proc(r: io.Reader, ptr: ^$T, flags := Decoder_Flags{}, allocator := context.allocator) -> (err: Unmarshal_Error) {
+	err = unmarshal_from_decoder(Decoder{ DEFAULT_MAX_PRE_ALLOC, flags, r }, ptr, allocator=allocator)
+
+	// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
+	if err == .EOF { err = .Unexpected_EOF }
+	return
 }
 }
 
 
 // Unmarshals from a string, see docs on the proc group `Unmarshal` for more info.
 // Unmarshals from a string, see docs on the proc group `Unmarshal` for more info.
-unmarshal_from_string :: proc(s: string, ptr: ^$T, allocator := context.allocator) -> Unmarshal_Error {
+unmarshal_from_string :: proc(s: string, ptr: ^$T, flags := Decoder_Flags{}, allocator := context.allocator) -> (err: Unmarshal_Error) {
 	sr: strings.Reader
 	sr: strings.Reader
 	r := strings.to_reader(&sr, s)
 	r := strings.to_reader(&sr, s)
-	return _unmarshal_any_ptr(r, ptr, allocator=allocator)
+
+	err = unmarshal_from_reader(r, ptr, flags, allocator)
+
+	// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
+	if err == .EOF { err = .Unexpected_EOF }
+	return
 }
 }
 
 
-_unmarshal_any_ptr :: proc(r: io.Reader, v: any, hdr: Maybe(Header) = nil, allocator := context.allocator) -> Unmarshal_Error {
+unmarshal_from_decoder :: proc(d: Decoder, ptr: ^$T, allocator := context.allocator) -> (err: Unmarshal_Error) {
+	d := d
+	if d.max_pre_alloc <= 0 {
+		d.max_pre_alloc = DEFAULT_MAX_PRE_ALLOC
+	}
+
+	err = _unmarshal_any_ptr(d, ptr, allocator=allocator)
+
+	// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
+	if err == .EOF { err = .Unexpected_EOF }
+	return
+
+}
+
+_unmarshal_any_ptr :: proc(d: Decoder, v: any, hdr: Maybe(Header) = nil, allocator := context.allocator) -> Unmarshal_Error {
 	context.allocator = allocator
 	context.allocator = allocator
 	v := v
 	v := v
 
 
@@ -48,12 +79,13 @@ _unmarshal_any_ptr :: proc(r: io.Reader, v: any, hdr: Maybe(Header) = nil, alloc
 	}
 	}
 	
 	
 	data := any{(^rawptr)(v.data)^, ti.variant.(reflect.Type_Info_Pointer).elem.id}	
 	data := any{(^rawptr)(v.data)^, ti.variant.(reflect.Type_Info_Pointer).elem.id}	
-	return _unmarshal_value(r, data, hdr.? or_else (_decode_header(r) or_return))
+	return _unmarshal_value(d, data, hdr.? or_else (_decode_header(d.reader) or_return))
 }
 }
 
 
-_unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_Error) {
+_unmarshal_value :: proc(d: Decoder, v: any, hdr: Header) -> (err: Unmarshal_Error) {
 	v := v
 	v := v
 	ti := reflect.type_info_base(type_info_of(v.id))
 	ti := reflect.type_info_base(type_info_of(v.id))
+	r := d.reader
 
 
 	// If it's a union with only one variant, then treat it as that variant
 	// If it's a union with only one variant, then treat it as that variant
 	if u, ok := ti.variant.(reflect.Type_Info_Union); ok && len(u.variants) == 1 {
 	if u, ok := ti.variant.(reflect.Type_Info_Union); ok && len(u.variants) == 1 {
@@ -73,7 +105,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
 	// Allow generic unmarshal by doing it into a `Value`.
 	// Allow generic unmarshal by doing it into a `Value`.
 	switch &dst in v {
 	switch &dst in v {
 	case Value:
 	case Value:
-		dst = err_conv(decode(r, hdr)) or_return
+		dst = err_conv(_decode_from_decoder(d, hdr)) or_return
 		return
 		return
 	}
 	}
 
 
@@ -253,7 +285,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
 	case .Tag:
 	case .Tag:
 		switch &dst in v {
 		switch &dst in v {
 		case ^Tag:
 		case ^Tag:
-			tval := err_conv(_decode_tag_ptr(r, add)) or_return
+			tval := err_conv(_decode_tag_ptr(d, add)) or_return
 			if t, is_tag := tval.(^Tag); is_tag {
 			if t, is_tag := tval.(^Tag); is_tag {
 				dst = t
 				dst = t
 				return
 				return
@@ -262,7 +294,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
 			destroy(tval)
 			destroy(tval)
 			return .Bad_Tag_Value
 			return .Bad_Tag_Value
 		case Tag:
 		case Tag:
-			t := err_conv(_decode_tag(r, add)) or_return
+			t := err_conv(_decode_tag(d, add)) or_return
 			if t, is_tag := t.?; is_tag {
 			if t, is_tag := t.?; is_tag {
 				dst = t
 				dst = t
 				return
 				return
@@ -271,33 +303,33 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
 			return .Bad_Tag_Value
 			return .Bad_Tag_Value
 		}
 		}
 
 
-		nr := err_conv(_decode_tag_nr(r, add)) or_return
+		nr := err_conv(_decode_uint_as_u64(r, add)) or_return
 
 
 		// Custom tag implementations.
 		// Custom tag implementations.
 		if impl, ok := _tag_implementations_nr[nr]; ok {
 		if impl, ok := _tag_implementations_nr[nr]; ok {
-			return impl->unmarshal(r, nr, v)
+			return impl->unmarshal(d, nr, v)
 		} else if nr == TAG_OBJECT_TYPE {
 		} else if nr == TAG_OBJECT_TYPE {
-			return _unmarshal_union(r, v, ti, hdr)
+			return _unmarshal_union(d, v, ti, hdr)
 		} else {
 		} else {
 			// Discard the tag info and unmarshal as its value.
 			// Discard the tag info and unmarshal as its value.
-			return _unmarshal_value(r, v, _decode_header(r) or_return)
+			return _unmarshal_value(d, v, _decode_header(r) or_return)
 		}
 		}
 
 
 		return _unsupported(v, hdr, add)
 		return _unsupported(v, hdr, add)
 
 
-	case .Bytes: return _unmarshal_bytes(r, v, ti, hdr, add)
-	case .Text:  return _unmarshal_string(r, v, ti, hdr, add)
-	case .Array: return _unmarshal_array(r, v, ti, hdr, add)
-	case .Map:   return _unmarshal_map(r, v, ti, hdr, add)
+	case .Bytes: return _unmarshal_bytes(d, v, ti, hdr, add)
+	case .Text:  return _unmarshal_string(d, v, ti, hdr, add)
+	case .Array: return _unmarshal_array(d, v, ti, hdr, add)
+	case .Map:   return _unmarshal_map(d, v, ti, hdr, add)
 
 
 	case:        return .Bad_Major
 	case:        return .Bad_Major
 	}
 	}
 }
 }
 
 
-_unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
+_unmarshal_bytes :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
 	#partial switch t in ti.variant {
 	#partial switch t in ti.variant {
 	case reflect.Type_Info_String:
 	case reflect.Type_Info_String:
-		bytes := err_conv(_decode_bytes(r, add)) or_return
+		bytes := err_conv(_decode_bytes(d, add)) or_return
 
 
 		if t.is_cstring {
 		if t.is_cstring {
 			raw  := (^cstring)(v.data)
 			raw  := (^cstring)(v.data)
@@ -316,7 +348,7 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 
 
 		if elem_base.id != byte { return _unsupported(v, hdr) }
 		if elem_base.id != byte { return _unsupported(v, hdr) }
 
 
-		bytes := err_conv(_decode_bytes(r, add)) or_return
+		bytes := err_conv(_decode_bytes(d, add)) or_return
 		raw   := (^mem.Raw_Slice)(v.data)
 		raw   := (^mem.Raw_Slice)(v.data)
 		raw^   = transmute(mem.Raw_Slice)bytes
 		raw^   = transmute(mem.Raw_Slice)bytes
 		return
 		return
@@ -326,7 +358,7 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 
 
 		if elem_base.id != byte { return _unsupported(v, hdr) }
 		if elem_base.id != byte { return _unsupported(v, hdr) }
 		
 		
-		bytes         := err_conv(_decode_bytes(r, add)) or_return
+		bytes         := err_conv(_decode_bytes(d, add)) or_return
 		raw           := (^mem.Raw_Dynamic_Array)(v.data)
 		raw           := (^mem.Raw_Dynamic_Array)(v.data)
 		raw.data       = raw_data(bytes)
 		raw.data       = raw_data(bytes)
 		raw.len        = len(bytes)
 		raw.len        = len(bytes)
@@ -339,11 +371,9 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 
 
 		if elem_base.id != byte { return _unsupported(v, hdr) }
 		if elem_base.id != byte { return _unsupported(v, hdr) }
 
 
-		bytes: []byte; {
-			context.allocator = context.temp_allocator
-			bytes = err_conv(_decode_bytes(r, add)) or_return
-		}
-		defer delete(bytes, context.temp_allocator)
+		context.allocator = context.temp_allocator
+		bytes := err_conv(_decode_bytes(d, add)) or_return
+		defer delete(bytes)
 
 
 		if len(bytes) > t.count { return _unsupported(v, hdr) }
 		if len(bytes) > t.count { return _unsupported(v, hdr) }
 		
 		
@@ -357,10 +387,10 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 	return _unsupported(v, hdr)
 	return _unsupported(v, hdr)
 }
 }
 
 
-_unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
+_unmarshal_string :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
 	#partial switch t in ti.variant {
 	#partial switch t in ti.variant {
 	case reflect.Type_Info_String:
 	case reflect.Type_Info_String:
-		text := err_conv(_decode_text(r, add)) or_return
+		text := err_conv(_decode_text(d, add)) or_return
 
 
 		if t.is_cstring {
 		if t.is_cstring {
 			raw := (^cstring)(v.data)
 			raw := (^cstring)(v.data)
@@ -376,8 +406,8 @@ _unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Hea
 	// Enum by its variant name.
 	// Enum by its variant name.
 	case reflect.Type_Info_Enum:
 	case reflect.Type_Info_Enum:
 		context.allocator = context.temp_allocator
 		context.allocator = context.temp_allocator
-		text := err_conv(_decode_text(r, add)) or_return
-		defer delete(text, context.temp_allocator)
+		text := err_conv(_decode_text(d, add)) or_return
+		defer delete(text)
 
 
 		for name, i in t.names {
 		for name, i in t.names {
 			if name == text {
 			if name == text {
@@ -388,8 +418,8 @@ _unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Hea
 	
 	
 	case reflect.Type_Info_Rune:
 	case reflect.Type_Info_Rune:
 		context.allocator = context.temp_allocator
 		context.allocator = context.temp_allocator
-		text := err_conv(_decode_text(r, add)) or_return
-		defer delete(text, context.temp_allocator)
+		text := err_conv(_decode_text(d, add)) or_return
+		defer delete(text)
 
 
 		r := (^rune)(v.data)
 		r := (^rune)(v.data)
 		dr, n := utf8.decode_rune(text)
 		dr, n := utf8.decode_rune(text)
@@ -404,21 +434,19 @@ _unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Hea
 	return _unsupported(v, hdr)
 	return _unsupported(v, hdr)
 }
 }
 
 
-_unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
-
+_unmarshal_array :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
 	assign_array :: proc(
 	assign_array :: proc(
-		r: io.Reader,
+		d: Decoder,
 		da: ^mem.Raw_Dynamic_Array,
 		da: ^mem.Raw_Dynamic_Array,
 		elemt: ^reflect.Type_Info,
 		elemt: ^reflect.Type_Info,
-		_length: Maybe(int),
+		length: int,
 		growable := true,
 		growable := true,
 	) -> (out_of_space: bool, err: Unmarshal_Error) {
 	) -> (out_of_space: bool, err: Unmarshal_Error) {
-		length, has_length := _length.?
-		for idx: uintptr = 0; !has_length || idx < uintptr(length); idx += 1 {
+		for idx: uintptr = 0; length == -1 || idx < uintptr(length); idx += 1 {
 			elem_ptr := rawptr(uintptr(da.data) + idx*uintptr(elemt.size))
 			elem_ptr := rawptr(uintptr(da.data) + idx*uintptr(elemt.size))
 			elem     := any{elem_ptr, elemt.id}
 			elem     := any{elem_ptr, elemt.id}
 
 
-			hdr := _decode_header(r) or_return
+			hdr := _decode_header(d.reader) or_return
 			
 			
 			// Double size if out of capacity.
 			// Double size if out of capacity.
 			if da.cap <= da.len {
 			if da.cap <= da.len {
@@ -432,8 +460,8 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 				if !ok { return false, .Out_Of_Memory }
 				if !ok { return false, .Out_Of_Memory }
 			}
 			}
 			
 			
-			err = _unmarshal_value(r, elem, hdr)
-			if !has_length && err == .Break { break }
+			err = _unmarshal_value(d, elem, hdr)
+			if length == -1 && err == .Break { break }
 			if err != nil { return }
 			if err != nil { return }
 
 
 			da.len += 1
 			da.len += 1
@@ -445,26 +473,25 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 	// Allow generically storing the values array.
 	// Allow generically storing the values array.
 	switch &dst in v {
 	switch &dst in v {
 	case ^Array:
 	case ^Array:
-		dst = err_conv(_decode_array_ptr(r, add)) or_return
+		dst = err_conv(_decode_array_ptr(d, add)) or_return
 		return
 		return
 	case Array:
 	case Array:
-		dst = err_conv(_decode_array(r, add)) or_return
+		dst = err_conv(_decode_array(d, add)) or_return
 		return
 		return
 	}
 	}
 
 
 	#partial switch t in ti.variant {
 	#partial switch t in ti.variant {
 	case reflect.Type_Info_Slice:
 	case reflect.Type_Info_Slice:
-		_length, unknown := err_conv(_decode_container_length(r, add)) or_return
-		length := _length.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
+		length, scap := err_conv(_decode_len_container(d, add)) or_return
 
 
-		data := mem.alloc_bytes_non_zeroed(t.elem.size * length, t.elem.align) or_return
+		data := mem.alloc_bytes_non_zeroed(t.elem.size * scap, t.elem.align) or_return
 		defer if err != nil { mem.free_bytes(data) }
 		defer if err != nil { mem.free_bytes(data) }
 
 
 		da := mem.Raw_Dynamic_Array{raw_data(data), 0, length, context.allocator }
 		da := mem.Raw_Dynamic_Array{raw_data(data), 0, length, context.allocator }
 
 
-		assign_array(r, &da, t.elem, _length) or_return
+		assign_array(d, &da, t.elem, length) or_return
 
 
-		if da.len < da.cap {
+		if .Shrink_Excess in d.flags {
 			// Ignoring an error here, but this is not critical to succeed.
 			// Ignoring an error here, but this is not critical to succeed.
 			_ = runtime.__dynamic_array_shrink(&da, t.elem.size, t.elem.align, da.len)
 			_ = runtime.__dynamic_array_shrink(&da, t.elem.size, t.elem.align, da.len)
 		}
 		}
@@ -475,54 +502,58 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 		return
 		return
 
 
 	case reflect.Type_Info_Dynamic_Array:
 	case reflect.Type_Info_Dynamic_Array:
-		_length, unknown := err_conv(_decode_container_length(r, add)) or_return
-		length := _length.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
+		length, scap := err_conv(_decode_len_container(d, add)) or_return
 
 
-		data := mem.alloc_bytes_non_zeroed(t.elem.size * length, t.elem.align) or_return
+		data := mem.alloc_bytes_non_zeroed(t.elem.size * scap, t.elem.align) or_return
 		defer if err != nil { mem.free_bytes(data) }
 		defer if err != nil { mem.free_bytes(data) }
 
 
-		raw := (^mem.Raw_Dynamic_Array)(v.data)
-		raw.data = raw_data(data) 
-		raw.len = 0
-		raw.cap = length
-		raw.allocator = context.allocator
+		raw           := (^mem.Raw_Dynamic_Array)(v.data)
+		raw.data       = raw_data(data) 
+		raw.len        = 0
+		raw.cap        = length
+		raw.allocator  = context.allocator
+
+		_ = assign_array(d, raw, t.elem, length) or_return
 
 
-		_ = assign_array(r, raw, t.elem, _length) or_return
+		if .Shrink_Excess in d.flags {
+			// Ignoring an error here, but this is not critical to succeed.
+			_ = runtime.__dynamic_array_shrink(raw, t.elem.size, t.elem.align, raw.len)
+		}
 		return
 		return
 
 
 	case reflect.Type_Info_Array:
 	case reflect.Type_Info_Array:
-		_length, unknown := err_conv(_decode_container_length(r, add)) or_return
-		length := _length.? or_else t.count
+		_length, scap := err_conv(_decode_len_container(d, add)) or_return
+		length := min(scap, t.count)
 	
 	
-		if !unknown && length > t.count {
+		if length > t.count {
 			return _unsupported(v, hdr)
 			return _unsupported(v, hdr)
 		}
 		}
 
 
 		da := mem.Raw_Dynamic_Array{rawptr(v.data), 0, length, context.allocator }
 		da := mem.Raw_Dynamic_Array{rawptr(v.data), 0, length, context.allocator }
 
 
-		out_of_space := assign_array(r, &da, t.elem, _length, growable=false) or_return
+		out_of_space := assign_array(d, &da, t.elem, length, growable=false) or_return
 		if out_of_space { return _unsupported(v, hdr) }
 		if out_of_space { return _unsupported(v, hdr) }
 		return
 		return
 
 
 	case reflect.Type_Info_Enumerated_Array:
 	case reflect.Type_Info_Enumerated_Array:
-		_length, unknown := err_conv(_decode_container_length(r, add)) or_return
-		length := _length.? or_else t.count
+		_length, scap := err_conv(_decode_len_container(d, add)) or_return
+		length := min(scap, t.count)
 	
 	
-		if !unknown && length > t.count {
+		if length > t.count {
 			return _unsupported(v, hdr)
 			return _unsupported(v, hdr)
 		}
 		}
 
 
 		da := mem.Raw_Dynamic_Array{rawptr(v.data), 0, length, context.allocator }
 		da := mem.Raw_Dynamic_Array{rawptr(v.data), 0, length, context.allocator }
 
 
-		out_of_space := assign_array(r, &da, t.elem, _length, growable=false) or_return
+		out_of_space := assign_array(d, &da, t.elem, length, growable=false) or_return
 		if out_of_space { return _unsupported(v, hdr) }
 		if out_of_space { return _unsupported(v, hdr) }
 		return
 		return
 
 
 	case reflect.Type_Info_Complex:
 	case reflect.Type_Info_Complex:
-		_length, unknown := err_conv(_decode_container_length(r, add)) or_return
-		length := _length.? or_else 2
+		_length, scap := err_conv(_decode_len_container(d, add)) or_return
+		length := min(scap, 2)
 	
 	
-		if !unknown && length > 2 {
+		if length > 2 {
 			return _unsupported(v, hdr)
 			return _unsupported(v, hdr)
 		}
 		}
 
 
@@ -536,15 +567,15 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 		case:            unreachable()
 		case:            unreachable()
 		}
 		}
 
 
-		out_of_space := assign_array(r, &da, info, 2, growable=false) or_return
+		out_of_space := assign_array(d, &da, info, 2, growable=false) or_return
 		if out_of_space { return _unsupported(v, hdr) }
 		if out_of_space { return _unsupported(v, hdr) }
 		return
 		return
 	
 	
 	case reflect.Type_Info_Quaternion:
 	case reflect.Type_Info_Quaternion:
-		_length, unknown := err_conv(_decode_container_length(r, add)) or_return
-		length := _length.? or_else 4
+		_length, scap := err_conv(_decode_len_container(d, add)) or_return
+		length := min(scap, 4)
 	
 	
-		if !unknown && length > 4 {
+		if length > 4 {
 			return _unsupported(v, hdr)
 			return _unsupported(v, hdr)
 		}
 		}
 
 
@@ -558,7 +589,7 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 		case:            unreachable()
 		case:            unreachable()
 		}
 		}
 
 
-		out_of_space := assign_array(r, &da, info, 4, growable=false) or_return
+		out_of_space := assign_array(d, &da, info, 4, growable=false) or_return
 		if out_of_space { return _unsupported(v, hdr) }
 		if out_of_space { return _unsupported(v, hdr) }
 		return
 		return
 
 
@@ -566,17 +597,17 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 	}
 	}
 }
 }
 
 
-_unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
-
-	decode_key :: proc(r: io.Reader, v: any) -> (k: string, err: Unmarshal_Error) {
-		entry_hdr := _decode_header(r) or_return
+_unmarshal_map :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
+	r := d.reader
+	decode_key :: proc(d: Decoder, v: any) -> (k: string, err: Unmarshal_Error) {
+		entry_hdr := _decode_header(d.reader) or_return
 		entry_maj, entry_add := _header_split(entry_hdr)
 		entry_maj, entry_add := _header_split(entry_hdr)
 		#partial switch entry_maj {
 		#partial switch entry_maj {
 		case .Text:
 		case .Text:
-			k = err_conv(_decode_text(r, entry_add)) or_return
+			k = err_conv(_decode_text(d, entry_add)) or_return
 			return
 			return
 		case .Bytes:
 		case .Bytes:
-			bytes := err_conv(_decode_bytes(r, entry_add)) or_return
+			bytes := err_conv(_decode_bytes(d, entry_add)) or_return
 			k = string(bytes)
 			k = string(bytes)
 			return
 			return
 		case:
 		case:
@@ -588,10 +619,10 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 	// Allow generically storing the map array.
 	// Allow generically storing the map array.
 	switch &dst in v {
 	switch &dst in v {
 	case ^Map:
 	case ^Map:
-		dst = err_conv(_decode_map_ptr(r, add)) or_return
+		dst = err_conv(_decode_map_ptr(d, add)) or_return
 		return
 		return
 	case Map:
 	case Map:
-		dst = err_conv(_decode_map(r, add)) or_return
+		dst = err_conv(_decode_map(d, add)) or_return
 		return
 		return
 	}
 	}
 
 
@@ -601,14 +632,15 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 			return _unsupported(v, hdr)
 			return _unsupported(v, hdr)
 		}
 		}
 
 
-		length, unknown := err_conv(_decode_container_length(r, add)) or_return
+		length, scap := err_conv(_decode_len_container(d, add)) or_return
+		unknown := length == -1
 		fields := reflect.struct_fields_zipped(ti.id)
 		fields := reflect.struct_fields_zipped(ti.id)
 	
 	
-		for idx := 0; unknown || idx < length.?; idx += 1 {
+		for idx := 0; idx < len(fields) && (unknown || idx < length); idx += 1 {
 			// Decode key, keys can only be strings.
 			// Decode key, keys can only be strings.
 			key: string; {
 			key: string; {
 				context.allocator = context.temp_allocator
 				context.allocator = context.temp_allocator
-				if keyv, kerr := decode_key(r, v); unknown && kerr == .Break {
+				if keyv, kerr := decode_key(d, v); unknown && kerr == .Break {
 					break
 					break
 				} else if kerr != nil {
 				} else if kerr != nil {
 					err = kerr
 					err = kerr
@@ -641,11 +673,11 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 				}
 				}
 			}
 			}
 
 
-			field  := fields[use_field_idx]
-			name   := field.name
-			ptr    := rawptr(uintptr(v.data) + field.offset)
-			fany   := any{ptr, field.type.id}
-			_unmarshal_value(r, fany, _decode_header(r) or_return) or_return
+			field := fields[use_field_idx]
+			name  := field.name
+			ptr   := rawptr(uintptr(v.data) + field.offset)
+			fany  := any{ptr, field.type.id}
+			_unmarshal_value(d, fany, _decode_header(r) or_return) or_return
 		}
 		}
 		return
 		return
 
 
@@ -654,6 +686,8 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 			return _unsupported(v, hdr)
 			return _unsupported(v, hdr)
 		}
 		}
 
 
+		// TODO: shrink excess.
+
 		raw_map := (^mem.Raw_Map)(v.data)
 		raw_map := (^mem.Raw_Map)(v.data)
 		if raw_map.allocator.procedure == nil {
 		if raw_map.allocator.procedure == nil {
 			raw_map.allocator = context.allocator
 			raw_map.allocator = context.allocator
@@ -663,10 +697,11 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 			_ = runtime.map_free_dynamic(raw_map^, t.map_info)
 			_ = runtime.map_free_dynamic(raw_map^, t.map_info)
 		}
 		}
 
 
-		length, unknown := err_conv(_decode_container_length(r, add)) or_return
+		length, scap := err_conv(_decode_len_container(d, add)) or_return
+		unknown := length == -1
 		if !unknown {
 		if !unknown {
 			// Reserve space before setting so we can return allocation errors and be efficient on big maps.
 			// Reserve space before setting so we can return allocation errors and be efficient on big maps.
-			new_len := uintptr(runtime.map_len(raw_map^)+length.?)
+			new_len := uintptr(min(scap, runtime.map_len(raw_map^)+length))
 			runtime.map_reserve_dynamic(raw_map, t.map_info, new_len) or_return
 			runtime.map_reserve_dynamic(raw_map, t.map_info, new_len) or_return
 		}
 		}
 		
 		
@@ -676,10 +711,10 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 
 
 		map_backing_value := any{raw_data(elem_backing), t.value.id}
 		map_backing_value := any{raw_data(elem_backing), t.value.id}
 
 
-		for idx := 0; unknown || idx < length.?; idx += 1 {
+		for idx := 0; unknown || idx < length; idx += 1 {
 			// Decode key, keys can only be strings.
 			// Decode key, keys can only be strings.
 			key: string
 			key: string
-			if keyv, kerr := decode_key(r, v); unknown && kerr == .Break {
+			if keyv, kerr := decode_key(d, v); unknown && kerr == .Break {
 				break
 				break
 			} else if kerr != nil {
 			} else if kerr != nil {
 				err = kerr
 				err = kerr
@@ -688,14 +723,14 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 				key = keyv
 				key = keyv
 			}
 			}
 
 
-			if unknown {
+			if unknown || idx > scap {
 				// Reserve space for new element so we can return allocator errors.
 				// Reserve space for new element so we can return allocator errors.
 				new_len := uintptr(runtime.map_len(raw_map^)+1)
 				new_len := uintptr(runtime.map_len(raw_map^)+1)
 				runtime.map_reserve_dynamic(raw_map, t.map_info, new_len) or_return
 				runtime.map_reserve_dynamic(raw_map, t.map_info, new_len) or_return
 			}
 			}
 
 
 			mem.zero_slice(elem_backing)
 			mem.zero_slice(elem_backing)
-			_unmarshal_value(r, map_backing_value, _decode_header(r) or_return) or_return
+			_unmarshal_value(d, map_backing_value, _decode_header(r) or_return) or_return
 
 
 			key_ptr := rawptr(&key)
 			key_ptr := rawptr(&key)
 			key_cstr: cstring
 			key_cstr: cstring
@@ -709,6 +744,10 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 			// We already reserved space for it, so this shouldn't fail.
 			// We already reserved space for it, so this shouldn't fail.
 			assert(set_ptr != nil)
 			assert(set_ptr != nil)
 		}
 		}
+	
+		if .Shrink_Excess in d.flags {
+			_, _ = runtime.map_shrink_dynamic(raw_map, t.map_info)
+		}
 		return
 		return
 
 
 		case:
 		case:
@@ -719,7 +758,8 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
 // Unmarshal into a union, based on the `TAG_OBJECT_TYPE` tag of the spec, it denotes a tag which
 // Unmarshal into a union, based on the `TAG_OBJECT_TYPE` tag of the spec, it denotes a tag which
 // contains an array of exactly two elements, the first is a textual representation of the following
 // contains an array of exactly two elements, the first is a textual representation of the following
 // CBOR value's type.
 // CBOR value's type.
-_unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header) -> (err: Unmarshal_Error) {
+_unmarshal_union :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header) -> (err: Unmarshal_Error) {
+	r := d.reader
 	#partial switch t in ti.variant {
 	#partial switch t in ti.variant {
 	case reflect.Type_Info_Union:
 	case reflect.Type_Info_Union:
 		idhdr: Header
 		idhdr: Header
@@ -731,8 +771,8 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 				return .Bad_Tag_Value
 				return .Bad_Tag_Value
 			}
 			}
 
 
-			n_items, unknown := err_conv(_decode_container_length(r, vadd)) or_return
-			if unknown || n_items != 2 {
+			n_items, _ := err_conv(_decode_len_container(d, vadd)) or_return
+			if n_items != 2 {
 				return .Bad_Tag_Value
 				return .Bad_Tag_Value
 			}
 			}
 			
 			
@@ -743,7 +783,7 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 			}
 			}
 
 
 			context.allocator = context.temp_allocator
 			context.allocator = context.temp_allocator
-			target_name = err_conv(_decode_text(r, idadd)) or_return
+			target_name = err_conv(_decode_text(d, idadd)) or_return
 		}
 		}
 		defer delete(target_name, context.temp_allocator)
 		defer delete(target_name, context.temp_allocator)
 
 
@@ -757,7 +797,7 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 			case reflect.Type_Info_Named:
 			case reflect.Type_Info_Named:
 				if vti.name == target_name {
 				if vti.name == target_name {
 					reflect.set_union_variant_raw_tag(v, tag)
 					reflect.set_union_variant_raw_tag(v, tag)
-					return _unmarshal_value(r, any{v.data, variant.id}, _decode_header(r) or_return)
+					return _unmarshal_value(d, any{v.data, variant.id}, _decode_header(r) or_return)
 				}
 				}
 
 
 			case:
 			case:
@@ -769,7 +809,7 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
 				
 				
 				if variant_name == target_name {
 				if variant_name == target_name {
 					reflect.set_union_variant_raw_tag(v, tag)
 					reflect.set_union_variant_raw_tag(v, tag)
-					return _unmarshal_value(r, any{v.data, variant.id}, _decode_header(r) or_return)
+					return _unmarshal_value(d, any{v.data, variant.id}, _decode_header(r) or_return)
 				}
 				}
 			}
 			}
 		}
 		}

+ 14 - 3
tests/core/encoding/cbor/test_core_cbor.odin

@@ -4,6 +4,7 @@ import "core:bytes"
 import "core:encoding/cbor"
 import "core:encoding/cbor"
 import "core:fmt"
 import "core:fmt"
 import "core:intrinsics"
 import "core:intrinsics"
+import "core:io"
 import "core:math/big"
 import "core:math/big"
 import "core:mem"
 import "core:mem"
 import "core:os"
 import "core:os"
@@ -61,7 +62,9 @@ main :: proc() {
 	test_marshalling_maybe(&t)
 	test_marshalling_maybe(&t)
 	test_marshalling_nil_maybe(&t)
 	test_marshalling_nil_maybe(&t)
 
 
-	test_cbor_marshalling_union(&t)
+	test_marshalling_union(&t)
+
+	test_lying_length_array(&t)
 
 
 	test_decode_unsigned(&t)
 	test_decode_unsigned(&t)
 	test_encode_unsigned(&t)
 	test_encode_unsigned(&t)
@@ -202,7 +205,7 @@ test_marshalling :: proc(t: ^testing.T) {
 		ev(t, err, nil)
 		ev(t, err, nil)
 		defer delete(data)
 		defer delete(data)
 
 
-		decoded, derr := cbor.decode_string(string(data))
+		decoded, derr := cbor.decode(string(data))
 		ev(t, derr, nil)
 		ev(t, derr, nil)
 		defer cbor.destroy(decoded)
 		defer cbor.destroy(decoded)
 
 
@@ -398,7 +401,7 @@ test_marshalling_nil_maybe :: proc(t: ^testing.T) {
 }
 }
 
 
 @(test)
 @(test)
-test_cbor_marshalling_union :: proc(t: ^testing.T) {
+test_marshalling_union :: proc(t: ^testing.T) {
 	My_Distinct :: distinct string
 	My_Distinct :: distinct string
 
 
 	My_Enum :: enum {
 	My_Enum :: enum {
@@ -457,6 +460,14 @@ test_cbor_marshalling_union :: proc(t: ^testing.T) {
 	}
 	}
 }
 }
 
 
+@(test)
+test_lying_length_array :: proc(t: ^testing.T) {
+	// Input says this is an array of length max(u64), this should not allocate that amount.
+	input := []byte{0x9B, 0x00, 0x00, 0x42, 0xFA, 0x42, 0xFA, 0x42, 0xFA, 0x42}
+	_, err := cbor.decode(string(input))
+	expect_value(t, err, io.Error.Unexpected_EOF) // .Out_Of_Memory would be bad.
+}
+
 @(test)
 @(test)
 test_decode_unsigned :: proc(t: ^testing.T) {
 test_decode_unsigned :: proc(t: ^testing.T) {
 	expect_decoding(t, "\x00", "0", u8)
 	expect_decoding(t, "\x00", "0", u8)