Browse Source

encoding/cbor: cleanup base64 tag

Laytan Laats 1 year ago
parent
commit
363769d4d3
1 changed files with 62 additions and 50 deletions
  1. 62 50
      core/encoding/cbor/tags.odin

+ 62 - 50
core/encoding/cbor/tags.odin

@@ -213,20 +213,20 @@ tag_big_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_E
 		// is uninitialized (which we checked).
 
 		is_neg, err := big.is_negative(&vv, mem.panic_allocator())
-		assert(err == nil, "only errors if not initialized, which has been checked")
+		assert(err == nil, "should only error if not initialized, which has been checked")
 		
 		tnr: u8 = TAG_NEGATIVE_BIG_NR if is_neg else TAG_UNSIGNED_BIG_NR
 		_encode_u8(e.writer, tnr, .Tag) or_return
 
 		size_in_bytes, berr := big.int_to_bytes_size(&vv, false, mem.panic_allocator())
-		assert(berr == nil, "only errors if not initialized, which has been checked")
+		assert(berr == nil, "should only error if not initialized, which has been checked")
 		assert(size_in_bytes >= 0)
 
 		err_conv(_encode_u64(e, u64(size_in_bytes), .Bytes)) or_return
 
 		for offset := (size_in_bytes*8)-8; offset >= 0; offset -= 8 {
 			bits, derr := big.int_bitfield_extract(&vv, offset, 8, mem.panic_allocator())
-			assert(derr == nil, "only errors if not initialized or invalid argument (offset and count), which won't happen")
+			assert(derr == nil, "should only error if not initialized or invalid argument (offset and count), which won't happen")
 
 			io.write_full(e.writer, {u8(bits & 255)}) or_return
 		}
@@ -273,63 +273,75 @@ tag_cbor_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_
 	}
 }
 
-// NOTE: this could probably be more efficient by decoding bytes from CBOR and then from base64 at the same time.
 @(private)
 tag_base64_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
-	hdr := _decode_header(r) or_return
+	hdr        := _decode_header(r) or_return
 	major, add := _header_split(hdr)
-	#partial switch major {
-	case .Text:
-		ti := reflect.type_info_base(type_info_of(v.id))
-		_unmarshal_bytes(r, v, ti, hdr, add) or_return
-		#partial switch t in ti.variant {
-		case runtime.Type_Info_String:
-			switch t.is_cstring {
-			case true:
-				str := string((^cstring)(v.data)^)
-				decoded := base64.decode(str) or_return
-				(^cstring)(v.data)^ = strings.clone_to_cstring(string(decoded)) or_return
-				delete(decoded)
-				delete(str)
-			case false:
-				str := (^string)(v.data)^
-				decoded := base64.decode(str) or_return
-				(^string)(v.data)^ = string(decoded)
-				delete(str)
-			}
-			return
+	ti         := reflect.type_info_base(type_info_of(v.id))
 
-		case runtime.Type_Info_Array:
-			raw := ([^]byte)(v.data)
-			decoded := base64.decode(string(raw[:t.count])) or_return
-			copy(raw[:t.count], decoded)
-			delete(decoded)
-			return
+	if major != .Text && major != .Bytes {
+		return .Bad_Tag_Value
+	}
 
-		case runtime.Type_Info_Slice:
-			raw := (^[]byte)(v.data)
-			decoded := base64.decode(string(raw^)) or_return
-			delete(raw^)
-			raw^ = decoded
-			return
+	bytes: string; {
+		context.allocator = context.temp_allocator
+		bytes = string(err_conv(_decode_bytes(r, add)) or_return)
+	}
+	defer delete(bytes, context.temp_allocator)
 
-		case runtime.Type_Info_Dynamic_Array:
-			raw := (^mem.Raw_Dynamic_Array)(v.data)
-			str := string(((^[dynamic]byte)(v.data)^)[:])
+	#partial switch t in ti.variant {
+	case reflect.Type_Info_String:
+
+		if t.is_cstring {
+			length  := base64.decoded_len(bytes)
+			builder := strings.builder_make(0, length+1)
+			base64.decode_into(strings.to_stream(&builder), bytes) or_return
+
+			raw  := (^cstring)(v.data)
+			raw^  = cstring(raw_data(builder.buf))
+		} else {
+			raw  := (^string)(v.data)
+			raw^  = string(base64.decode(bytes) or_return)
+		}
 
-			decoded := base64.decode(str) or_return
-			delete(str)
+		return
 
-			raw.data = raw_data(decoded)
-			raw.len  = len(decoded)
-			raw.cap  = len(decoded)
-			return
+	case reflect.Type_Info_Slice:
+		elem_base := reflect.type_info_base(t.elem)
 
-		case: unreachable()
-		}
+		if elem_base.id != byte { return _unsupported(v, hdr) }
 
-	case: return .Bad_Tag_Value
+		raw  := (^[]byte)(v.data)
+		raw^  = base64.decode(bytes) or_return
+		return
+		
+	case reflect.Type_Info_Dynamic_Array:
+		elem_base := reflect.type_info_base(t.elem)
+
+		if elem_base.id != byte { return _unsupported(v, hdr) }
+
+		decoded := base64.decode(bytes) or_return
+		
+		raw           := (^mem.Raw_Dynamic_Array)(v.data)
+		raw.data       = raw_data(decoded)
+		raw.len        = len(decoded)
+		raw.cap        = len(decoded)
+		raw.allocator  = context.allocator
+		return
+
+	case reflect.Type_Info_Array:
+		elem_base := reflect.type_info_base(t.elem)
+
+		if elem_base.id != byte { return _unsupported(v, hdr) }
+
+		if base64.decoded_len(bytes) > t.count { return _unsupported(v, hdr) }
+		
+		slice := ([^]byte)(v.data)[:len(bytes)]
+		copy(slice, base64.decode(bytes) or_return)
+		return
 	}
+
+	return _unsupported(v, hdr)
 }
 
 @(private)
@@ -355,7 +367,7 @@ tag_base64_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marsha
 		}
 	}
 
-	out_len := base64.encoded_length(bytes)
+	out_len := base64.encoded_len(bytes)
 	err_conv(_encode_u64(e, u64(out_len), .Text)) or_return
 	return base64.encode_into(e.writer, bytes)
 }