Pārlūkot izejas kodu

[mbedtls] Support mbedtls 3.x and fix bugs and compiler warnings. (#11646)

* [mbedtls] Store bio functions in a GC root.

* [mbedtls] Fix incorrect alt name check.

* [mbedtls] Replace String_val with Bytes_val to prevent compiler warnings.

* [mbedtls] use SecTrustCopyAnchorCertificates to get root certs on macOS.

SecKeychainOpen is deprecated.

* [mbedtls] Remove unused includes and use angled brackets.

* [mbedtls] Fix more warnings.

* [mbedtls] Support mbedtls 3.x.
Zeta 1 gadu atpakaļ
vecāks
revīzija
c3258892c3

+ 2 - 2
libs/mbedtls/mbedtls.ml

@@ -43,8 +43,8 @@ external mbedtls_ssl_setup : mbedtls_ssl_context -> mbedtls_ssl_config -> mbedtl
 external mbedtls_ssl_write : mbedtls_ssl_context -> bytes -> int -> int -> mbedtls_result = "ml_mbedtls_ssl_write"
 
 external mbedtls_pk_init : unit -> mbedtls_pk_context = "ml_mbedtls_pk_init"
-external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_key"
-external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile"
+external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_key"
+external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile"
 external mbedtls_pk_parse_public_keyfile : mbedtls_pk_context -> string -> mbedtls_result = "ml_mbedtls_pk_parse_public_keyfile"
 external mbedtls_pk_parse_public_key : mbedtls_pk_context -> bytes -> mbedtls_result = "ml_mbedtls_pk_parse_public_key"
 

+ 51 - 52
libs/mbedtls/mbedtls_stubs.c

@@ -1,4 +1,3 @@
-#include <ctype.h>
 #include <string.h>
 #include <stdio.h>
 
@@ -18,13 +17,10 @@
 #include <caml/callback.h>
 #include <caml/custom.h>
 
-#include "mbedtls/debug.h"
 #include "mbedtls/error.h"
-#include "mbedtls/config.h"
 #include "mbedtls/ssl.h"
 #include "mbedtls/entropy.h"
 #include "mbedtls/ctr_drbg.h"
-#include "mbedtls/certs.h"
 #include "mbedtls/oid.h"
 
 #define PVoid_val(v) (*((void**) Data_custom_val(v)))
@@ -84,7 +80,7 @@ CAMLprim value ml_mbedtls_ctr_drbg_init(void) {
 
 CAMLprim value ml_mbedtls_ctr_drbg_random(value p_rng, value output, value output_len) {
 	CAMLparam3(p_rng, output, output_len);
-	CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), String_val(output), Int_val(output_len))));
+	CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), Bytes_val(output), Int_val(output_len))));
 }
 
 CAMLprim value ml_mbedtls_ctr_drbg_seed(value ctx, value p_entropy, value custom) {
@@ -124,7 +120,7 @@ CAMLprim value ml_mbedtls_entropy_init(void) {
 
 CAMLprim value ml_mbedtls_entropy_func(value data, value output, value len) {
 	CAMLparam3(data, output, len);
-	CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), String_val(output), Int_val(len))));
+	CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), Bytes_val(output), Int_val(len))));
 }
 
 // Certificate
@@ -171,7 +167,7 @@ CAMLprim value ml_mbedtls_x509_next(value chain) {
 
 CAMLprim value ml_mbedtls_x509_crt_parse(value chain, value bytes) {
 	CAMLparam2(chain, bytes);
-	const char* buf = String_val(bytes);
+	const unsigned char* buf = Bytes_val(bytes);
 	int len = caml_string_length(bytes);
 	CAMLreturn(Val_int(mbedtls_x509_crt_parse(X509Crt_val(chain), buf, len + 1)));
 }
@@ -191,8 +187,7 @@ CAMLprim value ml_mbedtls_x509_crt_parse_path(value chain, value path) {
 value caml_string_of_asn1_buf(mbedtls_asn1_buf* dat) {
 	CAMLparam0();
 	CAMLlocal1(s);
-	s = caml_alloc_string(dat->len);
-	memcpy(String_val(s), dat->p, dat->len);
+	s = caml_alloc_initialized_string(dat->len, (const char *)dat->p);
 	CAMLreturn(s);
 }
 
@@ -200,7 +195,11 @@ CAMLprim value hx_cert_get_alt_names(value chain) {
 	CAMLparam1(chain);
 	CAMLlocal1(obj);
 	mbedtls_x509_crt* cert = X509Crt_val(chain);
-	if (cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME == 0 || &cert->subject_alt_names == NULL) {
+#if MBEDTLS_VERSION_MAJOR >= 3
+	if (!mbedtls_x509_crt_has_ext_type(cert, MBEDTLS_X509_EXT_SUBJECT_ALT_NAME)) {
+#else
+	if ((cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME) == 0) {
+#endif
 		obj = Atom(0);
 	} else {
 		mbedtls_asn1_sequence* cur = &cert->subject_alt_names;
@@ -366,29 +365,39 @@ CAMLprim value ml_mbedtls_pk_init(void) {
 	CAMLreturn(obj);
 }
 
-CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password) {
-	CAMLparam3(ctx, key, password);
-	const char* pwd = NULL;
+CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password, value rng) {
+	CAMLparam4(ctx, key, password, rng);
+	const unsigned char* pwd = NULL;
 	size_t pwdlen = 0;
 	if (password != Val_none) {
-		pwd = String_val(Field(password, 0));
+		pwd = Bytes_val(Field(password, 0));
 		pwdlen = caml_string_length(Field(password, 0));
 	}
-	CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1, pwd, pwdlen));
+	#if MBEDTLS_VERSION_MAJOR >= 3
+	mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng);
+	CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen, mbedtls_ctr_drbg_random, NULL));
+	#else
+	CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen));
+	#endif
 }
 
-CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password) {
-	CAMLparam3(ctx, path, password);
+CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password, value rng) {
+	CAMLparam4(ctx, path, password, rng);
 	const char* pwd = NULL;
 	if (password != Val_none) {
 		pwd = String_val(Field(password, 0));
 	}
+	#if MBEDTLS_VERSION_MAJOR >= 3
+	mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng);
+	CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd, mbedtls_ctr_drbg_random, ctr_drbg));
+	#else
 	CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd));
+	#endif
 }
 
 CAMLprim value ml_mbedtls_pk_parse_public_key(value ctx, value key) {
 	CAMLparam2(ctx, key);
-	CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1));
+	CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1));
 }
 
 CAMLprim value ml_mbedtls_pk_parse_public_keyfile(value ctx, value path) {
@@ -446,15 +455,14 @@ CAMLprim value ml_mbedtls_ssl_handshake(value ssl) {
 
 CAMLprim value ml_mbedtls_ssl_read(value ssl, value buf, value pos, value len) {
 	CAMLparam4(ssl, buf, pos, len);
-	CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len))));
+	CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len))));
 }
 
 static int bio_write_cb(void* ctx, const unsigned char* buf, size_t len) {
 	CAMLparam0();
 	CAMLlocal3(r, s, vctx);
-	vctx = (value)ctx;
-	s = caml_alloc_string(len);
-	memcpy(String_val(s), buf, len);
+	vctx = *(value*)ctx;
+	s = caml_alloc_initialized_string(len, (const char*)buf);
 	r = caml_callback2(Field(vctx, 1), Field(vctx, 0), s);
 	CAMLreturn(Int_val(r));
 }
@@ -462,7 +470,7 @@ static int bio_write_cb(void* ctx, const unsigned char* buf, size_t len) {
 static int bio_read_cb(void* ctx, unsigned char* buf, size_t len) {
 	CAMLparam0();
 	CAMLlocal3(r, s, vctx);
-	vctx = (value)ctx;
+	vctx = *(value*)ctx;
 	s = caml_alloc_string(len);
 	r = caml_callback2(Field(vctx, 2), Field(vctx, 0), s);
 	memcpy(buf, String_val(s), len);
@@ -476,7 +484,11 @@ CAMLprim value ml_mbedtls_ssl_set_bio(value ssl, value p_bio, value f_send, valu
 	Store_field(ctx, 0, p_bio);
 	Store_field(ctx, 1, f_send);
 	Store_field(ctx, 2, f_recv);
-	mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)ctx, bio_write_cb, bio_read_cb, NULL);
+	// TODO: this allocation is leaked
+	value *location = malloc(sizeof(value));
+	*location = ctx;
+	caml_register_generational_global_root(location);
+	mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)location, bio_write_cb, bio_read_cb, NULL);
 	CAMLreturn(Val_unit);
 }
 
@@ -492,7 +504,7 @@ CAMLprim value ml_mbedtls_ssl_setup(value ssl, value conf) {
 
 CAMLprim value ml_mbedtls_ssl_write(value ssl, value buf, value pos, value len) {
 	CAMLparam4(ssl, buf, pos, len);
-	CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len))));
+	CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len))));
 }
 
 // glue
@@ -520,36 +532,23 @@ CAMLprim value hx_cert_load_defaults(value certificate) {
 	#endif
 
 	#ifdef __APPLE__
-	CFMutableDictionaryRef search;
-	CFArrayRef result;
-	SecKeychainRef keychain;
-	SecCertificateRef item;
-	CFDataRef dat;
-	// Load keychain
-	if (SecKeychainOpen("/System/Library/Keychains/SystemRootCertificates.keychain", &keychain) == errSecSuccess) {
-		// Search for certificates
-		search = CFDictionaryCreateMutable(NULL, 0, NULL, NULL);
-		CFDictionarySetValue(search, kSecClass, kSecClassCertificate);
-		CFDictionarySetValue(search, kSecMatchLimit, kSecMatchLimitAll);
-		CFDictionarySetValue(search, kSecReturnRef, kCFBooleanTrue);
-		CFDictionarySetValue(search, kSecMatchSearchList, CFArrayCreate(NULL, (const void **)&keychain, 1, NULL));
-		if (SecItemCopyMatching(search, (CFTypeRef *)&result) == errSecSuccess) {
-			CFIndex n = CFArrayGetCount(result);
-			for (CFIndex i = 0; i < n; i++) {
-				item = (SecCertificateRef)CFArrayGetValueAtIndex(result, i);
-
-				// Get certificate in DER format
-				dat = SecCertificateCopyData(item);
-				if (dat) {
-					r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(dat), CFDataGetLength(dat));
-					CFRelease(dat);
-					if (r != 0) {
-						CAMLreturn(Val_int(r));
-					}
+	CFArrayRef certs;
+	if (SecTrustCopyAnchorCertificates(&certs) == errSecSuccess) {
+		CFIndex count = CFArrayGetCount(certs);
+		for(CFIndex i = 0; i < count; i++) {
+			SecCertificateRef item = (SecCertificateRef)CFArrayGetValueAtIndex(certs, i);
+
+			// Get certificate in DER format
+			CFDataRef data = SecCertificateCopyData(item);
+			if(data) {
+				r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(data), CFDataGetLength(data));
+				CFRelease(data);
+				if (r != 0) {
+					CAMLreturn(Val_int(r));
 				}
 			}
 		}
-		CFRelease(keychain);
+		CFRelease(certs);
 	}
 	#endif
 

+ 4 - 4
src/macro/eval/evalSsl.ml

@@ -160,11 +160,11 @@ let init_fields init_fields builtins =
 		"strerror",vfun1 (fun code -> encode_string (mbedtls_strerror (decode_int code)));
 	] [];
 	init_fields builtins (["mbedtls"],"PkContext") [] [
-		"parse_key",vifun2 (fun this key password ->
-			vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)));
+		"parse_key",vifun3 (fun this key password rng ->
+			vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng));
 		);
-		"parse_keyfile",vifun2 (fun this path password ->
-			vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)));
+		"parse_keyfile",vifun3 (fun this path password rng ->
+			vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng));
 		);
 		"parse_public_key",vifun1 (fun this key ->
 			vint (mbedtls_pk_parse_public_key (as_pk_context this) (decode_bytes key));

+ 2 - 2
std/eval/_std/mbedtls/PkContext.hx

@@ -5,8 +5,8 @@ import haxe.io.Bytes;
 extern class PkContext {
 	function new():Void;
 
-	function parse_key(key:Bytes, ?pwd:String):Int;
-	function parse_keyfile(path:String, ?password:String):Int;
+	function parse_key(key:Bytes, ?pwd:String, ctr_dbg: CtrDrbg):Int;
+	function parse_keyfile(path:String, ?password:String, ctr_dbg: CtrDrbg):Int;
 	function parse_public_key(key:Bytes):Int;
 	function parse_public_keyfile(path:String):Int;
 }

+ 2 - 2
std/eval/_std/sys/ssl/Key.hx

@@ -38,7 +38,7 @@ class Key {
 		var code = if (isPublic) {
 			key.native.parse_public_keyfile(file);
 		} else {
-			key.native.parse_keyfile(file, pass);
+			key.native.parse_keyfile(file, pass, Mbedtls.getDefaultCtrDrbg());
 		}
 		if (code != 0) {
 			throw(mbedtls.Error.strerror(code));
@@ -51,7 +51,7 @@ class Key {
 		var code = if (isPublic) {
 			key.native.parse_public_key(data);
 		} else {
-			key.native.parse_key(data);
+			key.native.parse_key(data, null, Mbedtls.getDefaultCtrDrbg());
 		}
 		if (code != 0) {
 			throw(mbedtls.Error.strerror(code));