Socket.hx 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. package sys.ssl;
  2. private typedef SocketHandle = Dynamic;
  3. private typedef CTX = Dynamic;
  4. private typedef SSL = Dynamic;
  5. private class SocketInput extends haxe.io.Input {
  6. @:allow(sys.ssl.Socket) private var __s : Socket;
  7. public function new( s : Socket ) {
  8. this.__s = s;
  9. }
  10. public override function readByte() {
  11. return try {
  12. __s.handshake();
  13. ssl_recv_char( @:privateAccess __s.ssl );
  14. } catch( e : Dynamic ) {
  15. if( e == "Blocking" )
  16. throw haxe.io.Error.Blocked;
  17. else if( __s == null )
  18. throw haxe.io.Error.Custom(e);
  19. else
  20. throw new haxe.io.Eof();
  21. }
  22. }
  23. public override function readBytes( buf : haxe.io.Bytes, pos : Int, len : Int ) : Int {
  24. var r : Int;
  25. if( __s == null )
  26. throw "Invalid handle";
  27. try {
  28. __s.handshake();
  29. r = ssl_recv( @:privateAccess __s.ssl, buf.getData(), pos, len );
  30. } catch( e : Dynamic ) {
  31. if( e == "Blocking" )
  32. throw haxe.io.Error.Blocked;
  33. else
  34. throw haxe.io.Error.Custom(e);
  35. }
  36. if( r == 0 )
  37. throw new haxe.io.Eof();
  38. return r;
  39. }
  40. public override function close() {
  41. super.close();
  42. if( __s != null ) __s.close();
  43. }
  44. private static var ssl_recv = neko.Lib.loadLazy( "ssl", "ssl_recv", 4 );
  45. private static var ssl_recv_char = neko.Lib.loadLazy( "ssl", "ssl_recv_char", 1 );
  46. }
  47. private class SocketOutput extends haxe.io.Output {
  48. @:allow(sys.ssl.Socket) private var __s : Socket;
  49. public function new( s : Socket ) {
  50. this.__s = s;
  51. }
  52. public override function writeByte( c : Int ) {
  53. if( __s == null )
  54. throw "Invalid handle";
  55. try {
  56. __s.handshake();
  57. ssl_send_char( @:privateAccess __s.ssl, c);
  58. } catch( e : Dynamic ) {
  59. if( e == "Blocking" )
  60. throw haxe.io.Error.Blocked;
  61. else
  62. throw haxe.io.Error.Custom(e);
  63. }
  64. }
  65. public override function writeBytes( buf : haxe.io.Bytes, pos : Int, len : Int) : Int {
  66. return try {
  67. __s.handshake();
  68. ssl_send( @:privateAccess __s.ssl, buf.getData(), pos, len);
  69. } catch( e : Dynamic ) {
  70. if( e == "Blocking" )
  71. throw haxe.io.Error.Blocked;
  72. else
  73. throw haxe.io.Error.Custom(e);
  74. }
  75. }
  76. public override function close() {
  77. super.close();
  78. if( __s != null ) __s.close();
  79. }
  80. private static var ssl_send_char = neko.Lib.loadLazy( "ssl", "ssl_send_char", 2 );
  81. private static var ssl_send = neko.Lib.loadLazy( "ssl", "ssl_send", 4 );
  82. }
  83. @:coreApi
  84. class Socket extends sys.net.Socket {
  85. public static var DEFAULT_VERIFY_CERT : Null<Bool> = true;
  86. public static var DEFAULT_CA : Null<Certificate>;
  87. private var ctx : CTX;
  88. private var ssl : SSL;
  89. public var verifyCert : Null<Bool>;
  90. private var caCert : Null<Certificate>;
  91. private var hostname : String;
  92. private var ownCert : Null<Certificate>;
  93. private var ownKey : Null<Key>;
  94. private var altSNIContexts : Null<Array<{match: String->Bool, key: Key, cert: Certificate}>>;
  95. private var sniCallback : Dynamic;
  96. private var handshakeDone : Bool;
  97. private override function init() : Void {
  98. __s = socket_new( false );
  99. input = new SocketInput( this );
  100. output = new SocketOutput( this );
  101. if( DEFAULT_VERIFY_CERT && DEFAULT_CA == null ){
  102. try {
  103. DEFAULT_CA = Certificate.loadDefaults();
  104. }catch( e : Dynamic ){}
  105. }
  106. verifyCert = DEFAULT_VERIFY_CERT;
  107. caCert = DEFAULT_CA;
  108. }
  109. public override function connect(host : sys.net.Host, port : Int) : Void {
  110. try {
  111. ctx = buildSSLContext( false );
  112. ssl = ssl_new( ctx );
  113. ssl_set_socket( ssl, __s );
  114. handshakeDone = false;
  115. if( hostname == null )
  116. hostname = host.host;
  117. if( hostname != null )
  118. ssl_set_hostname( ssl, untyped hostname.__s );
  119. socket_connect( __s, host.ip, port );
  120. handshake();
  121. } catch( s : String ) {
  122. if( s == "std@socket_connect" )
  123. throw "Failed to connect on "+host.host+":"+port;
  124. else
  125. neko.Lib.rethrow(s);
  126. } catch( e : Dynamic ) {
  127. neko.Lib.rethrow(e);
  128. }
  129. }
  130. public function handshake() : Void {
  131. if( !handshakeDone ){
  132. try {
  133. ssl_handshake( ssl );
  134. handshakeDone = true;
  135. } catch( e : Dynamic ) {
  136. if( e == "Blocking" )
  137. throw haxe.io.Error.Blocked;
  138. else
  139. neko.Lib.rethrow( e );
  140. }
  141. }
  142. }
  143. public function setCA( cert : Certificate ) : Void {
  144. caCert = cert;
  145. }
  146. public function setHostname( name : String ) : Void {
  147. hostname = name;
  148. }
  149. public function setCertificate( cert : Certificate, key : Key ) : Void {
  150. ownCert = cert;
  151. ownKey = key;
  152. }
  153. public override function read() : String {
  154. handshake();
  155. var b = ssl_read( ssl );
  156. if( b == null )
  157. return "";
  158. return new String(cast b);
  159. }
  160. public override function write( content : String ) : Void {
  161. handshake();
  162. ssl_write( ssl, untyped content.__s );
  163. }
  164. public override function close() : Void {
  165. if( ssl != null ) ssl_close( ssl );
  166. if( ctx != null ) conf_close( ctx );
  167. if( altSNIContexts != null )
  168. sniCallback = null;
  169. socket_close( __s );
  170. var input : SocketInput = cast input;
  171. var output : SocketOutput = cast output;
  172. @:privateAccess input.__s = output.__s = null;
  173. input.close();
  174. output.close();
  175. }
  176. public function addSNICertificate( cbServernameMatch : String->Bool, cert : Certificate, key : Key ) : Void {
  177. if( altSNIContexts == null )
  178. altSNIContexts = [];
  179. altSNIContexts.push( {match: cbServernameMatch, cert: cert, key: key} );
  180. }
  181. public override function bind( host : sys.net.Host, port : Int ) : Void {
  182. ctx = buildSSLContext( true );
  183. socket_bind( __s, host.ip, port );
  184. }
  185. public override function accept() : Socket {
  186. var c = socket_accept( __s );
  187. var ssl = ssl_new( ctx );
  188. ssl_set_socket( ssl, c );
  189. var s = Type.createEmptyInstance( sys.ssl.Socket );
  190. s.__s = c;
  191. s.ssl = ssl;
  192. s.input = new SocketInput(s);
  193. s.output = new SocketOutput(s);
  194. s.handshakeDone = false;
  195. return s;
  196. }
  197. public function peerCertificate() : sys.ssl.Certificate {
  198. var x = ssl_get_peer_certificate( ssl );
  199. return x==null ? null : new sys.ssl.Certificate( x );
  200. }
  201. private function buildSSLContext( server : Bool ) : CTX {
  202. var ctx : CTX = conf_new( server );
  203. if( ownCert != null && ownKey != null )
  204. conf_set_cert( ctx, @:privateAccess ownCert.__x, @:privateAccess ownKey.__k );
  205. if ( altSNIContexts != null ) {
  206. sniCallback = function(servername) {
  207. var servername = new String(cast servername);
  208. for( c in altSNIContexts ){
  209. if( c.match(servername) )
  210. return @:privateAccess {key: c.key.__k, cert: c.cert.__x};
  211. }
  212. if( ownKey != null && ownCert != null )
  213. return @:privateAccess { key: ownKey.__k, cert: ownCert.__x };
  214. return null;
  215. }
  216. conf_set_servername_callback( ctx, sniCallback );
  217. }
  218. if ( caCert != null )
  219. conf_set_ca( ctx, caCert == null ? null : @:privateAccess caCert.__x );
  220. conf_set_verify( ctx, verifyCert );
  221. return ctx;
  222. }
  223. private static var ssl_new = neko.Lib.loadLazy( "ssl", "ssl_new", 1 );
  224. private static var ssl_close = neko.Lib.loadLazy( "ssl", "ssl_close", 1 );
  225. private static var ssl_handshake = neko.Lib.loadLazy( "ssl", "ssl_handshake", 1 );
  226. private static var ssl_set_socket = neko.Lib.loadLazy( "ssl", "ssl_set_socket", 2 );
  227. private static var ssl_set_hostname = neko.Lib.loadLazy( "ssl", "ssl_set_hostname", 2 );
  228. private static var ssl_get_peer_certificate = neko.Lib.loadLazy( "ssl", "ssl_get_peer_certificate", 1 );
  229. private static var ssl_read = neko.Lib.loadLazy( "ssl", "ssl_read", 1 );
  230. private static var ssl_write = neko.Lib.loadLazy( "ssl", "ssl_write", 2 );
  231. private static var conf_new = neko.Lib.loadLazy( "ssl", "conf_new", 1 );
  232. private static var conf_close = neko.Lib.loadLazy( "ssl", "conf_close", 1 );
  233. private static var conf_set_ca = neko.Lib.loadLazy( "ssl", "conf_set_ca", 2 );
  234. private static var conf_set_verify = neko.Lib.loadLazy( "ssl", "conf_set_verify", 2 );
  235. private static var conf_set_cert = neko.Lib.loadLazy( "ssl", "conf_set_cert", 3 );
  236. private static var conf_set_servername_callback = neko.Lib.loadLazy( "ssl", "conf_set_servername_callback", 2 );
  237. private static var socket_new = neko.Lib.load("std","socket_new",1);
  238. private static var socket_close = neko.Lib.load("std","socket_close",1);
  239. private static var socket_connect = neko.Lib.load("std","socket_connect",3);
  240. private static var socket_bind = neko.Lib.load("std","socket_bind",3);
  241. private static var socket_accept = neko.Lib.load("std","socket_accept",1);
  242. }