Преглед изворни кода

fixed mbedtls blocking handling + added ssl_set_bio

ncannasse пре 6 година
родитељ
комит
a8cbcc5b99
1 измењених фајлова са 61 додато и 27 уклоњено
  1. 61 27
      libs/ssl/ssl.c

+ 61 - 27
libs/ssl/ssl.c

@@ -59,6 +59,14 @@ static bool ssl_init_done = false;
 static mbedtls_entropy_context entropy;
 static mbedtls_ctr_drbg_context ctr_drbg;
 
+static bool is_ssl_blocking( int r ) {
+	return r == MBEDTLS_ERR_SSL_WANT_READ || r == MBEDTLS_ERR_SSL_WANT_WRITE;
+}
+
+static int ssl_block_error( int r ) {
+	return is_ssl_blocking(r) ? -1 : -2;
+}
+
 static void cert_finalize(hl_ssl_cert *c) {
 	mbedtls_x509_crt_free(c->c);
 	free(c->c);
@@ -71,17 +79,6 @@ static void pkey_finalize(hl_ssl_pkey *k) {
 	k->k = NULL;
 }
 
-static int block_error() {
-#ifdef HL_WIN
-	int err = WSAGetLastError();
-	if (err == WSAEWOULDBLOCK || err == WSAEALREADY || err == WSAETIMEDOUT)
-#else
-	if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS || errno == EALREADY)
-#endif
-		return -1;
-	return -2;
-}
-
 static int ssl_error(int ret) {
 	char buf[128];
 	uchar buf16[128];
@@ -111,25 +108,59 @@ HL_PRIM void HL_NAME(ssl_close)(mbedtls_ssl_context *ssl) {
 HL_PRIM int HL_NAME(ssl_handshake)(mbedtls_ssl_context *ssl) {
 	int r;
 	r = mbedtls_ssl_handshake(ssl);
-	if (r == SOCKET_ERROR)
-		return block_error();
-	else if (r != 0) 
+	if( is_ssl_blocking(r) )
+		return -1;
+	if( r != 0 )
 		return ssl_error(r);
 	return 0;
 }
 
-int net_read(void *fd, unsigned char *buf, size_t len) {
-	return recv((SOCKET)(int_val)fd, (char *)buf, (int)len, 0);
+
+static bool is_block_error() {
+#ifdef HL_WIN
+	int err = WSAGetLastError();
+	if (err == WSAEWOULDBLOCK || err == WSAEALREADY || err == WSAETIMEDOUT)
+#else
+	if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINPROGRESS || errno == EALREADY)
+#endif
+		return true;
+	return false;
+}
+
+static int net_read(void *fd, unsigned char *buf, size_t len) {
+	int r = recv((SOCKET)(int_val)fd, (char *)buf, (int)len, 0);
+	if( r == SOCKET_ERROR && is_block_error() )
+		return MBEDTLS_ERR_SSL_WANT_READ;
+	return r;
 }
 
-int net_write(void *fd, const unsigned char *buf, size_t len) {
-	return send((SOCKET)(int_val)fd, (char *)buf, (int)len, 0);
+static int net_write(void *fd, const unsigned char *buf, size_t len) {
+	int r = send((SOCKET)(int_val)fd, (char *)buf, (int)len, 0);
+	if( r == SOCKET_ERROR && is_block_error() )
+		return MBEDTLS_ERR_SSL_WANT_WRITE;
+	return r;
 }
 
 HL_PRIM void HL_NAME(ssl_set_socket)(mbedtls_ssl_context *ssl, hl_socket *socket) {
 	mbedtls_ssl_set_bio(ssl, (void*)(int_val)socket->sock, net_write, net_read, NULL);
 }
 
+static int arr_read( void *arr, unsigned char *buf, size_t len ) {
+	int r = ((int (*)(vdynamic*,unsigned char*,int))hl_aptr(arr,vclosure*)[1]->fun)( hl_aptr(arr,vdynamic*)[0], buf, (int)len );
+	if( r == -2 ) return MBEDTLS_ERR_SSL_WANT_READ;
+	return r;
+}
+
+static int arr_write( void *arr, const unsigned char *buf, size_t len ) {
+	int r = ((int (*)(vdynamic*,const unsigned char*,int))hl_aptr(arr,vclosure*)[2]->fun)( hl_aptr(arr,vdynamic*)[0], buf, (int)len );
+	if( r == -2 ) return MBEDTLS_ERR_SSL_WANT_WRITE;
+	return r;
+}
+
+HL_PRIM void HL_NAME(ssl_set_bio)( mbedtls_ssl_context *ssl, varray *ctx ) {
+	mbedtls_ssl_set_bio(ssl, ctx, arr_write, arr_read, NULL);	
+}
+
 HL_PRIM void HL_NAME(ssl_set_hostname)(mbedtls_ssl_context *ssl, vbyte *hostname) {
 	int ret;
 	if ((ret = mbedtls_ssl_set_hostname(ssl, (char*)hostname)) != 0)
@@ -145,39 +176,42 @@ HL_PRIM hl_ssl_cert *HL_NAME(ssl_get_peer_certificate)(mbedtls_ssl_context *ssl)
 DEFINE_PRIM(TSSL, ssl_new, TCONF);
 DEFINE_PRIM(_VOID, ssl_close, TSSL);
 DEFINE_PRIM(_I32, ssl_handshake, TSSL);
+DEFINE_PRIM(_VOID, ssl_set_bio, TSSL _DYN);
 DEFINE_PRIM(_VOID, ssl_set_socket, TSSL _SOCK);
 DEFINE_PRIM(_VOID, ssl_set_hostname, TSSL _BYTES);
 DEFINE_PRIM(TCERT, ssl_get_peer_certificate, TSSL);
 
 HL_PRIM int HL_NAME(ssl_send_char)(mbedtls_ssl_context *ssl, int c) {
 	unsigned char cc;
+	int r;
 	cc = (unsigned char)c;
-	if (mbedtls_ssl_write(ssl, &cc, 1) == SOCKET_ERROR)
-		return block_error();
+	r = mbedtls_ssl_write(ssl, &cc, 1);
+	if( r < 0 )
+		return ssl_block_error(r);
 	return 1;
 }
 
 HL_PRIM int HL_NAME(ssl_send)(mbedtls_ssl_context *ssl, vbyte *buf, int pos, int len) {
 	int r = mbedtls_ssl_write(ssl, (const unsigned char *)buf + pos, len);
-	if (r == SOCKET_ERROR) 
-		return block_error();
+	if( r < 0 ) 
+		return ssl_block_error(r);
 	return r;
 }
 
 HL_PRIM int HL_NAME(ssl_recv_char)(mbedtls_ssl_context *ssl) {
 	unsigned char c;
 	int ret = mbedtls_ssl_read(ssl, &c, 1);
-	if (ret == SOCKET_ERROR || ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
-		return block_error();
+	if( ret != 1 )
+		return ssl_block_error(ret);
 	return c;
 }
 
 HL_PRIM int HL_NAME(ssl_recv)(mbedtls_ssl_context *ssl, vbyte *buf, int pos, int len) {
 	int ret = mbedtls_ssl_read(ssl, (unsigned char*)buf+pos, len);
-	if (ret == SOCKET_ERROR)
-		return block_error();
-	else if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY)
+	if( ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY )
 		return 0;
+	if( ret < 0 )
+		return ssl_block_error(ret);
 	return ret;
 }