Browse Source

Merge pull request #64 from paullouisageneau/websocket

Add optional WebSocket with the same API
Paul-Louis Ageneau 5 years ago
parent
commit
7c93c698da

+ 33 - 7
CMakeLists.txt

@@ -6,6 +6,7 @@ project (libdatachannel
 
 option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
 option(USE_JUICE "Use libjuice instead of libnice" OFF)
+option(RTC_ENABLE_WEBSOCKET "Build WebSocket support" ON)
 
 if(USE_GNUTLS)
 	option(USE_NETTLE "Use Nettle instead of OpenSSL in libjuice" ON)
@@ -39,6 +40,14 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
 )
 
+set(LIBDATACHANNEL_WEBSOCKET_SOURCES
+	${CMAKE_CURRENT_SOURCE_DIR}/src/base64.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/tcptransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/tlstransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/wstransport.cpp
+)
+
 set(LIBDATACHANNEL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/candidate.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/channel.hpp
@@ -55,6 +64,7 @@ set(LIBDATACHANNEL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/reliability.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.h
 	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/rtc.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/include/rtc/websocket.hpp
 )
 
 set(TESTS_SOURCES
@@ -89,26 +99,42 @@ endif()
 add_library(Usrsctp::Usrsctp ALIAS usrsctp)
 add_library(Usrsctp::UsrsctpStatic ALIAS usrsctp-static)
 
-add_library(datachannel SHARED ${LIBDATACHANNEL_SOURCES})
+if (RTC_ENABLE_WEBSOCKET)
+	add_library(datachannel SHARED
+		${LIBDATACHANNEL_SOURCES}
+		${LIBDATACHANNEL_WEBSOCKET_SOURCES})
+	add_library(datachannel-static STATIC EXCLUDE_FROM_ALL
+		${LIBDATACHANNEL_SOURCES}
+		${LIBDATACHANNEL_WEBSOCKET_SOURCES})
+	target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=1)
+	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=1)
+else()
+	add_library(datachannel SHARED
+		${LIBDATACHANNEL_SOURCES})
+	add_library(datachannel-static STATIC EXCLUDE_FROM_ALL
+		${LIBDATACHANNEL_SOURCES})
+	target_compile_definitions(datachannel PUBLIC RTC_ENABLE_WEBSOCKET=0)
+	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_WEBSOCKET=0)
+endif()
+
 set_target_properties(datachannel PROPERTIES
 	VERSION ${PROJECT_VERSION}
 	CXX_STANDARD 17)
+set_target_properties(datachannel-static PROPERTIES
+	VERSION ${PROJECT_VERSION}
+	CXX_STANDARD 17)
 
 target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
 target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
-target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
-
-add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES})
-set_target_properties(datachannel-static PROPERTIES
-	VERSION ${PROJECT_VERSION}
-	CXX_STANDARD 17)
 
 target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
 target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
 target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/deps/plog/include)
+
+target_link_libraries(datachannel Threads::Threads Usrsctp::UsrsctpStatic)
 target_link_libraries(datachannel-static Threads::Threads Usrsctp::UsrsctpStatic)
 
 if(WIN32)

+ 1 - 0
Jamfile

@@ -10,6 +10,7 @@ lib libdatachannel
 	<cxxstd>17
 	<include>./include/rtc
 	<define>USE_JUICE=1
+	<define>RTC_ENABLE_WEBSOCKET=0
 	<library>/libdatachannel//usrsctp
 	<library>/libdatachannel//juice
 	<library>/libdatachannel//plog

+ 8 - 0
Makefile

@@ -38,6 +38,14 @@ else
         LIBS+=glib-2.0 gobject-2.0 nice
 endif
 
+RTC_ENABLE_WEBSOCKET ?= 1
+ifneq ($(RTC_ENABLE_WEBSOCKET), 0)
+        CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=1
+else
+        CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=0
+endif
+
+
 INCLUDES+=$(shell pkg-config --cflags $(LIBS))
 LDLIBS+=$(LOCALLIBS) $(shell pkg-config --libs $(LIBS))
 

+ 0 - 1
include/rtc/datachannel.hpp

@@ -82,7 +82,6 @@ private:
 	std::atomic<bool> mIsClosed = false;
 
 	Queue<message_ptr> mRecvQueue;
-	std::atomic<size_t> mRecvAmount = 0;
 
 	friend class PeerConnection;
 };

+ 16 - 1
include/rtc/include.hpp

@@ -19,6 +19,10 @@
 #ifndef RTC_INCLUDE_H
 #define RTC_INCLUDE_H
 
+#ifndef RTC_ENABLE_WEBSOCKET
+#define RTC_ENABLE_WEBSOCKET 1
+#endif
+
 #ifdef _WIN32
 #ifndef _WIN32_WINNT
 #define _WIN32_WINNT 0x0602
@@ -56,10 +60,21 @@ const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
 const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536;    // Remote max message size if not specified in SDP
 const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
 
-
+// overloaded helper
 template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
 template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
 
+// weak_ptr bind helper
+template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
+	return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
+		using result_type = typename decltype(bound)::result_type;
+		if (auto shared_this = weak_this.lock())
+			return bound(args...);
+		else
+			return (result_type) false;
+	};
+}
+
 template <typename... P> class synchronized_callback {
 public:
 	synchronized_callback() = default;

+ 1 - 0
include/rtc/message.hpp

@@ -30,6 +30,7 @@ namespace rtc {
 struct Message : binary {
 	enum Type { Binary, String, Control, Reset };
 
+	Message(const Message &message) = default;
 	Message(size_t size, Type type_ = Binary) : binary(size), type(type_) {}
 
 	template <typename Iterator>

+ 2 - 2
include/rtc/peerconnection.hpp

@@ -98,8 +98,6 @@ public:
 	std::string connectionInfo;
 
 private:
-	init_token mInitToken = Init::Token();
-
 	std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
 	std::shared_ptr<DtlsTransport> initDtlsTransport();
 	std::shared_ptr<SctpTransport> initSctpTransport();
@@ -130,6 +128,8 @@ private:
 	const Configuration mConfig;
 	const std::shared_ptr<Certificate> mCertificate;
 
+	init_token mInitToken = Init::Token();
+
 	std::optional<Description> mLocalDescription, mRemoteDescription;
 	mutable std::recursive_mutex mLocalDescriptionMutex, mRemoteDescriptionMutex;
 

+ 11 - 0
include/rtc/queue.hpp

@@ -44,6 +44,7 @@ public:
 	void push(T element);
 	std::optional<T> pop();
 	std::optional<T> peek();
+	std::optional<T> exchange(T element);
 	bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
 
 private:
@@ -118,6 +119,16 @@ template <typename T> std::optional<T> Queue<T>::peek() {
 	}
 }
 
+template <typename T> std::optional<T> Queue<T>::exchange(T element) {
+	std::unique_lock lock(mMutex);
+	if (!mQueue.empty()) {
+		std::swap(mQueue.front(), element);
+		return std::optional<T>{element};
+	} else {
+		return nullopt;
+	}
+}
+
 template <typename T>
 bool Queue<T>::wait(const std::optional<std::chrono::milliseconds> &duration) {
 	std::unique_lock lock(mMutex);

+ 30 - 19
include/rtc/rtc.h

@@ -27,6 +27,10 @@ extern "C" {
 
 // libdatachannel C API
 
+#ifndef RTC_ENABLE_WEBSOCKET
+#define RTC_ENABLE_WEBSOCKET 1
+#endif
+
 typedef enum {
 	RTC_NEW = 0,
 	RTC_CONNECTING = 1,
@@ -42,8 +46,7 @@ typedef enum {
 	RTC_GATHERING_COMPLETE = 2
 } rtcGatheringState;
 
-// Don't change, it must match plog severity
-typedef enum {
+typedef enum { // Don't change, it must match plog severity
 	RTC_LOG_NONE = 0,
 	RTC_LOG_FATAL = 1,
 	RTC_LOG_ERROR = 2,
@@ -76,10 +79,10 @@ typedef void (*availableCallbackFunc)(void *ptr);
 void rtcInitLogger(rtcLogLevel level);
 
 // User pointer
-void rtcSetUserPointer(int i, void *ptr);
+void rtcSetUserPointer(int id, void *ptr);
 
 // PeerConnection
-int rtcCreatePeerConnection(const rtcConfiguration *config);
+int rtcCreatePeerConnection(const rtcConfiguration *config); // returns pc id
 int rtcDeletePeerConnection(int pc);
 
 int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb);
@@ -95,24 +98,32 @@ int rtcGetLocalAddress(int pc, char *buffer, int size);
 int rtcGetRemoteAddress(int pc, char *buffer, int size);
 
 // DataChannel
-int rtcCreateDataChannel(int pc, const char *label);
+int rtcCreateDataChannel(int pc, const char *label); // returns dc id
 int rtcDeleteDataChannel(int dc);
 
 int rtcGetDataChannelLabel(int dc, char *buffer, int size);
-int rtcSetOpenCallback(int dc, openCallbackFunc cb);
-int rtcSetClosedCallback(int dc, closedCallbackFunc cb);
-int rtcSetErrorCallback(int dc, errorCallbackFunc cb);
-int rtcSetMessageCallback(int dc, messageCallbackFunc cb);
-int rtcSendMessage(int dc, const char *data, int size);
-
-int rtcGetBufferedAmount(int dc); // total size buffered to send
-int rtcSetBufferedAmountLowThreshold(int dc, int amount);
-int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb);
-
-// DataChannel extended API
-int rtcGetAvailableAmount(int dc); // total size available to receive
-int rtcSetAvailableCallback(int dc, availableCallbackFunc cb);
-int rtcReceiveMessage(int dc, char *buffer, int *size);
+
+// WebSocket
+#if RTC_ENABLE_WEBSOCKET
+int rtcCreateWebSocket(const char *url); // returns ws id
+int rtcDeleteWebsocket(int ws);
+#endif
+
+// DataChannel and WebSocket common API
+int rtcSetOpenCallback(int id, openCallbackFunc cb);
+int rtcSetClosedCallback(int id, closedCallbackFunc cb);
+int rtcSetErrorCallback(int id, errorCallbackFunc cb);
+int rtcSetMessageCallback(int id, messageCallbackFunc cb);
+int rtcSendMessage(int id, const char *data, int size);
+
+int rtcGetBufferedAmount(int id); // total size buffered to send
+int rtcSetBufferedAmountLowThreshold(int id, int amount);
+int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb);
+
+// DataChannel and WebSocket common extended API
+int rtcGetAvailableAmount(int id); // total size available to receive
+int rtcSetAvailableCallback(int id, availableCallbackFunc cb);
+int rtcReceiveMessage(int id, char *buffer, int *size);
 
 // Cleanup
 void rtcCleanup();

+ 1 - 0
include/rtc/rtc.hpp

@@ -23,6 +23,7 @@
 //
 #include "datachannel.hpp"
 #include "peerconnection.hpp"
+#include "websocket.hpp"
 
 // C API
 #include "rtc.h"

+ 95 - 0
include/rtc/websocket.hpp

@@ -0,0 +1,95 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_WEBSOCKET_H
+#define RTC_WEBSOCKET_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "channel.hpp"
+#include "include.hpp"
+#include "init.hpp"
+#include "message.hpp"
+#include "queue.hpp"
+
+#include <atomic>
+#include <optional>
+#include <thread>
+#include <variant>
+
+namespace rtc {
+
+class TcpTransport;
+class TlsTransport;
+class WsTransport;
+
+class WebSocket final : public Channel, public std::enable_shared_from_this<WebSocket> {
+public:
+	enum class State : int {
+		Connecting = 0,
+		Open = 1,
+		Closing = 2,
+		Closed = 3,
+	};
+
+	WebSocket();
+	WebSocket(const string &url);
+	~WebSocket();
+
+	State readyState() const;
+
+	void open(const string &url);
+	void close() override;
+	bool send(const std::variant<binary, string> &data) override;
+
+	bool isOpen() const override;
+	bool isClosed() const override;
+	size_t maxMessageSize() const override;
+
+	// Extended API
+	std::optional<std::variant<binary, string>> receive() override;
+	size_t availableAmount() const override; // total size available to receive
+
+private:
+	bool changeState(State state);
+	void remoteClose();
+	bool outgoing(mutable_message_ptr message);
+	void incoming(message_ptr message);
+
+	std::shared_ptr<TcpTransport> initTcpTransport();
+	std::shared_ptr<TlsTransport> initTlsTransport();
+	std::shared_ptr<WsTransport> initWsTransport();
+	void closeTransports();
+
+	init_token mInitToken = Init::Token();
+
+	std::shared_ptr<TcpTransport> mTcpTransport;
+	std::shared_ptr<TlsTransport> mTlsTransport;
+	std::shared_ptr<WsTransport> mWsTransport;
+	std::recursive_mutex mInitMutex;
+
+	string mScheme, mHost, mHostname, mService, mPath;
+	std::atomic<State> mState = State::Closed;
+
+	Queue<message_ptr> mRecvQueue;
+};
+} // namespace rtc
+
+#endif
+
+#endif // RTC_WEBSOCKET_H

+ 65 - 0
src/base64.cpp

@@ -0,0 +1,65 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "base64.hpp"
+
+namespace rtc {
+
+using std::to_integer;
+
+string to_base64(const binary &data) {
+	static const char tab[] =
+	    "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+	string out;
+	out.reserve(3 * ((data.size() + 3) / 4));
+	int i = 0;
+	while (data.size() - i >= 3) {
+		auto d0 = to_integer<uint8_t>(data[i]);
+		auto d1 = to_integer<uint8_t>(data[i + 1]);
+		auto d2 = to_integer<uint8_t>(data[i + 2]);
+		out += tab[d0 >> 2];
+		out += tab[((d0 & 3) << 4) | (d1 >> 4)];
+		out += tab[((d1 & 0x0F) << 2) | (d2 >> 6)];
+		out += tab[d2 & 0x3F];
+		i += 3;
+	}
+
+	int left = data.size() - i;
+	if (left) {
+		auto d0 = to_integer<uint8_t>(data[i]);
+		out += tab[d0 >> 2];
+		if (left == 1) {
+			out += tab[(d0 & 3) << 4];
+			out += '=';
+		} else { // left == 2
+			auto d1 = to_integer<uint8_t>(data[i + 1]);
+			out += tab[((d0 & 3) << 4) | (d1 >> 4)];
+			out += tab[(d1 & 0x0F) << 2];
+		}
+		out += '=';
+	}
+
+	return out;
+}
+
+} // namespace rtc
+
+#endif

+ 34 - 0
src/base64.hpp

@@ -0,0 +1,34 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_BASE64_H
+#define RTC_BASE64_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "include.hpp"
+
+namespace rtc {
+
+string to_base64(const binary &data);
+
+}
+
+#endif
+
+#endif

+ 3 - 0
src/datachannel.cpp

@@ -214,6 +214,9 @@ bool DataChannel::outgoing(mutable_message_ptr message) {
 }
 
 void DataChannel::incoming(message_ptr message) {
+	if (!message)
+		return;
+
 	switch (message->type) {
 	case Message::Control: {
 		auto raw = reinterpret_cast<const uint8_t *>(message->data());

+ 104 - 110
src/dtlstransport.cpp

@@ -18,9 +18,7 @@
 
 #include "dtlstransport.hpp"
 #include "icetransport.hpp"
-#include "message.hpp"
 
-#include <cassert>
 #include <chrono>
 #include <cstring>
 #include <exception>
@@ -64,11 +62,9 @@ void DtlsTransport::Cleanup() {
 }
 
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
-                             verifier_callback verifierCallback,
-                             state_callback stateChangeCallback)
-    : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
-      mVerifierCallback(std::move(verifierCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
+                             verifier_callback verifierCallback, state_callback stateChangeCallback)
+    : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
+      mVerifierCallback(std::move(verifierCallback)) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
@@ -76,31 +72,37 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 	unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
 	check_gnutls(gnutls_init(&mSession, flags));
 
-	// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
-	// Therefore, the DTLS layer MUST NOT use any compression algorithm.
-	// See https://tools.ietf.org/html/rfc8261#section-5
-	const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
-	const char *err_pos = NULL;
-	check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
-	             "Unable to set TLS priorities");
-
-	gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
-	check_gnutls(
-	    gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials()));
-
-	gnutls_dtls_set_timeouts(mSession,
-	                         1000,   // 1s retransmission timeout recommended by RFC 6347
-	                         30000); // 30s total timeout
-	gnutls_handshake_set_timeout(mSession, 30000);
-
-	gnutls_session_set_ptr(mSession, this);
-	gnutls_transport_set_ptr(mSession, this);
-	gnutls_transport_set_push_function(mSession, WriteCallback);
-	gnutls_transport_set_pull_function(mSession, ReadCallback);
-	gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
-
-	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
-	registerIncoming();
+	try {
+		// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
+		// Therefore, the DTLS layer MUST NOT use any compression algorithm.
+		// See https://tools.ietf.org/html/rfc8261#section-5
+		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
+		const char *err_pos = NULL;
+		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
+		             "Failed to set TLS priorities");
+
+		gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
+		check_gnutls(
+		    gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials()));
+
+		gnutls_dtls_set_timeouts(mSession,
+		                         1000,   // 1s retransmission timeout recommended by RFC 6347
+		                         30000); // 30s total timeout
+		gnutls_handshake_set_timeout(mSession, 30000);
+
+		gnutls_session_set_ptr(mSession, this);
+		gnutls_transport_set_ptr(mSession, this);
+		gnutls_transport_set_push_function(mSession, WriteCallback);
+		gnutls_transport_set_pull_function(mSession, ReadCallback);
+		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
+
+		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+		registerIncoming();
+
+	} catch (...) {
+		gnutls_deinit(mSession);
+		throw;
+	}
 }
 
 DtlsTransport::~DtlsTransport() {
@@ -109,8 +111,6 @@ DtlsTransport::~DtlsTransport() {
 	gnutls_deinit(mSession);
 }
 
-DtlsTransport::State DtlsTransport::state() const { return mState; }
-
 bool DtlsTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -122,7 +122,7 @@ bool DtlsTransport::stop() {
 }
 
 bool DtlsTransport::send(message_ptr message) {
-	if (!message || mState != State::Connected)
+	if (!message || state() != State::Connected)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -148,11 +148,6 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
-void DtlsTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(state);
-}
-
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 
@@ -169,7 +164,7 @@ void DtlsTransport::runRecvLoop() {
 				throw std::runtime_error("MTU is too low");
 
 		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
-		         !check_gnutls(ret, "TLS handshake failed"));
+		         !check_gnutls(ret, "DTLS handshake failed"));
 
 		// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
 		// See https://tools.ietf.org/html/rfc8261#section-5
@@ -183,7 +178,7 @@ void DtlsTransport::runRecvLoop() {
 
 	// Receive loop
 	try {
-		PLOG_INFO << "DTLS handshake done";
+		PLOG_INFO << "DTLS handshake finished";
 		changeState(State::Connected);
 
 		const size_t bufferSize = maxMtu;
@@ -218,7 +213,7 @@ void DtlsTransport::runRecvLoop() {
 
 	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
 
-	PLOG_INFO << "DTLS disconnected";
+	PLOG_INFO << "DTLS closed";
 	changeState(State::Disconnected);
 	recv(nullptr);
 }
@@ -341,7 +336,7 @@ void DtlsTransport::Init() {
 	if (!BioMethods) {
 		BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
 		if (!BioMethods)
-			throw std::runtime_error("Unable to BIO methods for DTLS writer");
+			throw std::runtime_error("Failed to create BIO methods for DTLS writer");
 		BIO_meth_set_create(BioMethods, BioMethodNew);
 		BIO_meth_set_destroy(BioMethods, BioMethodFree);
 		BIO_meth_set_write(BioMethods, BioMethodWrite);
@@ -358,60 +353,68 @@ void DtlsTransport::Cleanup() {
 
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
-    : Transport(lower), mCertificate(certificate), mState(State::Disconnected),
-      mVerifierCallback(std::move(verifierCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)) {
+    : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
+      mVerifierCallback(std::move(verifierCallback)) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
-	if (!(mCtx = SSL_CTX_new(DTLS_method())))
-		throw std::runtime_error("Unable to create SSL context");
-
-	check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
-	              "Unable to set SSL priorities");
-
-	// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
-	// Therefore, the DTLS layer MUST NOT use any compression algorithm.
-	// See https://tools.ietf.org/html/rfc8261#section-5
-	SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_QUERY_MTU);
-	SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION);
-	SSL_CTX_set_read_ahead(mCtx, 1);
-	SSL_CTX_set_quiet_shutdown(mCtx, 1);
-	SSL_CTX_set_info_callback(mCtx, InfoCallback);
-	SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
-	                   CertificateCallback);
-	SSL_CTX_set_verify_depth(mCtx, 1);
-
-	auto [x509, pkey] = mCertificate->credentials();
-	SSL_CTX_use_certificate(mCtx, x509);
-	SSL_CTX_use_PrivateKey(mCtx, pkey);
-
-	check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
-
-	if (!(mSsl = SSL_new(mCtx)))
-		throw std::runtime_error("Unable to create SSL instance");
-
-	SSL_set_ex_data(mSsl, TransportExIndex, this);
-
-	if (lower->role() == Description::Role::Active)
-		SSL_set_connect_state(mSsl);
-	else
-		SSL_set_accept_state(mSsl);
-
-	if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods)))
-		throw std::runtime_error("Unable to create BIO");
-
-	BIO_set_mem_eof_return(mInBio, BIO_EOF);
-	BIO_set_data(mOutBio, this);
-	SSL_set_bio(mSsl, mInBio, mOutBio);
+	try {
+		if (!(mCtx = SSL_CTX_new(DTLS_method())))
+			throw std::runtime_error("Failed to create SSL context");
 
-	auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
-	    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
-	SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
-	SSL_set_tmp_ecdh(mSsl, ecdh.get());
+		check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
+		              "Failed to set SSL priorities");
 
-	mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
-	registerIncoming();
+		// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
+		// Therefore, the DTLS layer MUST NOT use any compression algorithm.
+		// See https://tools.ietf.org/html/rfc8261#section-5
+		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION | SSL_OP_NO_QUERY_MTU);
+		SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION);
+		SSL_CTX_set_read_ahead(mCtx, 1);
+		SSL_CTX_set_quiet_shutdown(mCtx, 1);
+		SSL_CTX_set_info_callback(mCtx, InfoCallback);
+		SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
+		                   CertificateCallback);
+		SSL_CTX_set_verify_depth(mCtx, 1);
+
+		auto [x509, pkey] = mCertificate->credentials();
+		SSL_CTX_use_certificate(mCtx, x509);
+		SSL_CTX_use_PrivateKey(mCtx, pkey);
+
+		check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
+
+		if (!(mSsl = SSL_new(mCtx)))
+			throw std::runtime_error("Failed to create SSL instance");
+
+		SSL_set_ex_data(mSsl, TransportExIndex, this);
+
+		if (lower->role() == Description::Role::Active)
+			SSL_set_connect_state(mSsl);
+		else
+			SSL_set_accept_state(mSsl);
+
+		if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BioMethods)))
+			throw std::runtime_error("Failed to create BIO");
+
+		BIO_set_mem_eof_return(mInBio, BIO_EOF);
+		BIO_set_data(mOutBio, this);
+		SSL_set_bio(mSsl, mInBio, mOutBio);
+
+		auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
+		    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
+		SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
+		SSL_set_tmp_ecdh(mSsl, ecdh.get());
+
+		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
+		registerIncoming();
+
+	} catch (...) {
+		if (mSsl)
+			SSL_free(mSsl);
+		if (mCtx)
+			SSL_CTX_free(mCtx);
+		throw;
+	}
 }
 
 DtlsTransport::~DtlsTransport() {
@@ -432,18 +435,14 @@ bool DtlsTransport::stop() {
 	return true;
 }
 
-DtlsTransport::State DtlsTransport::state() const { return mState; }
-
 bool DtlsTransport::send(message_ptr message) {
-	if (!message || mState != State::Connected)
+	if (!message || state() != State::Connected)
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 
 	int ret = SSL_write(mSsl, message->data(), message->size());
-	if (!check_openssl_ret(mSsl, ret))
-		return false;
-	return true;
+	return check_openssl_ret(mSsl, ret);
 }
 
 void DtlsTransport::incoming(message_ptr message) {
@@ -456,11 +455,6 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
-void DtlsTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(state);
-}
-
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 	try {
@@ -479,7 +473,7 @@ void DtlsTransport::runRecvLoop() {
 				auto message = *mIncomingQueue.pop();
 				BIO_write(mInBio, message->data(), message->size());
 
-				if (mState == State::Connecting) {
+				if (state() == State::Connecting) {
 					// Continue the handshake
 					int ret = SSL_do_handshake(mSsl);
 					if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
@@ -490,7 +484,7 @@ void DtlsTransport::runRecvLoop() {
 						// MTU See https://tools.ietf.org/html/rfc8261#section-5
 						SSL_set_mtu(mSsl, maxMtu + 1);
 
-						PLOG_INFO << "DTLS handshake done";
+						PLOG_INFO << "DTLS handshake finished";
 						changeState(State::Connected);
 					}
 				} else {
@@ -504,7 +498,7 @@ void DtlsTransport::runRecvLoop() {
 
 			// No more messages pending, retransmit and rearm timeout if connecting
 			std::optional<milliseconds> duration;
-			if (mState == State::Connecting) {
+			if (state() == State::Connecting) {
 				// Warning: This function breaks the usual return value convention
 				int ret = DTLSv1_handle_timeout(mSsl);
 				if (ret < 0) {
@@ -514,7 +508,7 @@ void DtlsTransport::runRecvLoop() {
 				}
 
 				struct timeval timeout = {};
-				if (mState == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
+				if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
 					duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
 					// Also handle handshake timeout manually because OpenSSL actually doesn't...
 					// OpenSSL backs off exponentially in base 2 starting from the recommended 1s
@@ -535,8 +529,8 @@ void DtlsTransport::runRecvLoop() {
 		PLOG_ERROR << "DTLS recv: " << e.what();
 	}
 
-	if (mState == State::Connected) {
-		PLOG_INFO << "DTLS disconnected";
+	if (state() == State::Connected) {
+		PLOG_INFO << "DTLS closed";
 		changeState(State::Disconnected);
 		recv(nullptr);
 	} else {

+ 2 - 10
src/dtlstransport.hpp

@@ -46,33 +46,25 @@ public:
 	static void Init();
 	static void Cleanup();
 
-	enum class State { Disconnected, Connecting, Connected, Failed };
-
 	using verifier_callback = std::function<bool(const std::string &fingerprint)>;
-	using state_callback = std::function<void(State state)>;
 
 	DtlsTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
 	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
-	State state() const;
-
 	bool stop() override;
 	bool send(message_ptr message) override; // false if dropped
 
 private:
 	void incoming(message_ptr message) override;
-	void changeState(State state);
 	void runRecvLoop();
 
 	const std::shared_ptr<Certificate> mCertificate;
 
 	Queue<message_ptr> mIncomingQueue;
-	std::atomic<State> mState;
 	std::thread mRecvThread;
 
 	verifier_callback mVerifierCallback;
-	state_callback mStateChangeCallback;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;
@@ -82,8 +74,8 @@ private:
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 #else
-	SSL_CTX *mCtx;
-	SSL *mSsl;
+	SSL_CTX *mCtx = NULL;
+	SSL *mSsl = NULL;
 	BIO *mInBio, *mOutBio;
 
 	static BIO_METHOD *BioMethods;

+ 43 - 24
src/icetransport.cpp

@@ -48,9 +48,8 @@ namespace rtc {
 IceTransport::IceTransport(const Configuration &config, Description::Role role,
                            candidate_callback candidateCallback, state_callback stateChangeCallback,
                            gathering_state_callback gatheringStateChangeCallback)
-    : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
-      mCandidateCallback(std::move(candidateCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)),
+    : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
+      mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
       mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
       mAgent(nullptr, nullptr) {
 
@@ -84,6 +83,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
 			mStunService = server.service;
 			jconfig.stun_server_host = mStunHostname.c_str();
 			jconfig.stun_server_port = std::stoul(mStunService);
+			break;
 		}
 	}
 
@@ -108,8 +108,6 @@ bool IceTransport::stop() {
 
 Description::Role IceTransport::role() const { return mRole; }
 
-IceTransport::State IceTransport::state() const { return mState; }
-
 Description IceTransport::getLocalDescription(Description::Type type) const {
 	char sdp[JUICE_MAX_SDP_STRING_LEN];
 	if (juice_get_local_description(mAgent.get(), sdp, JUICE_MAX_SDP_STRING_LEN) < 0)
@@ -161,7 +159,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
 }
 
 bool IceTransport::send(message_ptr message) {
-	if (!message || (mState != State::Connected && mState != State::Completed))
+	auto s = state();
+	if (!message || (s != State::Connected && s != State::Completed))
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -173,18 +172,29 @@ bool IceTransport::outgoing(message_ptr message) {
 	                  message->size()) >= 0;
 }
 
-void IceTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(mState);
-}
-
 void IceTransport::changeGatheringState(GatheringState state) {
 	if (mGatheringState.exchange(state) != state)
 		mGatheringStateChangeCallback(mGatheringState);
 }
 
 void IceTransport::processStateChange(unsigned int state) {
-	changeState(static_cast<State>(state));
+	switch (state) {
+	case JUICE_STATE_DISCONNECTED:
+		changeState(State::Disconnected);
+		break;
+	case JUICE_STATE_CONNECTING:
+		changeState(State::Connecting);
+		break;
+	case JUICE_STATE_CONNECTED:
+		changeState(State::Connected);
+		break;
+	case JUICE_STATE_COMPLETED:
+		changeState(State::Completed);
+		break;
+	case JUICE_STATE_FAILED:
+		changeState(State::Failed);
+		break;
+	};
 }
 
 void IceTransport::processCandidate(const string &candidate) {
@@ -263,9 +273,8 @@ namespace rtc {
 IceTransport::IceTransport(const Configuration &config, Description::Role role,
                            candidate_callback candidateCallback, state_callback stateChangeCallback,
                            gathering_state_callback gatheringStateChangeCallback)
-    : mRole(role), mMid("0"), mState(State::Disconnected), mGatheringState(GatheringState::New),
-      mCandidateCallback(std::move(candidateCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)),
+    : Transport(nullptr, std::move(stateChangeCallback)), mRole(role), mMid("0"),
+      mGatheringState(GatheringState::New), mCandidateCallback(std::move(candidateCallback)),
       mGatheringStateChangeCallback(std::move(gatheringStateChangeCallback)),
       mNiceAgent(nullptr, nullptr), mMainLoop(nullptr, nullptr) {
 
@@ -457,8 +466,6 @@ bool IceTransport::stop() {
 
 Description::Role IceTransport::role() const { return mRole; }
 
-IceTransport::State IceTransport::state() const { return mState; }
-
 Description IceTransport::getLocalDescription(Description::Type type) const {
 	// RFC 8445: The initiating agent that started the ICE processing MUST take the controlling
 	// role, and the other MUST take the controlled role.
@@ -529,7 +536,8 @@ std::optional<string> IceTransport::getRemoteAddress() const {
 }
 
 bool IceTransport::send(message_ptr message) {
-	if (!message || (mState != State::Connected && mState != State::Completed))
+	auto s = state();
+	if (!message || (s != State::Connected && s != State::Completed))
 		return false;
 
 	PLOG_VERBOSE << "Send size=" << message->size();
@@ -541,11 +549,6 @@ bool IceTransport::outgoing(message_ptr message) {
 	                       reinterpret_cast<const char *>(message->data())) >= 0;
 }
 
-void IceTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(mState);
-}
-
 void IceTransport::changeGatheringState(GatheringState state) {
 	if (mGatheringState.exchange(state) != state)
 		mGatheringStateChangeCallback(mGatheringState);
@@ -576,7 +579,23 @@ void IceTransport::processStateChange(unsigned int state) {
 		mTimeoutId = 0;
 	}
 
-	changeState(static_cast<State>(state));
+	switch (state) {
+	case NICE_COMPONENT_STATE_DISCONNECTED:
+		changeState(State::Disconnected);
+		break;
+	case NICE_COMPONENT_STATE_CONNECTING:
+		changeState(State::Connecting);
+		break;
+	case NICE_COMPONENT_STATE_CONNECTED:
+		changeState(State::Connected);
+		break;
+	case NICE_COMPONENT_STATE_READY:
+		changeState(State::Completed);
+		break;
+	case NICE_COMPONENT_STATE_FAILED:
+		changeState(State::Failed);
+		break;
+	};
 }
 
 string IceTransport::AddressToString(const NiceAddress &addr) {

+ 4 - 24
src/icetransport.hpp

@@ -40,29 +40,9 @@ namespace rtc {
 
 class IceTransport : public Transport {
 public:
-#if USE_JUICE
-	enum class State : unsigned int{
-	    Disconnected = JUICE_STATE_DISCONNECTED,
-	    Connecting = JUICE_STATE_CONNECTING,
-	    Connected = JUICE_STATE_CONNECTED,
-	    Completed = JUICE_STATE_COMPLETED,
-	    Failed = JUICE_STATE_FAILED,
-	};
-#else
-	enum class State : unsigned int {
-		Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
-		Connecting = NICE_COMPONENT_STATE_CONNECTING,
-		Connected = NICE_COMPONENT_STATE_CONNECTED,
-		Completed = NICE_COMPONENT_STATE_READY,
-		Failed = NICE_COMPONENT_STATE_FAILED,
-	};
-
-	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
-#endif
 	enum class GatheringState { New = 0, InProgress = 1, Complete = 2 };
 
 	using candidate_callback = std::function<void(const Candidate &candidate)>;
-	using state_callback = std::function<void(State state)>;
 	using gathering_state_callback = std::function<void(GatheringState state)>;
 
 	IceTransport(const Configuration &config, Description::Role role,
@@ -71,7 +51,6 @@ public:
 	~IceTransport();
 
 	Description::Role role() const;
-	State state() const;
 	GatheringState gatheringState() const;
 	Description getLocalDescription(Description::Type type) const;
 	void setRemoteDescription(const Description &description);
@@ -84,10 +63,13 @@ public:
 	bool stop() override;
 	bool send(message_ptr message) override; // false if dropped
 
+#if !USE_JUICE
+	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
+#endif
+
 private:
 	bool outgoing(message_ptr message) override;
 
-	void changeState(State state);
 	void changeGatheringState(GatheringState state);
 
 	void processStateChange(unsigned int state);
@@ -98,11 +80,9 @@ private:
 	Description::Role mRole;
 	string mMid;
 	std::chrono::milliseconds mTrickleTimeout;
-	std::atomic<State> mState;
 	std::atomic<GatheringState> mGatheringState;
 
 	candidate_callback mCandidateCallback;
-	state_callback mStateChangeCallback;
 	gathering_state_callback mGatheringStateChangeCallback;
 
 #if USE_JUICE

+ 12 - 2
src/init.cpp

@@ -21,6 +21,10 @@
 #include "dtlstransport.hpp"
 #include "sctptransport.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+#include "tlstransport.hpp"
+#endif
+
 #ifdef _WIN32
 #include <winsock2.h>
 #endif
@@ -69,13 +73,19 @@ Init::Init() {
 	ERR_load_crypto_strings();
 #endif
 
-	DtlsTransport::Init();
 	SctpTransport::Init();
+	DtlsTransport::Init();
+#if RTC_ENABLE_WEBSOCKET
+	TlsTransport::Init();
+#endif
 }
 
 Init::~Init() {
-	DtlsTransport::Cleanup();
 	SctpTransport::Cleanup();
+	DtlsTransport::Cleanup();
+#if RTC_ENABLE_WEBSOCKET
+	TlsTransport::Cleanup();
+#endif
 
 #ifdef _WIN32
 	WSACleanup();

+ 1 - 19
src/peerconnection.cpp

@@ -23,7 +23,6 @@
 #include "include.hpp"
 #include "sctptransport.hpp"
 
-#include <iostream>
 #include <thread>
 
 namespace rtc {
@@ -33,23 +32,6 @@ using namespace std::placeholders;
 using std::shared_ptr;
 using std::weak_ptr;
 
-template <typename F, typename T, typename... Args> auto weak_bind(F &&f, T *t, Args &&... _args) {
-	return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
-		if (auto shared_this = weak_this.lock())
-			bound(args...);
-	};
-}
-
-template <typename F, typename T, typename... Args>
-auto weak_bind_verifier(F &&f, T *t, Args &&... _args) {
-	return [bound = std::bind(f, t, _args...), weak_this = t->weak_from_this()](auto &&... args) {
-		if (auto shared_this = weak_this.lock())
-			return bound(args...);
-		else
-			return false;
-	};
-}
-
 PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
 
 PeerConnection::PeerConnection(const Configuration &config)
@@ -271,7 +253,7 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 
 		auto lower = std::atomic_load(&mIceTransport);
 		auto transport = std::make_shared<DtlsTransport>(
-		    lower, mCertificate, weak_bind_verifier(&PeerConnection::checkFingerprint, this, _1),
+		    lower, mCertificate, weak_bind(&PeerConnection::checkFingerprint, this, _1),
 		    [this, weak_this = weak_from_this()](DtlsTransport::State state) {
 			    auto shared_this = weak_this.lock();
 			    if (!shared_this)

+ 123 - 58
src/rtc.cpp

@@ -16,10 +16,15 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#include "datachannel.hpp"
 #include "include.hpp"
+
+#include "datachannel.hpp"
 #include "peerconnection.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+#include "websocket.hpp"
+#endif
+
 #include <rtc.h>
 
 #include <exception>
@@ -43,6 +48,9 @@ namespace {
 
 std::unordered_map<int, shared_ptr<PeerConnection>> peerConnectionMap;
 std::unordered_map<int, shared_ptr<DataChannel>> dataChannelMap;
+#if RTC_ENABLE_WEBSOCKET
+std::unordered_map<int, shared_ptr<WebSocket>> webSocketMap;
+#endif
 std::unordered_map<int, void *> userPointerMap;
 std::mutex mutex;
 int lastId = 0;
@@ -103,6 +111,40 @@ bool eraseDataChannel(int dc) {
 	return true;
 }
 
+#if RTC_ENABLE_WEBSOCKET
+shared_ptr<WebSocket> getWebSocket(int id) {
+	std::lock_guard lock(mutex);
+	auto it = webSocketMap.find(id);
+	return it != webSocketMap.end() ? it->second : nullptr;
+}
+
+int emplaceWebSocket(shared_ptr<WebSocket> ptr) {
+	std::lock_guard lock(mutex);
+	int ws = ++lastId;
+	webSocketMap.emplace(std::make_pair(ws, ptr));
+	return ws;
+}
+
+bool eraseWebSocket(int ws) {
+	std::lock_guard lock(mutex);
+	if (webSocketMap.erase(ws) == 0)
+		return false;
+	userPointerMap.erase(ws);
+	return true;
+}
+#endif
+
+shared_ptr<Channel> getChannel(int id) {
+	std::lock_guard lock(mutex);
+	if (auto it = dataChannelMap.find(id); it != dataChannelMap.end())
+		return it->second;
+#if RTC_ENABLE_WEBSOCKET
+	if (auto it = webSocketMap.find(id); it != webSocketMap.end())
+		return it->second;
+#endif
+	return nullptr;
+}
+
 } // namespace
 
 void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
@@ -164,6 +206,29 @@ int rtcDeleteDataChannel(int dc) {
 	return 0;
 }
 
+#if RTC_ENABLE_WEBSOCKET
+int rtcCreateWebSocket(const char *url) {
+	return emplaceWebSocket(std::make_shared<WebSocket>(url));
+}
+
+int rtcDeleteWebsocket(int ws) {
+	auto webSocket = getWebSocket(ws);
+	if (!webSocket)
+		return -1;
+
+	webSocket->onOpen(nullptr);
+	webSocket->onClosed(nullptr);
+	webSocket->onError(nullptr);
+	webSocket->onMessage(nullptr);
+	webSocket->onBufferedAmountLow(nullptr);
+	webSocket->onAvailable(nullptr);
+
+	eraseWebSocket(ws);
+	return 0;
+}
+
+#endif
+
 int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb) {
 	auto peerConnection = getPeerConnection(pc);
 	if (!peerConnection)
@@ -298,135 +363,135 @@ int rtcGetDataChannelLabel(int dc, char *buffer, int size) {
 	return size + 1;
 }
 
-int rtcSetOpenCallback(int dc, openCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetOpenCallback(int id, openCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onOpen(nullptr);
+		channel->onOpen(nullptr);
 	return 0;
 }
 
-int rtcSetClosedCallback(int dc, closedCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetClosedCallback(int id, closedCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onClosed([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onClosed([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onClosed(nullptr);
+		channel->onClosed(nullptr);
 	return 0;
 }
 
-int rtcSetErrorCallback(int dc, errorCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetErrorCallback(int id, errorCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onError(
-		    [dc, cb](const string &error) { cb(error.c_str(), getUserPointer(dc)); });
+		channel->onError([id, cb](const string &error) { cb(error.c_str(), getUserPointer(id)); });
 	else
-		dataChannel->onError(nullptr);
+		channel->onError(nullptr);
 	return 0;
 }
 
-int rtcSetMessageCallback(int dc, messageCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetMessageCallback(int id, messageCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onMessage(
-		    [dc, cb](const binary &b) {
-			    cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(dc));
+		channel->onMessage(
+		    [id, cb](const binary &b) {
+			    cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(id));
 		    },
-		    [dc, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(dc)); });
+		    [id, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(id)); });
 	else
-		dataChannel->onMessage(nullptr);
+		channel->onMessage(nullptr);
 
 	return 0;
 }
 
-int rtcSendMessage(int dc, const char *data, int size) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSendMessage(int id, const char *data, int size) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (size >= 0) {
 		auto b = reinterpret_cast<const byte *>(data);
-		CATCH(dataChannel->send(b, size));
+		CATCH(channel->send(binary(b, b + size)));
 		return size;
 	} else {
-		string s(data);
-		CATCH(dataChannel->send(s));
-		return s.size();
+		string str(data);
+		int len = str.size();
+		CATCH(channel->send(std::move(str)));
+		return len;
 	}
 }
 
-int rtcGetBufferedAmount(int dc) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcGetBufferedAmount(int id) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
-	CATCH(return int(dataChannel->bufferedAmount()));
+	CATCH(return int(channel->bufferedAmount()));
 }
 
-int rtcSetBufferedAmountLowThreshold(int dc, int amount) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetBufferedAmountLowThreshold(int id, int amount) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
-	CATCH(dataChannel->setBufferedAmountLowThreshold(size_t(amount)));
+	CATCH(channel->setBufferedAmountLowThreshold(size_t(amount)));
 	return 0;
 }
 
-int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onBufferedAmountLow([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onBufferedAmountLow([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onBufferedAmountLow(nullptr);
+		channel->onBufferedAmountLow(nullptr);
 	return 0;
 }
 
-int rtcGetAvailableAmount(int dc) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcGetAvailableAmount(int id) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
-	CATCH(return int(dataChannel->availableAmount()));
+	CATCH(return int(channel->availableAmount()));
 }
 
-int rtcSetAvailableCallback(int dc, availableCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetAvailableCallback(int id, availableCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onOpen(nullptr);
+		channel->onOpen(nullptr);
 	return 0;
 }
 
-int rtcReceiveMessage(int dc, char *buffer, int *size) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcReceiveMessage(int id, char *buffer, int *size) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (!size)
 		return -1;
 
 	CATCH({
-		auto message = dataChannel->receive();
+		auto message = channel->receive();
 		if (!message)
 			return 0;
 

+ 7 - 14
src/sctptransport.cpp

@@ -71,9 +71,8 @@ void SctpTransport::Cleanup() {
 SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
                              message_callback recvCallback, amount_callback bufferedAmountCallback,
                              state_callback stateChangeCallback)
-    : Transport(lower), mPort(port), mSendQueue(0, message_size_func),
-      mBufferedAmountCallback(std::move(bufferedAmountCallback)),
-      mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
+    : Transport(lower, std::move(stateChangeCallback)), mPort(port),
+      mSendQueue(0, message_size_func), mBufferedAmountCallback(std::move(bufferedAmountCallback)) {
 	onRecv(recvCallback);
 
 	PLOG_DEBUG << "Initializing SCTP transport";
@@ -180,8 +179,6 @@ SctpTransport::~SctpTransport() {
 	usrsctp_deregister_address(this);
 }
 
-SctpTransport::State SctpTransport::state() const { return mState; }
-
 bool SctpTransport::stop() {
 	if (!Transport::stop())
 		return false;
@@ -240,6 +237,7 @@ void SctpTransport::shutdown() {
 
 bool SctpTransport::send(message_ptr message) {
 	std::lock_guard lock(mSendMutex);
+
 	if (!message)
 		return mSendQueue.empty();
 
@@ -269,7 +267,7 @@ void SctpTransport::incoming(message_ptr message) {
 	// to be sent on our side (i.e. the local INIT) before proceeding.
 	{
 		std::unique_lock lock(mWriteMutex);
-		mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || mState != State::Connected; });
+		mWrittenCondition.wait(lock, [&]() { return mWrittenOnce || state() != State::Connected; });
 	}
 
 	if (!message) {
@@ -283,11 +281,6 @@ void SctpTransport::incoming(message_ptr message) {
 	usrsctp_conninput(this, message->data(), message->size(), 0);
 }
 
-void SctpTransport::changeState(State state) {
-	if (mState.exchange(state) != state)
-		mStateChangeCallback(state);
-}
-
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
@@ -302,7 +295,7 @@ bool SctpTransport::trySendQueue() {
 
 bool SctpTransport::trySendMessage(message_ptr message) {
 	// Requires mSendMutex to be locked
-	if (!mSock || mState != State::Connected)
+	if (!mSock || state() != State::Connected)
 		return false;
 
 	uint32_t ppid;
@@ -414,7 +407,7 @@ void SctpTransport::sendReset(uint16_t streamId) {
 	if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len) == 0) {
 		std::unique_lock lock(mWriteMutex); // locking before setsockopt might deadlock usrsctp...
 		mWrittenCondition.wait_for(lock, 1000ms,
-		                           [&]() { return mWritten || mState != State::Connected; });
+		                           [&]() { return mWritten || state() != State::Connected; });
 	} else if (errno == EINVAL) {
 		PLOG_VERBOSE << "SCTP stream " << streamId << " already reset";
 	} else {
@@ -571,7 +564,7 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
 			PLOG_INFO << "SCTP connected";
 			changeState(State::Connected);
 		} else {
-			if (mState == State::Connecting) {
+			if (state() == State::Connecting) {
 				PLOG_ERROR << "SCTP connection failed";
 				changeState(State::Failed);
 			} else {

+ 1 - 10
src/sctptransport.hpp

@@ -38,17 +38,12 @@ public:
 	static void Init();
 	static void Cleanup();
 
-	enum class State { Disconnected, Connecting, Connected, Failed };
-
 	using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
-	using state_callback = std::function<void(State state)>;
 
 	SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
 	              amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
 	~SctpTransport();
 
-	State state() const;
-
 	bool stop() override;
 	bool send(message_ptr message) override; // false if buffered
 	void close(unsigned int stream);
@@ -76,7 +71,6 @@ private:
 	void connect();
 	void shutdown();
 	void incoming(message_ptr message) override;
-	void changeState(State state);
 
 	bool trySendQueue();
 	bool trySendMessage(message_ptr message);
@@ -105,14 +99,11 @@ private:
 	std::atomic<bool> mWritten = false; // written outside lock
 	bool mWrittenOnce = false;
 
-	state_callback mStateChangeCallback;
-	std::atomic<State> mState;
+	binary mPartialRecv, mPartialStringData, mPartialBinaryData;
 
 	// Stats
 	std::atomic<size_t> mBytesSent = 0, mBytesReceived = 0;
 
-	binary mPartialRecv, mPartialStringData, mPartialBinaryData;
-
 	static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
 	                        struct sctp_rcvinfo recv_info, int flags, void *user_data);
 	static int SendCallback(struct socket *sock, uint32_t sb_free);

+ 320 - 0
src/tcptransport.cpp

@@ -0,0 +1,320 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "tcptransport.hpp"
+
+#include <exception>
+#ifndef _WIN32
+#include <fcntl.h>
+#include <unistd.h>
+#endif
+
+namespace rtc {
+
+using std::to_string;
+
+SelectInterrupter::SelectInterrupter() {
+#ifndef _WIN32
+	int pipefd[2];
+	if (::pipe(pipefd) != 0)
+		throw std::runtime_error("Failed to create pipe");
+	::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
+	::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
+	mPipeOut = pipefd[0]; // read
+	mPipeIn = pipefd[1];  // write
+#endif
+}
+
+SelectInterrupter::~SelectInterrupter() {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock != INVALID_SOCKET)
+		::closesocket(mDummySock);
+#else
+	::close(mPipeIn);
+	::close(mPipeOut);
+#endif
+}
+
+int SelectInterrupter::prepare(fd_set &readfds, fd_set &writefds) {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock == INVALID_SOCKET)
+		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
+	FD_SET(mDummySock, &readfds);
+	return SOCK_TO_INT(mDummySock) + 1;
+#else
+	int ret;
+	do {
+		char dummy;
+		ret = ::read(mPipeIn, &dummy, 1);
+	} while (ret > 0);
+	FD_SET(mPipeIn, &readfds);
+	return mPipeIn + 1;
+#endif
+}
+
+void SelectInterrupter::interrupt() {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock != INVALID_SOCKET) {
+		::closesocket(mDummySock);
+		mDummySock = INVALID_SOCKET;
+	}
+#else
+	char dummy = 0;
+	::write(mPipeOut, &dummy, 1);
+#endif
+}
+
+TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
+    : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
+
+	PLOG_DEBUG << "Initializing TCP transport";
+	mThread = std::thread(&TcpTransport::runLoop, this);
+}
+
+TcpTransport::~TcpTransport() {
+	stop();
+}
+
+bool TcpTransport::stop() {
+	if (!Transport::stop())
+		return false;
+
+	PLOG_DEBUG << "Waiting TCP recv thread";
+	close();
+	mThread.join();
+	return true;
+}
+
+bool TcpTransport::send(message_ptr message) {
+	if (!message)
+		return mSendQueue.empty();
+
+	PLOG_VERBOSE << "Send size=" << (message ? message->size() : 0);
+
+	return outgoing(message);
+}
+
+void TcpTransport::incoming(message_ptr message) { recv(message); }
+
+bool TcpTransport::outgoing(message_ptr message) {
+	// If nothing is pending, try to send directly
+	// It's safe because if the queue is empty, the thread is not sending
+	if (mSendQueue.empty() && trySendMessage(message))
+		return true;
+
+	mSendQueue.push(message);
+	interruptSelect(); // so the thread waits for writability
+	return false;
+}
+
+void TcpTransport::connect(const string &hostname, const string &service) {
+	PLOG_DEBUG << "Connecting to " << hostname << ":" << service;
+
+	struct addrinfo hints = {};
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	hints.ai_protocol = IPPROTO_TCP;
+	hints.ai_flags = AI_ADDRCONFIG;
+
+	struct addrinfo *result = nullptr;
+	if (getaddrinfo(hostname.c_str(), service.c_str(), &hints, &result))
+		throw std::runtime_error("Resolution failed for \"" + hostname + ":" + service + "\"");
+
+	for (auto p = result; p; p = p->ai_next)
+		try {
+			connect(p->ai_addr, p->ai_addrlen);
+			freeaddrinfo(result);
+			return;
+		} catch (const std::runtime_error &e) {
+			PLOG_WARNING << e.what();
+		}
+
+	freeaddrinfo(result);
+	throw std::runtime_error("Connection failed to \"" + hostname + ":" + service + "\"");
+}
+
+void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
+	try {
+		PLOG_DEBUG << "Creating TCP socket";
+
+		// Create socket
+		mSock = ::socket(addr->sa_family, SOCK_STREAM, IPPROTO_TCP);
+		if (mSock == INVALID_SOCKET)
+			throw std::runtime_error("TCP socket creation failed");
+
+		ctl_t b = 1;
+		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+			throw std::runtime_error("Failed to set socket non-blocking mode");
+
+		IF_PLOG(plog::debug) {
+			char node[MAX_NUMERICNODE_LEN];
+			char serv[MAX_NUMERICSERV_LEN];
+			if (getnameinfo(addr, addrlen, node, MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
+			                NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
+				PLOG_DEBUG << "Trying address " << node << ":" << serv;
+			}
+		}
+
+		// Initiate connection
+		::connect(mSock, addr, addrlen);
+
+		fd_set writefds;
+		FD_ZERO(&writefds);
+		FD_SET(mSock, &writefds);
+		struct timeval tv;
+		tv.tv_sec = 10; // TODO
+		tv.tv_usec = 0;
+		int ret = ::select(SOCKET_TO_INT(mSock) + 1, NULL, &writefds, NULL, &tv);
+
+		if (ret < 0)
+			throw std::runtime_error("Failed to wait for socket connection");
+
+		if (ret == 0 || ::send(mSock, NULL, 0, MSG_NOSIGNAL) != 0)
+			throw std::runtime_error("Connection failed");
+
+	} catch (...) {
+		if (mSock != INVALID_SOCKET) {
+			::closesocket(mSock);
+			mSock = INVALID_SOCKET;
+		}
+		throw;
+	}
+}
+
+void TcpTransport::close() {
+	if (mSock != INVALID_SOCKET) {
+		PLOG_DEBUG << "Closing TCP socket";
+		::closesocket(mSock);
+		mSock = INVALID_SOCKET;
+	}
+	changeState(State::Disconnected);
+}
+
+bool TcpTransport::trySendQueue() {
+	while (auto next = mSendQueue.peek()) {
+		auto message = *next;
+		if (!trySendMessage(message)) {
+			mSendQueue.exchange(message);
+			return false;
+		}
+		mSendQueue.pop();
+	}
+	return true;
+}
+
+bool TcpTransport::trySendMessage(message_ptr &message) {
+	auto data = reinterpret_cast<const char *>(message->data());
+	auto size = message->size();
+	while (size) {
+		int len = ::send(mSock, data, size, MSG_NOSIGNAL);
+		if (len < 0) {
+			if (errno == EAGAIN || errno == EWOULDBLOCK) {
+				message = make_message(message->end() - size, message->end());
+				return false;
+			} else {
+				throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
+			}
+		}
+
+		data += len;
+		size -= len;
+	}
+	message = nullptr;
+	return true;
+}
+
+void TcpTransport::runLoop() {
+	const size_t bufferSize = 4096;
+
+	// Connect
+	try {
+		changeState(State::Connecting);
+		connect(mHostname, mService);
+
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TCP connect: " << e.what();
+		changeState(State::Failed);
+		return;
+	}
+
+
+	// Receive loop
+	try {
+		PLOG_INFO << "TCP connected";
+		changeState(State::Connected);
+
+		while (true) {
+			fd_set readfds, writefds;
+			int n = prepareSelect(readfds, writefds);
+			int ret = ::select(n, &readfds, &writefds, NULL, NULL);
+			if (ret < 0)
+				throw std::runtime_error("Failed to wait on socket");
+
+			if (FD_ISSET(mSock, &writefds))
+				trySendQueue();
+
+			if (FD_ISSET(mSock, &readfds)) {
+				char buffer[bufferSize];
+				int len = ::recv(mSock, buffer, bufferSize, 0);
+				if (len < 0) {
+					if (errno == EAGAIN || errno == EWOULDBLOCK) {
+						continue;
+					} else {
+						throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
+					}
+				}
+
+				if (len == 0)
+					break; // clean close
+
+				auto *b = reinterpret_cast<byte *>(buffer);
+				incoming(make_message(b, b + len));
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TCP recv: " << e.what();
+	}
+
+	PLOG_INFO << "TCP disconnected";
+	changeState(State::Disconnected);
+	recv(nullptr);
+}
+
+int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) {
+	FD_ZERO(&readfds);
+	FD_ZERO(&writefds);
+	FD_SET(mSock, &readfds);
+
+	if (!mSendQueue.empty())
+		FD_SET(mSock, &writefds);
+
+	int n = SOCKET_TO_INT(mSock) + 1;
+	int m = mInterrupter.prepare(readfds, writefds);
+	return std::max(n, m);
+}
+
+void TcpTransport::interruptSelect() { mInterrupter.interrupt(); }
+
+} // namespace rtc
+
+#endif

+ 90 - 0
src/tcptransport.hpp

@@ -0,0 +1,90 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_TCP_TRANSPORT_H
+#define RTC_TCP_TRANSPORT_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "include.hpp"
+#include "queue.hpp"
+#include "transport.hpp"
+
+#include <mutex>
+#include <thread>
+
+// Use the socket defines from libjuice
+#include "../deps/libjuice/src/socket.h"
+
+namespace rtc {
+
+// Utility class to interrupt select()
+class SelectInterrupter {
+public:
+	SelectInterrupter();
+	~SelectInterrupter();
+
+	int prepare(fd_set &readfds, fd_set &writefds);
+	void interrupt();
+
+private:
+	std::mutex mMutex;
+#ifdef _WIN32
+	socket_t mDummySock = INVALID_SOCKET;
+#else // assume POSIX
+	int mPipeIn, mPipeOut;
+#endif
+};
+
+class TcpTransport : public Transport {
+public:
+	TcpTransport(const string &hostname, const string &service, state_callback callback);
+	~TcpTransport();
+
+	bool stop() override;
+	bool send(message_ptr message) override;
+
+	void incoming(message_ptr message) override;
+	bool outgoing(message_ptr message) override;
+
+private:
+	void connect(const string &hostname, const string &service);
+	void connect(const sockaddr *addr, socklen_t addrlen);
+	void close();
+
+	bool trySendQueue();
+	bool trySendMessage(message_ptr &message);
+
+	void runLoop();
+
+	int prepareSelect(fd_set &readfds, fd_set &writefds);
+	void interruptSelect();
+
+	string mHostname, mService;
+
+	socket_t mSock = INVALID_SOCKET;
+	std::thread mThread;
+	SelectInterrupter mInterrupter;
+	Queue<message_ptr> mSendQueue;
+};
+
+} // namespace rtc
+
+#endif
+
+#endif

+ 432 - 0
src/tlstransport.cpp

@@ -0,0 +1,432 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "tlstransport.hpp"
+#include "tcptransport.hpp"
+
+#include <chrono>
+#include <cstring>
+#include <exception>
+#include <iostream>
+
+using namespace std::chrono;
+
+using std::shared_ptr;
+using std::string;
+using std::unique_ptr;
+using std::weak_ptr;
+
+#if USE_GNUTLS
+
+namespace {
+
+static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
+	if (ret < 0) {
+		if (!gnutls_error_is_fatal(ret)) {
+			PLOG_INFO << gnutls_strerror(ret);
+			return false;
+		}
+		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
+		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
+	}
+	return true;
+}
+
+} // namespace
+
+namespace rtc {
+
+void TlsTransport::Init() {
+	// Nothing to do
+}
+
+void TlsTransport::Cleanup() {
+	// Nothing to do
+}
+
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
+    : Transport(lower, std::move(callback)) {
+
+	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
+
+	check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
+
+	try {
+		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
+		const char *err_pos = NULL;
+		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
+		             "Failed to set TLS priorities");
+
+		gnutls_session_set_ptr(mSession, this);
+		gnutls_transport_set_ptr(mSession, this);
+		gnutls_transport_set_push_function(mSession, WriteCallback);
+		gnutls_transport_set_pull_function(mSession, ReadCallback);
+		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
+
+		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
+
+		mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+		registerIncoming();
+
+	} catch (...) {
+
+		gnutls_deinit(mSession);
+		throw;
+	}
+}
+
+TlsTransport::~TlsTransport() {
+	stop();
+	gnutls_deinit(mSession);
+}
+
+bool TlsTransport::stop() {
+	if (!Transport::stop())
+		return false;
+
+	PLOG_DEBUG << "Stopping TLS recv thread";
+	mIncomingQueue.stop();
+	mRecvThread.join();
+	return true;
+}
+
+bool TlsTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
+	PLOG_VERBOSE << "Send size=" << message->size();
+
+	ssize_t ret;
+	do {
+		ret = gnutls_record_send(mSession, message->data(), message->size());
+	} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
+
+	return check_gnutls(ret);
+}
+
+void TlsTransport::incoming(message_ptr message) {
+	if (message)
+		mIncomingQueue.push(message);
+	else
+		mIncomingQueue.stop();
+}
+
+void TlsTransport::runRecvLoop() {
+	const size_t bufferSize = 4096;
+	char buffer[bufferSize];
+
+	// Handshake loop
+	try {
+		changeState(State::Connecting);
+
+		int ret;
+		do {
+			ret = gnutls_handshake(mSession);
+		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
+		         !check_gnutls(ret, "TLS handshake failed"));
+
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS handshake: " << e.what();
+		changeState(State::Failed);
+		return;
+	}
+
+	// Receive loop
+	try {
+		PLOG_INFO << "TLS handshake finished";
+		changeState(State::Connected);
+
+		while (true) {
+			ssize_t ret;
+			do {
+				ret = gnutls_record_recv(mSession, buffer, bufferSize);
+			} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
+
+			// Consider premature termination as remote closing
+			if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
+				PLOG_DEBUG << "TLS connection terminated";
+				break;
+			}
+
+			if (check_gnutls(ret)) {
+				if (ret == 0) {
+					// Closed
+					PLOG_DEBUG << "TLS connection cleanly closed";
+					break;
+				}
+				auto *b = reinterpret_cast<byte *>(buffer);
+				recv(make_message(b, b + ret));
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS recv: " << e.what();
+	}
+
+	gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
+
+	PLOG_INFO << "TLS closed";
+	changeState(State::Disconnected);
+	recv(nullptr);
+}
+
+ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
+	TlsTransport *t = static_cast<TlsTransport *>(ptr);
+	if (len > 0) {
+		auto b = reinterpret_cast<const byte *>(data);
+		t->outgoing(make_message(b, b + len));
+	}
+	gnutls_transport_set_errno(t->mSession, 0);
+	return ssize_t(len);
+}
+
+ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
+	TlsTransport *t = static_cast<TlsTransport *>(ptr);
+	if (auto next = t->mIncomingQueue.pop()) {
+		auto message = *next;
+		ssize_t len = std::min(maxlen, message->size());
+		std::memcpy(data, message->data(), len);
+		gnutls_transport_set_errno(t->mSession, 0);
+		return len;
+	}
+	// Closed
+	gnutls_transport_set_errno(t->mSession, 0);
+	return 0;
+}
+
+int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
+	TlsTransport *t = static_cast<TlsTransport *>(ptr);
+	if (ms != GNUTLS_INDEFINITE_TIMEOUT)
+		t->mIncomingQueue.wait(milliseconds(ms));
+	else
+		t->mIncomingQueue.wait();
+	return !t->mIncomingQueue.empty() ? 1 : 0;
+}
+
+} // namespace rtc
+
+#else // USE_GNUTLS==0
+
+#include <openssl/bio.h>
+#include <openssl/ec.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+namespace {
+
+const int BIO_EOF = -1;
+
+string openssl_error_string(unsigned long err) {
+	const size_t bufferSize = 256;
+	char buffer[bufferSize];
+	ERR_error_string_n(err, buffer, bufferSize);
+	return string(buffer);
+}
+
+bool check_openssl(int success, const string &message = "OpenSSL error") {
+	if (success)
+		return true;
+
+	string str = openssl_error_string(ERR_get_error());
+	PLOG_ERROR << message << ": " << str;
+	throw std::runtime_error(message + ": " + str);
+}
+
+bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
+	if (ret == BIO_EOF)
+		return true;
+
+	unsigned long err = SSL_get_error(ssl, ret);
+	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
+		return true;
+	}
+	if (err == SSL_ERROR_ZERO_RETURN) {
+		PLOG_DEBUG << "TLS connection cleanly closed";
+		return false;
+	}
+	string str = openssl_error_string(err);
+	PLOG_ERROR << str;
+	throw std::runtime_error(message + ": " + str);
+}
+
+} // namespace
+
+namespace rtc {
+
+int TlsTransport::TransportExIndex = -1;
+
+void TlsTransport::Init() {
+	if (TransportExIndex < 0) {
+		TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
+	}
+}
+
+void TlsTransport::Cleanup() {
+	// Nothing to do
+}
+
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
+    : Transport(lower, std::move(callback)) {
+
+	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
+
+	if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
+		throw std::runtime_error("Failed to create SSL context");
+
+	check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
+	              "Failed to set SSL priorities");
+
+	SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
+	SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
+	SSL_CTX_set_read_ahead(mCtx, 1);
+	SSL_CTX_set_quiet_shutdown(mCtx, 1);
+	SSL_CTX_set_info_callback(mCtx, InfoCallback);
+
+	SSL_CTX_set_default_verify_paths(mCtx);
+	SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
+	SSL_CTX_set_verify_depth(mCtx, 4);
+
+	if (!(mSsl = SSL_new(mCtx)))
+		throw std::runtime_error("Failed to create SSL instance");
+
+	SSL_set_ex_data(mSsl, TransportExIndex, this);
+	SSL_set_tlsext_host_name(mSsl, host.c_str());
+
+	SSL_set_connect_state(mSsl);
+
+	if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
+		throw std::runtime_error("Failed to create BIO");
+
+	BIO_set_mem_eof_return(mInBio, BIO_EOF);
+	BIO_set_mem_eof_return(mOutBio, BIO_EOF);
+	SSL_set_bio(mSsl, mInBio, mOutBio);
+
+	auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
+	    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
+	SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
+	SSL_set_tmp_ecdh(mSsl, ecdh.get());
+
+	mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
+}
+
+TlsTransport::~TlsTransport() {
+	stop();
+
+	SSL_free(mSsl);
+	SSL_CTX_free(mCtx);
+}
+
+bool TlsTransport::stop() {
+	if (!Transport::stop())
+		return false;
+
+	PLOG_DEBUG << "Stopping TLS recv thread";
+	mIncomingQueue.stop();
+	mRecvThread.join();
+	SSL_shutdown(mSsl);
+	return true;
+}
+
+bool TlsTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
+	int ret = SSL_write(mSsl, message->data(), message->size());
+	if (!check_openssl_ret(mSsl, ret))
+		return false;
+
+	const size_t bufferSize = 4096;
+	byte buffer[bufferSize];
+	while (int len = BIO_read(mOutBio, buffer, bufferSize))
+		outgoing(make_message(buffer, buffer + len));
+
+	return true;
+}
+
+void TlsTransport::incoming(message_ptr message) {
+	if (message)
+		mIncomingQueue.push(message);
+	else
+		mIncomingQueue.stop();
+}
+
+void TlsTransport::runRecvLoop() {
+	const size_t bufferSize = 4096;
+	byte buffer[bufferSize];
+
+	try {
+		changeState(State::Connecting);
+
+		SSL_do_handshake(mSsl);
+		while (int len = BIO_read(mOutBio, buffer, bufferSize))
+			outgoing(make_message(buffer, buffer + len));
+
+		while (auto next = mIncomingQueue.pop()) {
+			message_ptr message = *next;
+			message_ptr decrypted;
+
+			BIO_write(mInBio, message->data(), message->size());
+
+			int ret = SSL_read(mSsl, buffer, bufferSize);
+			if (!check_openssl_ret(mSsl, ret))
+				break;
+
+			if (ret > 0)
+				decrypted = make_message(buffer, buffer + ret);
+
+			while (int len = BIO_read(mOutBio, buffer, bufferSize))
+				outgoing(make_message(buffer, buffer + len));
+
+			if (state() == State::Connecting && SSL_is_init_finished(mSsl)) {
+				PLOG_INFO << "TLS handshake finished";
+				changeState(State::Connected);
+			}
+
+			if (decrypted)
+				recv(decrypted);
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS recv: " << e.what();
+	}
+
+	if (state() == State::Connected) {
+		PLOG_INFO << "TLS closed";
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "TLS handshake failed";
+	}
+}
+
+void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
+	TlsTransport *t =
+	    static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
+
+	if (where & SSL_CB_ALERT) {
+		if (ret != 256) { // Close Notify
+			PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
+		}
+		t->mIncomingQueue.stop(); // Close the connection
+	}
+}
+
+} // namespace rtc
+
+#endif
+
+#endif

+ 83 - 0
src/tlstransport.hpp

@@ -0,0 +1,83 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_TLS_TRANSPORT_H
+#define RTC_TLS_TRANSPORT_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "include.hpp"
+#include "queue.hpp"
+#include "transport.hpp"
+
+#include <memory>
+#include <mutex>
+#include <thread>
+
+#if USE_GNUTLS
+#include <gnutls/gnutls.h>
+#else
+#include <openssl/ssl.h>
+#endif
+
+namespace rtc {
+
+class TcpTransport;
+
+class TlsTransport : public Transport {
+public:
+	static void Init();
+	static void Cleanup();
+
+	TlsTransport(std::shared_ptr<TcpTransport> lower, string host, state_callback callback);
+	~TlsTransport();
+
+	bool stop() override;
+	bool send(message_ptr message) override;
+
+	void incoming(message_ptr message) override;
+
+protected:
+	void runRecvLoop();
+
+	Queue<message_ptr> mIncomingQueue;
+	std::thread mRecvThread;
+
+#if USE_GNUTLS
+	gnutls_session_t mSession;
+
+	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
+	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
+	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
+#else
+	SSL_CTX *mCtx;
+	SSL *mSsl;
+	BIO *mInBio, *mOutBio;
+
+	static int TransportExIndex;
+
+	static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
+	static void InfoCallback(const SSL *ssl, int where, int ret);
+#endif
+};
+
+} // namespace rtc
+
+#endif
+
+#endif

+ 15 - 1
src/transport.hpp

@@ -32,7 +32,13 @@ using namespace std::placeholders;
 
 class Transport {
 public:
-	Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {}
+	enum class State { Disconnected, Connecting, Connected, Completed, Failed };
+	using state_callback = std::function<void(State state)>;
+
+	Transport(std::shared_ptr<Transport> lower = nullptr, state_callback callback = nullptr)
+	    : mLower(std::move(lower)), mStateChangeCallback(std::move(callback)) {
+	}
+
 	virtual ~Transport() {
 		stop();
 		if (mLower)
@@ -49,11 +55,16 @@ public:
 	}
 
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
+	State state() const { return mState; }
 
 	virtual bool send(message_ptr message) { return outgoing(message); }
 
 protected:
 	void recv(message_ptr message) { mRecvCallback(message); }
+	void changeState(State state) {
+		if (mState.exchange(state) != state)
+			mStateChangeCallback(state);
+	}
 
 	virtual void incoming(message_ptr message) { recv(message); }
 	virtual bool outgoing(message_ptr message) {
@@ -65,7 +76,10 @@ protected:
 
 private:
 	std::shared_ptr<Transport> mLower;
+	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<message_ptr> mRecvCallback;
+
+	std::atomic<State> mState = State::Disconnected;
 	std::atomic<bool> mShutdown = false;
 };
 

+ 311 - 0
src/websocket.cpp

@@ -0,0 +1,311 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "include.hpp"
+#include "websocket.hpp"
+
+#include "tcptransport.hpp"
+#include "tlstransport.hpp"
+#include "wstransport.hpp"
+
+#include <regex>
+
+#ifdef _WIN32
+#include <winsock2.h>
+#endif
+
+namespace rtc {
+
+WebSocket::WebSocket() {}
+
+WebSocket::WebSocket(const string &url) : WebSocket() { open(url); }
+
+WebSocket::~WebSocket() { remoteClose(); }
+
+WebSocket::State WebSocket::readyState() const { return mState; }
+
+void WebSocket::open(const string &url) {
+	if (mState != State::Closed)
+		throw std::runtime_error("WebSocket must be closed before opening");
+
+	static const char *rs = R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)";
+	static std::regex regex(rs, std::regex::extended);
+
+	std::smatch match;
+	if (!std::regex_match(url, match, regex))
+		throw std::invalid_argument("Malformed WebSocket URL: " + url);
+
+	mScheme = match[2];
+	if (mScheme != "ws" && mScheme != "wss")
+		throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
+
+	mHost = match[4];
+	if (auto pos = mHost.find(':'); pos != string::npos) {
+		mHostname = mHost.substr(0, pos);
+		mService = mHost.substr(pos + 1);
+	} else {
+		mHostname = mHost;
+		mService = mScheme == "ws" ? "80" : "443";
+	}
+
+	mPath = match[5];
+	if (string query = match[7]; !query.empty())
+		mPath += "?" + query;
+
+	changeState(State::Connecting);
+	initTcpTransport();
+}
+
+void WebSocket::close() {
+	auto state = mState.load();
+	if (state == State::Connecting || state == State::Open) {
+		changeState(State::Closing);
+		if (auto transport = std::atomic_load(&mWsTransport))
+			transport->close();
+		else
+			changeState(State::Closed);
+	}
+}
+
+void WebSocket::remoteClose() {
+	close();
+	closeTransports();
+}
+
+bool WebSocket::send(const std::variant<binary, string> &data) {
+	return std::visit(
+	    [&](const auto &d) {
+		    using T = std::decay_t<decltype(d)>;
+		    constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
+		    auto *b = reinterpret_cast<const byte *>(d.data());
+		    return outgoing(std::make_shared<Message>(b, b + d.size(), type));
+	    },
+	    data);
+}
+
+bool WebSocket::isOpen() const { return mState == State::Open; }
+
+bool WebSocket::isClosed() const { return mState == State::Closed; }
+
+size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
+
+std::optional<std::variant<binary, string>> WebSocket::receive() {
+	while (!mRecvQueue.empty()) {
+		auto message = *mRecvQueue.pop();
+		switch (message->type) {
+		case Message::String:
+			return std::make_optional(
+			    string(reinterpret_cast<const char *>(message->data()), message->size()));
+		case Message::Binary:
+			return std::make_optional(std::move(*message));
+		default:
+			// Ignore
+			break;
+		}
+	}
+	return nullopt;
+}
+
+size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
+
+bool WebSocket::changeState(State state) { return mState.exchange(state) != state; }
+
+bool WebSocket::outgoing(mutable_message_ptr message) {
+	if (mState != State::Open || !mWsTransport)
+		throw std::runtime_error("WebSocket is not open");
+
+	if (message->size() > maxMessageSize())
+		throw std::runtime_error("Message size exceeds limit");
+
+	return mWsTransport->send(message);
+}
+
+void WebSocket::incoming(message_ptr message) {
+	if (message->type == Message::String || message->type == Message::Binary) {
+		mRecvQueue.push(message);
+		triggerAvailable(mRecvQueue.size());
+	}
+}
+
+std::shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
+	using State = TcpTransport::State;
+	try {
+		std::lock_guard lock(mInitMutex);
+		if (auto transport = std::atomic_load(&mTcpTransport))
+			return transport;
+
+		auto transport = std::make_shared<TcpTransport>(
+		    mHostname, mService, [this, weak_this = weak_from_this()](State state) {
+			    auto shared_this = weak_this.lock();
+			    if (!shared_this)
+				    return;
+			    switch (state) {
+			    case State::Connected:
+				    if (mScheme == "ws")
+					    initWsTransport();
+				    else
+					    initTlsTransport();
+				    break;
+			    case State::Failed:
+				    triggerError("TCP connection failed");
+				    remoteClose();
+				    break;
+			    case State::Disconnected:
+				    remoteClose();
+				    break;
+			    default:
+				    // Ignore
+				    break;
+			    }
+		    });
+		std::atomic_store(&mTcpTransport, transport);
+		if (mState == WebSocket::State::Closed) {
+			mTcpTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
+		return transport;
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		remoteClose();
+		throw std::runtime_error("TCP transport initialization failed");
+	}
+}
+
+std::shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
+	using State = TlsTransport::State;
+	try {
+		std::lock_guard lock(mInitMutex);
+		if (auto transport = std::atomic_load(&mTlsTransport))
+			return transport;
+
+		auto lower = std::atomic_load(&mTcpTransport);
+		auto transport = std::make_shared<TlsTransport>(
+		    lower, mHost, [this, weak_this = weak_from_this()](State state) {
+			    auto shared_this = weak_this.lock();
+			    if (!shared_this)
+				    return;
+			    switch (state) {
+			    case State::Connected:
+				    initWsTransport();
+				    break;
+			    case State::Failed:
+				    triggerError("TCP connection failed");
+				    remoteClose();
+				    break;
+			    case State::Disconnected:
+				    remoteClose();
+				    break;
+			    default:
+				    // Ignore
+				    break;
+			    }
+		    });
+		std::atomic_store(&mTlsTransport, transport);
+		if (mState == WebSocket::State::Closed) {
+			mTlsTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
+		return transport;
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		remoteClose();
+		throw std::runtime_error("TLS transport initialization failed");
+	}
+}
+
+std::shared_ptr<WsTransport> WebSocket::initWsTransport() {
+	using State = WsTransport::State;
+	try {
+		std::lock_guard lock(mInitMutex);
+		if (auto transport = std::atomic_load(&mWsTransport))
+			return transport;
+
+		std::shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
+		if (!lower)
+			lower = std::atomic_load(&mTcpTransport);
+		auto transport = std::make_shared<WsTransport>(
+		    lower, mHost, mPath, weak_bind(&WebSocket::incoming, this, _1),
+		    [this, weak_this = weak_from_this()](State state) {
+			    auto shared_this = weak_this.lock();
+			    if (!shared_this)
+				    return;
+			    switch (state) {
+			    case State::Connected:
+				    if (mState == WebSocket::State::Connecting) {
+					    PLOG_DEBUG << "WebSocket open";
+					    changeState(WebSocket::State::Open);
+					    triggerOpen();
+				    }
+				    break;
+			    case State::Failed:
+				    triggerError("WebSocket connection failed");
+				    remoteClose();
+				    break;
+			    case State::Disconnected:
+				    remoteClose();
+				    break;
+			    default:
+				    // Ignore
+				    break;
+			    }
+		    });
+		std::atomic_store(&mWsTransport, transport);
+		if (mState == WebSocket::State::Closed) {
+			mWsTransport.reset();
+			transport->stop();
+			throw std::runtime_error("Connection is closed");
+		}
+		return transport;
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+		remoteClose();
+		throw std::runtime_error("WebSocket transport initialization failed");
+	}
+}
+
+void WebSocket::closeTransports() {
+	changeState(State::Closed);
+
+	// Pass the references to a thread, allowing to terminate a transport from its own thread
+	auto ws = std::atomic_exchange(&mWsTransport, decltype(mWsTransport)(nullptr));
+	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
+	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
+	if (ws || tls || tcp) {
+		std::thread t([ws, tls, tcp]() mutable {
+			if (ws)
+				ws->stop();
+			if (tls)
+				tls->stop();
+			if (tcp)
+				tcp->stop();
+
+			ws.reset();
+			tls.reset();
+			tcp.reset();
+		});
+		t.detach();
+	}
+}
+
+} // namespace rtc
+
+#endif

+ 372 - 0
src/wstransport.cpp

@@ -0,0 +1,372 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "wstransport.hpp"
+#include "tcptransport.hpp"
+#include "tlstransport.hpp"
+
+#include "base64.hpp"
+
+#include <chrono>
+#include <list>
+#include <map>
+#include <random>
+#include <regex>
+
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <arpa/inet.h>
+#endif
+
+#ifndef htonll
+#define htonll(x)                                                                                  \
+	((uint64_t)htonl(((uint64_t)(x)&0xFFFFFFFF) << 32) | (uint64_t)htonl((uint64_t)(x) >> 32))
+#endif
+#ifndef ntohll
+#define ntohll(x) htonll(x)
+#endif
+
+namespace rtc {
+
+using namespace std::chrono;
+using std::to_integer;
+using std::to_string;
+
+using random_bytes_engine =
+    std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
+
+WsTransport::WsTransport(std::shared_ptr<Transport> lower, string host, string path,
+                         message_callback recvCallback, state_callback stateCallback)
+    : Transport(lower, std::move(stateCallback)), mHost(std::move(host)), mPath(std::move(path)) {
+	onRecv(recvCallback);
+
+	PLOG_DEBUG << "Initializing WebSocket transport";
+
+	registerIncoming();
+	sendHttpRequest();
+}
+
+WsTransport::~WsTransport() { stop(); }
+
+bool WsTransport::stop() {
+	if (!Transport::stop())
+		return false;
+
+	close();
+	return true;
+}
+
+bool WsTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
+	// Call the mutable message overload with a copy
+	return send(std::make_shared<Message>(*message));
+}
+
+bool WsTransport::send(mutable_message_ptr message) {
+	if (!message || state() != State::Connected)
+		return false;
+
+	PLOG_VERBOSE << "Send size=" << message->size();
+
+	return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
+	                  message->size(), true, true});
+}
+
+void WsTransport::incoming(message_ptr message) {
+	try {
+		mBuffer.insert(mBuffer.end(), message->begin(), message->end());
+
+		if (state() == State::Connecting) {
+			if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
+				mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				PLOG_INFO << "WebSocket open";
+				changeState(State::Connected);
+			}
+		}
+
+		if (state() == State::Connected) {
+			Frame frame = {};
+			while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
+				mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				recvFrame(frame);
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << e.what();
+	}
+
+	if (state() == State::Connected) {
+		PLOG_INFO << "WebSocket disconnected";
+		changeState(State::Disconnected);
+		recv(nullptr);
+	} else {
+		PLOG_ERROR << "WebSocket handshake failed";
+		changeState(State::Failed);
+	}
+}
+
+void WsTransport::close() {
+	if (state() == State::Connected) {
+		sendFrame({CLOSE, NULL, 0, true, true});
+		PLOG_INFO << "WebSocket closing";
+		changeState(State::Completed);
+	}
+}
+
+bool WsTransport::sendHttpRequest() {
+	changeState(State::Connecting);
+
+	auto seed = system_clock::now().time_since_epoch().count();
+	random_bytes_engine generator(seed);
+
+	binary key(16);
+	std::generate(reinterpret_cast<uint8_t *>(key.data()),
+	              reinterpret_cast<uint8_t *>(key.data() + key.size()), generator);
+
+	const string request = "GET " + mPath +
+	                       " HTTP/1.1\r\n"
+	                       "Host: " +
+	                       mHost +
+	                       "\r\n"
+	                       "Connection: Upgrade\r\n"
+	                       "Upgrade: websocket\r\n"
+	                       "Sec-WebSocket-Version: 13\r\n"
+	                       "Sec-WebSocket-Key: " +
+	                       to_base64(key) +
+	                       "\r\n"
+	                       "\r\n";
+
+	auto data = reinterpret_cast<const byte *>(request.data());
+	auto size = request.size();
+	return outgoing(make_message(data, data + size));
+}
+
+size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
+	std::list<string> lines;
+	auto begin = reinterpret_cast<const char *>(buffer);
+	auto end = begin + size;
+	auto cur = begin;
+	while (true) {
+		auto last = cur;
+		cur = std::find(cur, end, '\n');
+		if (cur == end)
+			return 0;
+		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
+		if (line.empty())
+			break;
+		lines.emplace_back(std::move(line));
+	}
+	size_t length = cur - begin;
+
+	if (lines.empty())
+		throw std::runtime_error("Invalid HTTP response for WebSocket");
+
+	string status = std::move(lines.front());
+	lines.pop_front();
+
+	std::istringstream ss(status);
+	string protocol;
+	unsigned int code = 0;
+	ss >> protocol >> code;
+	PLOG_DEBUG << "WebSocket response code: " << code;
+	if (code != 101)
+		throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
+
+	std::multimap<string, string> headers;
+	for (const auto &line : lines) {
+		if (size_t pos = line.find_first_of(':'); pos != string::npos) {
+			string key = line.substr(0, pos);
+			string value = line.substr(line.find_first_not_of(' ', pos + 1));
+			std::transform(key.begin(), key.end(), key.begin(),
+			               [](char c) { return std::tolower(c); });
+			headers.emplace(std::move(key), std::move(value));
+		} else {
+			headers.emplace(line, "");
+		}
+	}
+
+	auto h = headers.find("upgrade");
+	if (h == headers.end() || h->second != "websocket")
+		throw std::runtime_error("WebSocket update header missing or mismatching");
+
+	h = headers.find("sec-websocket-accept");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket accept header missing");
+
+	// TODO: Verify Sec-WebSocket-Accept
+
+	return length;
+}
+
+// http://tools.ietf.org/html/rfc6455#section-5.2  Base Framing Protocol
+//
+//  0                   1                   2                   3
+//  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+// +-+-+-+-+-------+-+-------------+-------------------------------+
+// |F|R|R|R| opcode|M| Payload len |    Extended payload length    |
+// |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           |
+// |N|V|V|V|       |S|             |   (if payload len==126/127)   |
+// | |1|2|3|       |K|             |                               |
+// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+// |    Extended payload length continued, if payload len == 127   |
+// + - - - - - - - - - - - - - - - +-------------------------------+
+// |                               | Masking-key, if MASK set to 1 |
+// +-------------------------------+-------------------------------+
+// |    Masking-key (continued)    |          Payload Data         |
+// +-------------------------------+ - - - - - - - - - - - - - - - +
+// :                     Payload Data continued ...                :
+// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+// |                     Payload Data continued ...                |
+// +---------------------------------------------------------------+
+
+size_t WsTransport::readFrame(byte *buffer, size_t size, Frame &frame) {
+	const byte *end = buffer + size;
+	if (end - buffer < 2)
+		return 0;
+
+	byte *cur = buffer;
+	auto b1 = to_integer<uint8_t>(*cur++);
+	auto b2 = to_integer<uint8_t>(*cur++);
+
+	frame.fin = (b1 & 0x80) != 0;
+	frame.mask = (b2 & 0x80) != 0;
+	frame.opcode = static_cast<Opcode>(b1 & 0x0F);
+	frame.length = b2 & 0x7F;
+
+	if (frame.length == 0x7E) {
+		if (end - cur < 2)
+			return 0;
+		frame.length = ntohs(*reinterpret_cast<const uint16_t *>(cur));
+		cur += 2;
+	} else if (frame.length == 0x7F) {
+		if (end - cur < 8)
+			return false;
+		frame.length = ntohll(*reinterpret_cast<const uint64_t *>(cur));
+		cur += 8;
+	}
+
+	const byte *maskingKey = nullptr;
+	if (frame.mask) {
+		if (end - cur < 4)
+			return 0;
+		maskingKey = cur;
+		cur += 4;
+	}
+
+	if (end - cur < frame.length)
+		return false;
+
+	frame.payload = cur;
+	if (maskingKey)
+		for (size_t i = 0; i < frame.length; ++i)
+			frame.payload[i] ^= maskingKey[i % 4];
+
+	return end - buffer;
+}
+
+void WsTransport::recvFrame(const Frame &frame) {
+	switch (frame.opcode) {
+	case TEXT_FRAME:
+	case BINARY_FRAME: {
+		if (!mPartial.empty()) {
+			auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
+			recv(make_message(mPartial.begin(), mPartial.end(), type));
+			mPartial.clear();
+		}
+		if (frame.fin) {
+			auto type = frame.opcode == TEXT_FRAME ? Message::String : Message::Binary;
+			recv(make_message(frame.payload, frame.payload + frame.length));
+		} else {
+			mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
+			mPartialOpcode = frame.opcode;
+		}
+		break;
+	}
+	case CONTINUATION: {
+		mPartial.insert(mPartial.end(), frame.payload, frame.payload + frame.length);
+		if (frame.fin) {
+			auto type = mPartialOpcode == TEXT_FRAME ? Message::String : Message::Binary;
+			recv(make_message(mPartial.begin(), mPartial.end()));
+			mPartial.clear();
+		}
+		break;
+	}
+	case PING: {
+		sendFrame({PONG, frame.payload, frame.length, true, true});
+		break;
+	}
+	case PONG: {
+		// TODO
+		break;
+	}
+	case CLOSE: {
+		close();
+		PLOG_INFO << "WebSocket closed";
+		changeState(State::Disconnected);
+		break;
+	}
+	default: {
+		close();
+		throw std::invalid_argument("Unknown WebSocket opcode: " + to_string(frame.opcode));
+	}
+	}
+}
+
+bool WsTransport::sendFrame(const Frame &frame) {
+	byte buffer[14];
+	byte *cur = buffer;
+
+	*cur++ = byte((frame.opcode & 0x0F) | (frame.fin ? 0x80 : 0));
+
+	if (frame.length < 0x7E) {
+		*cur++ = byte((frame.length & 0x7F) | (frame.mask ? 0x80 : 0));
+	} else if (frame.length <= 0xFF) {
+		*cur++ = byte(0x7E | (frame.mask ? 0x80 : 0));
+		*reinterpret_cast<uint16_t *>(cur) = uint16_t(frame.length);
+		cur += 2;
+	} else {
+		*cur++ = byte(0x7F | (frame.mask ? 0x80 : 0));
+		*reinterpret_cast<uint64_t *>(cur) = uint64_t(frame.length);
+		cur += 8;
+	}
+
+	if (frame.mask) {
+		auto seed = system_clock::now().time_since_epoch().count();
+		random_bytes_engine generator(seed);
+
+		auto *maskingKey = cur;
+		std::generate(reinterpret_cast<uint8_t *>(maskingKey),
+		              reinterpret_cast<uint8_t *>(maskingKey + 4), generator);
+		cur += 4;
+
+		for (size_t i = 0; i < frame.length; ++i)
+			frame.payload[i] ^= maskingKey[i % 4];
+	}
+
+	outgoing(make_message(buffer, cur));                                        // header
+	return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
+}
+
+} // namespace rtc
+
+#endif

+ 83 - 0
src/wstransport.hpp

@@ -0,0 +1,83 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_WS_TRANSPORT_H
+#define RTC_WS_TRANSPORT_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "include.hpp"
+#include "transport.hpp"
+
+namespace rtc {
+
+class TcpTransport;
+class TlsTransport;
+
+class WsTransport : public Transport {
+public:
+	WsTransport(std::shared_ptr<Transport> lower, string host, string path,
+	            message_callback recvCallback, state_callback stateCallback);
+	~WsTransport();
+
+	bool stop() override;
+	bool send(message_ptr message) override;
+	bool send(mutable_message_ptr message);
+
+	void incoming(message_ptr message) override;
+
+	void close();
+
+private:
+	enum Opcode : uint8_t {
+		CONTINUATION = 0,
+		TEXT_FRAME = 1,
+		BINARY_FRAME = 2,
+		CLOSE = 8,
+		PING = 9,
+		PONG = 10,
+	};
+
+	struct Frame {
+		Opcode opcode = BINARY_FRAME;
+		byte *payload = nullptr;
+		size_t length = 0;
+		bool fin = true;
+		bool mask = true;
+	};
+
+	bool sendHttpRequest();
+	size_t readHttpResponse(const byte *buffer, size_t size);
+
+	size_t readFrame(byte *buffer, size_t size, Frame &frame);
+	void recvFrame(const Frame &frame);
+	bool sendFrame(const Frame &frame);
+
+	const string mHost;
+	const string mPath;
+
+	binary mBuffer;
+	binary mPartial;
+	Opcode mPartialOpcode;
+};
+
+} // namespace rtc
+
+#endif
+
+#endif

+ 6 - 6
test/main.cpp

@@ -25,19 +25,19 @@ void test_capi();
 
 int main(int argc, char **argv) {
 	try {
-		std::cout << "*** Running connectivity test..." << std::endl;
+		cout << endl << "*** Running connectivity test..." << endl;
 		test_connectivity();
-		std::cout << "*** Finished connectivity test" << std::endl;
+		cout << "*** Finished connectivity test" << endl;
 	} catch (const exception &e) {
-		std::cerr << "Connectivity test failed: " << e.what() << endl;
+		cerr << "Connectivity test failed: " << e.what() << endl;
 		return -1;
 	}
 	try {
-		std::cout << "*** Running C API test..." << std::endl;
+		cout << endl << "*** Running C API test..." << endl;
 		test_capi();
-		std::cout << "*** Finished C API test" << std::endl;
+		cout << "*** Finished C API test" << endl;
 	} catch (const exception &e) {
-		std::cerr << "C API test failed: " << e.what() << endl;
+		cerr << "C API test failed: " << e.what() << endl;
 		return -1;
 	}
 	return 0;