Bläddra i källkod

small refactor of ssl independent of ssl socket

ncannasse 6 år sedan
förälder
incheckning
7e6c0a0a99
2 ändrade filer med 83 tillägg och 65 borttagningar
  1. 51 0
      std/hl/_std/sys/ssl/Context.hx
  2. 32 65
      std/hl/_std/sys/ssl/Socket.hx

+ 51 - 0
std/hl/_std/sys/ssl/Context.hx

@@ -0,0 +1,51 @@
+package sys.ssl;
+
+private typedef ConfigPtr = hl.Abstract<"mbedtls_ssl_config">;
+private typedef ContextPtr = hl.Abstract<"mbedtls_ssl_context">;
+
+@:keep class SNICbResult {
+	public var cert : Certificate.CertificatePtr;
+	public var key : Key.KeyPtr;
+	public function new( cert : Certificate, key : Key ){
+		this.cert = @:privateAccess cert.__x;
+		this.key = @:privateAccess key.__k;
+	}
+}
+
+@:hlNative("ssl","ssl_")
+abstract Context(ContextPtr) {
+
+	public function new(config) {
+		this = ssl_new(config);
+	}
+
+	public function close() : Void {}
+	public function handshake() : Int { return 0; }
+	public function recvChar() : Int { return 0; }
+	public function sendChar( c : Int ) : Int { return 0; }
+	public function getPeerCertificate() : Certificate.CertificatePtr { return null; }
+	public function recv( bytes : hl.Bytes, pos : Int, len : Int ) : Int { return 0; }
+	public function send( bytes : hl.Bytes, pos : Int, len : Int ) : Int { return 0; }
+	public function setSocket( socket : sys.net.Socket.SocketHandle ) : Void { }
+	public function setHostname( name : hl.Bytes ) : Void { }
+
+	@:hlNative("ssl","ssl_new") static function ssl_new( conf : Config ) : ContextPtr { return null; }
+
+}
+
+@:hlNative("ssl","conf_")
+abstract Config(ConfigPtr) {
+
+	public function new( server : Bool ) {
+		this = conf_new(server);
+	}
+
+	public function setCert( cert : Certificate.CertificatePtr, pkey : Key.KeyPtr ) : Void { }
+	public function setCa( ca : Certificate.CertificatePtr ) : Void { }
+	public function close() : Void { }
+	public function setVerify( mode : Int ) : Void { }
+	public function setServernameCallback( cb : hl.Bytes -> SNICbResult ) : Void { }
+
+	@:hlNative("ssl","conf_new") static function conf_new( server : Bool ) : ConfigPtr { return null; }
+
+}

+ 32 - 65
std/hl/_std/sys/ssl/Socket.hx

@@ -4,19 +4,6 @@ import sys.ssl.Key.KeyPtr;
 import sys.ssl.Certificate.CertificatePtr;
 import sys.net.Socket.SocketHandle;
 
-private typedef ConfigPtr = hl.Abstract<"mbedtls_ssl_config">;
-private typedef ContextPtr = hl.Abstract<"mbedtls_ssl_context">;
-
-@:keep
-private class SNICbResult {
-	public var cert : CertificatePtr;
-	public var key : KeyPtr; 
-	public function new( cert : Certificate, key : Key ){
-		this.cert = @:privateAccess cert.__x;
-		this.key = @:privateAccess key.__k;
-	}
-}
-
 private class SocketInput extends haxe.io.Input {
 	@:allow(sys.ssl.Socket) private var __s : Socket;
 
@@ -26,7 +13,7 @@ private class SocketInput extends haxe.io.Input {
 
 	public override function readByte() {
 		__s.handshake();
-		var r = ssl_recv_char( @:privateAccess __s.ssl );
+		var r = @:privateAccess __s.ssl.recvChar();
 		if( r == -1 )
 			throw haxe.io.Error.Blocked;
 		else if( r < 0 )
@@ -38,7 +25,7 @@ private class SocketInput extends haxe.io.Input {
 		if( pos < 0 || len < 0 || ((pos+len):UInt) > (buf.length : UInt) )
 			throw haxe.io.Error.OutsideBounds;
 		__s.handshake();
-		var r = ssl_recv(  @:privateAccess __s.ssl, @:privateAccess buf.b, pos, len );
+		var r = @:privateAccess __s.ssl.recv(buf, pos, len);
 		if( r == -1 )
 			throw haxe.io.Error.Blocked;
 		else if( r < 0 )
@@ -50,9 +37,7 @@ private class SocketInput extends haxe.io.Input {
 		super.close();
 		if( __s != null ) __s.close();
 	}
-	
-	@:hlNative("ssl","ssl_recv") static function ssl_recv( ssl : ContextPtr, bytes : hl.Bytes, pos : Int, len : Int ) : Int { return -1; }
-	@:hlNative("ssl","ssl_recv_char") static function ssl_recv_char( ssl : ContextPtr ) : Int { return -1; }
+
 }
 
 private class SocketOutput extends haxe.io.Output {
@@ -64,7 +49,7 @@ private class SocketOutput extends haxe.io.Output {
 
 	public override function writeByte( c : Int ) {
 		__s.handshake();
-		var r = ssl_send_char( @:privateAccess __s.ssl, c);
+		var r = @:privateAccess __s.ssl.sendChar(c);
 		if( r == -1 )
 			throw haxe.io.Error.Blocked;
 		else if( r < 0 )
@@ -75,7 +60,7 @@ private class SocketOutput extends haxe.io.Output {
 		if( pos < 0 || len < 0 || ((pos+len):UInt) > (buf.length : UInt) )
 			throw haxe.io.Error.OutsideBounds;
 		__s.handshake();
-		var r = ssl_send( @:privateAccess __s.ssl, @:privateAccess buf.b, pos, len);
+		var r = @:privateAccess __s.ssl.send(buf, pos, len);
 		if( r == -1 )
 			throw haxe.io.Error.Blocked;
 		else if( r < 0 )
@@ -88,21 +73,18 @@ private class SocketOutput extends haxe.io.Output {
 		if( __s != null ) __s.close();
 	}
 
-	@:hlNative("ssl","ssl_send") static function ssl_send( ssl : ContextPtr, bytes : hl.Bytes, pos : Int, len : Int ) : Int { return -1; }
-	@:hlNative("ssl","ssl_send_char") static function ssl_send_char( ssl : ContextPtr, c : Int ) : Int { return -1; }
-
 }
 
 @:coreApi @:access(sys.net.Socket)
 class Socket extends sys.net.Socket {
-	
+
 	public static var DEFAULT_VERIFY_CERT : Null<Bool> = true;
 
 	public static var DEFAULT_CA : Null<Certificate>;
-	
-	private var conf : ConfigPtr;
-	private var ssl : ContextPtr;
-	
+
+	private var conf : Context.Config;
+	private var ssl : Context;
+
 	public var verifyCert : Null<Bool>;
 	private var caCert : Null<Certificate>;
 	private var hostname : String;
@@ -110,7 +92,7 @@ class Socket extends sys.net.Socket {
 	private var ownCert : Null<Certificate>;
 	private var ownKey : Null<Key>;
 	private var altSNIContexts : Null<Array<{match: String->Bool, key: Key, cert: Certificate}>>;
-	private var sniCallback : hl.Bytes -> SNICbResult;
+	private var sniCallback : hl.Bytes -> Context.SNICbResult;
 	private var handshakeDone : Bool;
 	private var isBlocking : Bool = true;
 
@@ -129,13 +111,13 @@ class Socket extends sys.net.Socket {
 
 	public override function connect(host : sys.net.Host, port : Int) : Void {
 		conf = buildConfig( false );
-		ssl = ssl_new( conf );
-		ssl_set_socket( ssl, __s );
+		ssl = new Context( conf );
+		ssl.setSocket(__s);
 		handshakeDone = false;
 		if( hostname == null )
 			hostname = host.host;
 		if( hostname != null )
-			ssl_set_hostname( ssl, @:privateAccess hostname.toUtf8() );
+			ssl.setHostname( @:privateAccess hostname.toUtf8() );
 		if( !sys.net.Socket.socket_connect( __s, host.ip, port ) )
 			throw new Sys.SysError("Failed to connect on "+host.toString()+":"+port);
 		if( isBlocking )
@@ -144,7 +126,7 @@ class Socket extends sys.net.Socket {
 
 	public function handshake() : Void {
 		if( !handshakeDone ){
-			var r = ssl_handshake( ssl );
+			var r = ssl.handshake();
 			if( r == 0 )
 				handshakeDone = true;
 			else if( r == -1 )
@@ -173,8 +155,8 @@ class Socket extends sys.net.Socket {
 	}
 
 	public override function close() : Void {
-		if( ssl != null ) ssl_close( ssl );
-		if( conf != null ) conf_close( conf );
+		if( ssl != null ) ssl.close();
+		if( conf != null ) conf.close();
 		if( altSNIContexts != null )
 			sniCallback = null;
 		sys.net.Socket.socket_close( __s );
@@ -199,8 +181,8 @@ class Socket extends sys.net.Socket {
 
 	public override function accept() : Socket {
 		var c = sys.net.Socket.socket_accept( __s );
-		var cssl = ssl_new( conf );
-		ssl_set_socket( cssl, c );
+		var cssl = new Context( conf );
+		cssl.setSocket(c);
 
 		var s = Type.createEmptyInstance( sys.ssl.Socket );
 		s.__s = c;
@@ -213,50 +195,35 @@ class Socket extends sys.net.Socket {
 	}
 
 	public function peerCertificate() : sys.ssl.Certificate {
-		var x = ssl_get_peer_certificate( ssl );
+		var x = ssl.getPeerCertificate();
 		return x==null ? null : new sys.ssl.Certificate( x );
 	}
 
-	private function buildConfig( server : Bool ) : ConfigPtr {
-		var conf = conf_new( server );
+	private function buildConfig( server : Bool ) : Context.Config {
+		var conf = new Context.Config( server );
 
 		if( ownCert != null && ownKey != null )
-			conf_set_cert( conf, @:privateAccess ownCert.__x, @:privateAccess ownKey.__k );
+			conf.setCert( @:privateAccess ownCert.__x, @:privateAccess ownKey.__k );
 
 		if ( altSNIContexts != null ) {
-			sniCallback = function(servername:hl.Bytes) : SNICbResult {
+			sniCallback = function(servername:hl.Bytes) : Context.SNICbResult {
 				var servername = @:privateAccess String.fromUTF8(servername);
 				for( c in altSNIContexts ){
 					if( c.match(servername) )
-						return new SNICbResult(c.cert, c.key);
+						return new Context.SNICbResult(c.cert, c.key);
 				}
 				if( ownKey != null && ownCert != null )
-					return new SNICbResult(ownCert, ownKey);
+					return new Context.SNICbResult(ownCert, ownKey);
 				return null;
 			}
-			conf_set_servername_callback( conf, sniCallback );
+			conf.setServernameCallback(sniCallback);
 		}
 
-		if ( caCert != null ) 
-			conf_set_ca( conf, caCert == null ? null : @:privateAccess caCert.__x  );
-		conf_set_verify( conf, if( verifyCert ) 1 else if( verifyCert==null ) 2 else 0 );
-		
+		if ( caCert != null )
+			conf.setCa( caCert == null ? null : @:privateAccess caCert.__x  );
+		conf.setVerify( if( verifyCert ) 1 else if( verifyCert==null ) 2 else 0 );
+
 		return conf;
 	}
-	
-	
-	@:hlNative("ssl","ssl_new") static function ssl_new( conf : ConfigPtr ) : ContextPtr { return null; }
-	@:hlNative("ssl","ssl_close") static function ssl_close( ssl : ContextPtr ) : Void {}
-	@:hlNative("ssl","ssl_handshake") static function ssl_handshake( ssl : ContextPtr ) : Int { return -1; }
-	@:hlNative("ssl","ssl_set_socket") static function ssl_set_socket( ssl : ContextPtr, socket : SocketHandle ) : Void { }
-	@:hlNative("ssl","ssl_set_hostname") static function ssl_set_hostname( ssl : ContextPtr, name : hl.Bytes ) : Void { }
-	@:hlNative("ssl","ssl_get_peer_certificate") static function ssl_get_peer_certificate( ssl : ContextPtr ) : CertificatePtr { return null; }
-	
-	@:hlNative("ssl","conf_new") static function conf_new( server : Bool ) : ConfigPtr { return null; }
-	@:hlNative("ssl","conf_close") static function conf_close( conf : ConfigPtr ) : Void { }
-	@:hlNative("ssl","conf_set_ca") static function conf_set_ca( conf : ConfigPtr, ca : CertificatePtr ) : Void { }
-	@:hlNative("ssl","conf_set_verify") static function conf_set_verify( conf : ConfigPtr, mode : Int ) : Void { }
-	@:hlNative("ssl","conf_set_cert") static function conf_set_cert( conf : ConfigPtr, cert : CertificatePtr, pkey : KeyPtr ) : Void { }
-	@:hlNative("ssl","conf_set_servername_callback") static function conf_set_servername_callback( conf : ConfigPtr, cb : hl.Bytes -> SNICbResult ) : Void { }
-	
+
 }