Pascal Peridont 8 years ago
parent
commit
3410001a98

+ 1 - 1
std/haxe/Http.hx

@@ -436,7 +436,7 @@ class Http {
 				sock = new php.net.SslSocket();
 				#elseif java
 				sock = new java.net.SslSocket();
-				#elseif (!no_ssl && (hxssl || cpp || (neko && !(macro || interp))))
+				#elseif (!no_ssl && (hxssl || hl || cpp || (neko && !(macro || interp))))
 				sock = new sys.ssl.Socket();
 				#else
 				throw "Https is only supported with -lib hxssl";

+ 7 - 3
std/hl/_std/sys/net/Socket.hx

@@ -23,9 +23,9 @@ package sys.net;
 import haxe.io.Error;
 
 #if doc_gen
-private enum SocketHandle { }
+@:noDoc enum SocketHandle { }
 #else
-private typedef SocketHandle = hl.Abstract<"hl_socket">;
+@:noDoc typedef SocketHandle = hl.Abstract<"hl_socket">;
 #end
 
 private class SocketOutput extends haxe.io.Output {
@@ -114,7 +114,11 @@ class Socket {
 	}
 
 	public function new() : Void {
-		if( __s == null ) __s = socket_new(false);
+		init();
+	}
+	
+	function init() : Void {
+		__s = socket_new(false);
 		input = new SocketInput(this);
 		output = new SocketOutput(this);
 	}

+ 130 - 0
std/hl/_std/sys/ssl/Certificate.hx

@@ -0,0 +1,130 @@
+package sys.ssl;
+import sys.ssl.Lib;
+
+@:noDoc
+typedef CertificatePtr = hl.Abstract<"hl_ssl_cert">;
+
+@:coreApi
+class Certificate {
+	
+	var __h : Null<Certificate>;
+	var __x : CertificatePtr;
+
+	@:allow(sys.ssl.Socket)
+	function new( x : CertificatePtr, ?h: Null<Certificate> ){
+		__x = x;
+		__h = h;
+	}
+
+	public static function loadFile( file : String ) : Certificate {
+		return new Certificate( cert_load_file( @:privateAccess file.toUtf8() ) );
+	}
+	
+	public static function loadPath( path : String ) : Certificate {
+		return new Certificate( cert_load_path( @:privateAccess path.toUtf8() ) );
+	}
+
+	public static function fromString( str : String ) : Certificate {
+		return new Certificate( cert_add_pem(null, @:privateAccess str.toUtf8() ) );
+	}
+	
+	public static function loadDefaults() : Certificate {
+		var x = cert_load_defaults();
+		if ( x != null )
+			return new Certificate( x );
+		
+		var defPaths = null;
+		switch( Sys.systemName() ){
+			case "Linux":
+				defPaths = [
+					"/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu/Gentoo etc.
+					"/etc/pki/tls/certs/ca-bundle.crt",   // Fedora/RHEL
+					"/etc/ssl/ca-bundle.pem",             // OpenSUSE
+					"/etc/pki/tls/cacert.pem",            // OpenELEC
+					"/etc/ssl/certs",                     // SLES10/SLES11
+					"/system/etc/security/cacerts"        // Android
+				];
+			case "BSD":
+				defPaths = [
+					"/usr/local/share/certs/ca-root-nss.crt", // FreeBSD/DragonFly
+					"/etc/ssl/cert.pem",                      // OpenBSD
+					"/etc/openssl/certs/ca-certificates.crt", // NetBSD	
+				];
+			case "Android":
+				defPaths = ["/system/etc/security/cacerts"];
+			default:
+		}
+		if( defPaths != null ){
+			for( path in defPaths ){
+				if( sys.FileSystem.exists(path) ){
+					if( sys.FileSystem.isDirectory(path) )
+						return loadPath(path);
+					else
+						return loadFile(path);
+				}
+			}
+		}
+		return null;
+	}
+
+	public var commonName(get,null) : Null<String>;
+	public var altNames(get, null) : Array<String>;
+	public var notBefore(get,null) : Date;
+	public var notAfter(get,null) : Date;
+
+	function get_commonName() : Null<String> {
+		return subject("CN");
+	}
+
+	function get_altNames() : Array<String> {
+		var a = cert_get_altnames(__x);
+		return [for( e in a ) @:privateAccess String.fromUTF8(e)];
+	}
+	
+	public function subject( field : String ) : Null<String> {
+		var s = cert_get_subject(__x, @:privateAccess field.toUtf8() );
+		return s==null ? null : new String( cast s );
+	}
+	
+	public function issuer( field : String ) : Null<String> {
+		var s = cert_get_issuer(__x, @:privateAccess field.toUtf8());
+		return s==null ? null : new String( cast s );
+	}
+
+	function get_notBefore() : Date {
+		var a = cert_get_notbefore( __x );
+		return new Date( a[0], a[1] - 1, a[2], a[3], a[4], a[5] );
+	}
+
+	function get_notAfter() : Date {
+		var a = cert_get_notafter( __x );
+		return new Date( a[0], a[1] - 1, a[2], a[3], a[4], a[5] );
+	}
+	
+	public function next() : Null<Certificate> {
+		var n = cert_get_next(__x);
+		return n == null ? null : new Certificate( n, __h==null ? this : __h );
+	}
+
+	public function add( pem : String ) : Void {
+		cert_add_pem(__x, @:privateAccess pem.toUtf8());
+	}
+
+	public function addDER( der : haxe.io.Bytes ) : Void {
+		cert_add_der(__x, @:privateAccess der.b, @:privateAccess der.length);
+	}
+
+	@:hlNative("ssl","cert_load_defaults") static function cert_load_defaults() : CertificatePtr { return null; }
+	@:hlNative("ssl","cert_load_file") static function cert_load_file( file : hl.Bytes ) : CertificatePtr { return null; }
+	@:hlNative("ssl","cert_load_path") static function cert_load_path( path : hl.Bytes ) : CertificatePtr { return null; }
+	@:hlNative("ssl","cert_get_subject") static function cert_get_subject( cert : CertificatePtr, obj : hl.Bytes ) : hl.Bytes { return null; }
+	@:hlNative("ssl","cert_get_issuer") static function cert_get_issuer( cert : CertificatePtr, obj : hl.Bytes ) : hl.Bytes { return null; }
+	@:hlNative("ssl","cert_get_altnames") static function cert_get_altnames( cert : CertificatePtr ) : hl.NativeArray<hl.Bytes> { return null; }
+	@:hlNative("ssl","cert_get_notbefore") static function cert_get_notbefore( cert : CertificatePtr ) : hl.NativeArray<Int> { return null; }
+	@:hlNative("ssl","cert_get_notafter") static function cert_get_notafter( cert : CertificatePtr ) : hl.NativeArray<Int> { return null; }
+	@:hlNative("ssl","cert_get_next") static function cert_get_next( cert : CertificatePtr ) : Null<CertificatePtr> { return null; }
+	@:hlNative("ssl","cert_add_pem") static function cert_add_pem( cert : Null<CertificatePtr>, data : hl.Bytes ) : CertificatePtr { return null; }
+	@:hlNative("ssl","cert_add_der") static function cert_add_der( cert : Null<CertificatePtr>, data : hl.Bytes, len : Int ) : CertificatePtr { return null; }
+	
+
+}

+ 27 - 0
std/hl/_std/sys/ssl/Digest.hx

@@ -0,0 +1,27 @@
+package sys.ssl;
+import sys.ssl.Lib;
+
+@:coreApi
+class Digest {
+	
+	public static function make( data : haxe.io.Bytes, alg : DigestAlgorithm ) : haxe.io.Bytes {
+		var size = 0;
+		var b = @:privateAccess dgst_make( data.b, data.length, (alg:String).toUtf8(), size );
+		return @:privateAccess new haxe.io.Bytes(b,size);
+	}
+	
+	public static function sign( data : haxe.io.Bytes, privKey : Key, alg : DigestAlgorithm ) : haxe.io.Bytes {
+		var size = 0;
+		var b = @:privateAccess dgst_sign( data.b, data.length, privKey.__k, (alg:String).toUtf8(), size );
+		return @:privateAccess new haxe.io.Bytes(b,size);
+	}
+	
+	public static function verify( data : haxe.io.Bytes, signature : haxe.io.Bytes, pubKey : Key, alg : DigestAlgorithm ) : Bool{
+		return @:privateAccess dgst_verify( data.b, data.length, signature.b, signature.length, pubKey.__k, (alg:String).toUtf8() );
+	}
+
+	@:hlNative("ssl","dgst_make") static function dgst_make( data : hl.Bytes, len : Int, alg : hl.Bytes, size : hl.Ref<Int> ) : hl.Bytes { return null; }
+	@:hlNative("ssl","dgst_sign") static function dgst_sign( data : hl.Bytes, len : Int, key : sys.ssl.Key.KeyPtr, alg : hl.Bytes, size : hl.Ref<Int> ) : hl.Bytes { return null; }
+	@:hlNative("ssl","dgst_verify") static function dgst_verify( data : hl.Bytes, dlen : Int, sign : hl.Bytes, slen : Int, key : sys.ssl.Key.KeyPtr, alg : hl.Bytes ) : Bool { return false; }
+	
+}

+ 36 - 0
std/hl/_std/sys/ssl/Key.hx

@@ -0,0 +1,36 @@
+package sys.ssl;
+import sys.ssl.Lib;
+
+@:noDoc
+typedef KeyPtr = hl.Abstract<"hl_ssl_pkey">;
+
+@:coreApi
+class Key {
+	
+	private var __k : KeyPtr;
+
+	private function new( k : KeyPtr ){
+		__k = k;
+	}
+	
+	public static function loadFile( file : String, ?isPublic : Bool, ?pass : String ) : Key {
+		var data = sys.io.File.getBytes( file );
+		var start = data.getString(0,11);
+		if( start == "-----BEGIN " )
+			return readPEM( data.toString(), isPublic==true, pass );
+		else
+			return readDER( data, isPublic==true );
+	}
+	
+	public static function readPEM( data : String, isPublic : Bool, ?pass : String ) : Key {
+		return new Key( key_from_pem( @:privateAccess data.toUtf8(), isPublic, pass == null ? null : @:privateAccess pass.toUtf8() ) );
+	}
+
+	public static function readDER( data : haxe.io.Bytes, isPublic : Bool ) : Key {
+		return new Key( key_from_der( @:privateAccess data.b, @:privateAccess data.length, isPublic ) );
+	}
+
+	@:hlNative("ssl","key_from_pem") static function key_from_pem( data : hl.Bytes, pub : Bool, pass : Null<hl.Bytes> ) : KeyPtr { return null; }
+	@:hlNative("ssl","key_from_der") static function key_from_der( data : hl.Bytes, len : Int, pub : Bool ) : KeyPtr { return null; }
+
+}

+ 10 - 0
std/hl/_std/sys/ssl/Lib.hx

@@ -0,0 +1,10 @@
+package sys.ssl;
+
+@:noDoc @:keep
+class Lib {
+	static function __init__() : Void{
+		ssl_init();
+	}
+	
+	@:hlNative("ssl","ssl_init") static function ssl_init(){};
+}

+ 251 - 0
std/hl/_std/sys/ssl/Socket.hx

@@ -0,0 +1,251 @@
+package sys.ssl;
+import sys.ssl.Lib;
+import sys.ssl.Key.KeyPtr;
+import sys.ssl.Certificate.CertificatePtr;
+import sys.net.Socket.SocketHandle;
+
+private typedef ConfigPtr = hl.Abstract<"mbedtls_ssl_config">;
+private typedef ContextPtr = hl.Abstract<"mbedtls_ssl_context">;
+
+@:keep
+private class SNICbResult {
+	public var cert : CertificatePtr;
+	public var key : KeyPtr; 
+	public function new( cert : Certificate, key : Key ){
+		this.cert = @:privateAccess cert.__x;
+		this.key = @:privateAccess key.__k;
+	}
+}
+
+private class SocketInput extends haxe.io.Input {
+	@:allow(sys.ssl.Socket) private var __s : Socket;
+
+	public function new( s : Socket ) {
+		this.__s = s;
+	}
+
+	public override function readByte() {
+		__s.handshake();
+		var r = ssl_recv_char( @:privateAccess __s.ssl );
+		if( r == -1 )
+			throw haxe.io.Error.Blocked;
+		else if( r < 0 )
+			throw new haxe.io.Eof();
+		return r;
+	}
+
+	public override function readBytes( buf : haxe.io.Bytes, pos : Int, len : Int ) : Int {
+		__s.handshake();
+		var r = ssl_recv(  @:privateAccess __s.ssl, @:privateAccess buf.b, pos, len );
+		if( r == -1 )
+			throw haxe.io.Error.Blocked;
+		else if( r < 0 )
+			throw new haxe.io.Eof();
+		return r;
+	}
+
+	public override function close() {
+		super.close();
+		if( __s != null ) __s.close();
+	}
+	
+	@:hlNative("ssl","ssl_recv") static function ssl_recv( ssl : ContextPtr, bytes : hl.Bytes, pos : Int, len : Int ) : Int { return -1; }
+	@:hlNative("ssl","ssl_recv_char") static function ssl_recv_char( ssl : ContextPtr ) : Int { return -1; }
+}
+
+private class SocketOutput extends haxe.io.Output {
+	@:allow(sys.ssl.Socket) private var __s : Socket;
+
+	public function new( s : Socket ) {
+		this.__s = s;
+	}
+
+	public override function writeByte( c : Int ) {
+		__s.handshake();
+		var r = ssl_send_char( @:privateAccess __s.ssl, c);
+		if( r == -1 )
+			throw haxe.io.Error.Blocked;
+		else if( r < 0 )
+			throw new haxe.io.Eof();
+	}
+
+	public override function writeBytes( buf : haxe.io.Bytes, pos : Int, len : Int) : Int {
+		__s.handshake();
+		var r = ssl_send( @:privateAccess __s.ssl, @:privateAccess buf.b, pos, len);
+		if( r == -1 )
+			throw haxe.io.Error.Blocked;
+		else if( r < 0 )
+			throw new haxe.io.Eof();
+		return r;
+	}
+
+	public override function close() {
+		super.close();
+		if( __s != null ) __s.close();
+	}
+
+	@:hlNative("ssl","ssl_send") static function ssl_send( ssl : ContextPtr, bytes : hl.Bytes, pos : Int, len : Int ) : Int { return -1; }
+	@:hlNative("ssl","ssl_send_char") static function ssl_send_char( ssl : ContextPtr, c : Int ) : Int { return -1; }
+
+}
+
+@:coreApi @:access(sys.net.Socket)
+class Socket extends sys.net.Socket {
+	
+	public static var DEFAULT_VERIFY_CERT : Null<Bool> = true;
+
+	public static var DEFAULT_CA : Null<Certificate>;
+	
+	private var conf : ConfigPtr;
+	private var ssl : ContextPtr;
+	
+	public var verifyCert : Null<Bool>;
+	private var caCert : Null<Certificate>;
+	private var hostname : String;
+
+	private var ownCert : Null<Certificate>;
+	private var ownKey : Null<Key>;
+	private var altSNIContexts : Null<Array<{match: String->Bool, key: Key, cert: Certificate}>>;
+	private var sniCallback : hl.Bytes -> SNICbResult;
+	private var handshakeDone : Bool;
+
+	private override function init() : Void {
+		__s = sys.net.Socket.socket_new( false );
+		input = new SocketInput( this );
+		output = new SocketOutput( this );
+		if( DEFAULT_VERIFY_CERT && DEFAULT_CA == null ){
+			try {
+				DEFAULT_CA = Certificate.loadDefaults();
+			}catch( e : Dynamic ){}
+		}
+		verifyCert = DEFAULT_VERIFY_CERT;
+		caCert = DEFAULT_CA;
+	}
+
+	public override function connect(host : sys.net.Host, port : Int) : Void {
+		conf = buildConfig( false );
+		ssl = ssl_new( conf );
+		ssl_set_socket( ssl, __s );
+		handshakeDone = false;
+		if( hostname == null )
+			hostname = host.host;
+		if( hostname != null )
+			ssl_set_hostname( ssl, @:privateAccess hostname.toUtf8() );
+		if( !sys.net.Socket.socket_connect( __s, host.ip, port ) )
+			throw new Sys.SysError("Failed to connect on "+host.toString()+":"+port);
+		handshake();
+	}
+
+	public function handshake() : Void {
+		if( !handshakeDone ){
+			var r = ssl_handshake( ssl );
+			if( r == 0 )
+				handshakeDone = true;
+			else if( r == -1 )
+				throw haxe.io.Error.Blocked;
+			else
+				throw new haxe.io.Eof();
+		}
+	}
+
+	public function setCA( cert : Certificate ) : Void {
+		caCert = cert;
+	}
+
+	public function setHostname( name : String ) : Void {
+		hostname = name;
+	}
+
+	public function setCertificate( cert : Certificate, key : Key ) : Void {
+		ownCert = cert;
+		ownKey = key;
+	}
+
+	public override function close() : Void {
+		if( ssl != null ) ssl_close( ssl );
+		if( conf != null ) conf_close( conf );
+		if( altSNIContexts != null )
+			sniCallback = null;
+		sys.net.Socket.socket_close( __s );
+		var input : SocketInput = cast input;
+		var output : SocketOutput = cast output;
+		@:privateAccess input.__s = output.__s = null;
+		input.close();
+		output.close();
+	}
+
+	public function addSNICertificate( cbServernameMatch : String->Bool, cert : Certificate, key : Key ) : Void {
+		if( altSNIContexts == null )
+			altSNIContexts = [];
+		altSNIContexts.push( {match: cbServernameMatch, cert: cert, key: key} );
+	}
+
+	public override function bind( host : sys.net.Host, port : Int ) : Void {
+		conf = buildConfig( true );
+
+		sys.net.Socket.socket_bind( __s, host.ip, port );
+	}
+
+	public override function accept() : Socket {
+		var c = sys.net.Socket.socket_accept( __s );
+		var cssl = ssl_new( conf );
+		ssl_set_socket( cssl, c );
+
+		var s = Type.createEmptyInstance( sys.ssl.Socket );
+		s.__s = c;
+		s.ssl = cssl;
+		s.input = new SocketInput(s);
+		s.output = new SocketOutput(s);
+		s.handshakeDone = false;
+
+		return s;
+	}
+
+	public function peerCertificate() : sys.ssl.Certificate {
+		var x = ssl_get_peer_certificate( ssl );
+		return x==null ? null : new sys.ssl.Certificate( x );
+	}
+
+	private function buildConfig( server : Bool ) : ConfigPtr {
+		var conf = conf_new( server );
+
+		if( ownCert != null && ownKey != null )
+			conf_set_cert( conf, @:privateAccess ownCert.__x, @:privateAccess ownKey.__k );
+
+		if ( altSNIContexts != null ) {
+			sniCallback = function(servername:hl.Bytes) : SNICbResult {
+				var servername = @:privateAccess String.fromUTF8(servername);
+				for( c in altSNIContexts ){
+					if( c.match(servername) )
+						return new SNICbResult(c.cert, c.key);
+				}
+				if( ownKey != null && ownCert != null )
+					return new SNICbResult(ownCert, ownKey);
+				return null;
+			}
+			conf_set_servername_callback( conf, sniCallback );
+		}
+
+		if ( caCert != null ) 
+			conf_set_ca( conf, caCert == null ? null : @:privateAccess caCert.__x  );
+		conf_set_verify( conf, if( verifyCert ) 1 else if( verifyCert==null ) 2 else 0 );
+		
+		return conf;
+	}
+	
+	
+	@:hlNative("ssl","ssl_new") static function ssl_new( conf : ConfigPtr ) : ContextPtr { return null; }
+	@:hlNative("ssl","ssl_close") static function ssl_close( ssl : ContextPtr ) : Void {}
+	@:hlNative("ssl","ssl_handshake") static function ssl_handshake( ssl : ContextPtr ) : Int { return -1; }
+	@:hlNative("ssl","ssl_set_socket") static function ssl_set_socket( ssl : ContextPtr, socket : SocketHandle ) : Void { }
+	@:hlNative("ssl","ssl_set_hostname") static function ssl_set_hostname( ssl : ContextPtr, name : hl.Bytes ) : Void { }
+	@:hlNative("ssl","ssl_get_peer_certificate") static function ssl_get_peer_certificate( ssl : ContextPtr ) : CertificatePtr { return null; }
+	
+	@:hlNative("ssl","conf_new") static function conf_new( server : Bool ) : ConfigPtr { return null; }
+	@:hlNative("ssl","conf_close") static function conf_close( conf : ConfigPtr ) : Void { }
+	@:hlNative("ssl","conf_set_ca") static function conf_set_ca( conf : ConfigPtr, ca : CertificatePtr ) : Void { }
+	@:hlNative("ssl","conf_set_verify") static function conf_set_verify( conf : ConfigPtr, mode : Int ) : Void { }
+	@:hlNative("ssl","conf_set_cert") static function conf_set_cert( conf : ConfigPtr, cert : CertificatePtr, pkey : KeyPtr ) : Void { }
+	@:hlNative("ssl","conf_set_servername_callback") static function conf_set_servername_callback( conf : ConfigPtr, cb : hl.Bytes -> SNICbResult ) : Void { }
+	
+}