Browse Source

Merge pull request #447 from paullouisageneau/websocket-server

Add WebSocket server
Paul-Louis Ageneau 4 years ago
parent
commit
6542c26325

+ 31 - 8
CMakeLists.txt

@@ -49,6 +49,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpreceivingsession.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/track.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/websocketserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizationconfig.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpsrreporter.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizer.cpp
@@ -115,11 +116,16 @@ set(LIBDATACHANNEL_IMPL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/track.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/selectinterrupter.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocketserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wshandshake.cpp
 )
 
 set(LIBDATACHANNEL_IMPL_HEADERS
@@ -140,11 +146,16 @@ set(LIBDATACHANNEL_IMPL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/track.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/selectinterrupter.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocketserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wshandshake.hpp
 )
 
 set(TESTS_SOURCES
@@ -155,6 +166,8 @@ set(TESTS_SOURCES
     ${CMAKE_CURRENT_SOURCE_DIR}/test/capi_connectivity.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/capi_track.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/websocket.cpp
+    ${CMAKE_CURRENT_SOURCE_DIR}/test/websocketserver.cpp
+    ${CMAKE_CURRENT_SOURCE_DIR}/test/capi_websocketserver.cpp
     ${CMAKE_CURRENT_SOURCE_DIR}/test/benchmark.cpp
 )
 
@@ -282,9 +295,15 @@ if (USE_GNUTLS)
 	target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=1)
 	target_link_libraries(datachannel PRIVATE GnuTLS::GnuTLS)
 	target_link_libraries(datachannel-static PRIVATE GnuTLS::GnuTLS)
+	if (NOT NO_WEBSOCKET)
+		# Needed for SHA1, it should be present as GnuTLS cryptography backend
+		find_package(Nettle REQUIRED)
+		target_link_libraries(datachannel PRIVATE Nettle::Nettle)
+		target_link_libraries(datachannel-static PRIVATE Nettle::Nettle)
+	endif()
 else()
 	if(APPLE)
-		# This is a bug in CMake that causes it to prefer the system version over 
+		# This is a bug in CMake that causes it to prefer the system version over
 		# the one in the specified ROOT folder
 		if(EXISTS ${OPENSSL_ROOT_DIR})
 			set(OPENSSL_CRYPTO_LIBRARY "${OPENSSL_ROOT_DIR}/lib/libcrypto.dylib" CACHE FILEPATH "" FORCE)
@@ -390,13 +409,17 @@ endif()
 if(NOT NO_EXAMPLES)
 	set(JSON_BuildTests OFF CACHE INTERNAL "")
 	add_subdirectory(deps/json EXCLUDE_FROM_ALL)
-	add_subdirectory(examples/client)
-	add_subdirectory(examples/client-benchmark)
-if(NOT NO_MEDIA)
-	add_subdirectory(examples/media)
-	add_subdirectory(examples/sfu-media)
-    add_subdirectory(examples/streamer)
-endif()
+	if(NOT NO_WEBSOCKET)
+		add_subdirectory(examples/client)
+		add_subdirectory(examples/client-benchmark)
+	endif()
+	if(NOT NO_MEDIA)
+		add_subdirectory(examples/media)
+		add_subdirectory(examples/sfu-media)
+	endif()
+	if(NOT NO_MEDIA AND NOT NO_WEBSOCKET)
+    	add_subdirectory(examples/streamer)
+	endif()
 	add_subdirectory(examples/copy-paste)
 	add_subdirectory(examples/copy-paste-capi)
 endif()

+ 96 - 3
DOC.md

@@ -504,6 +504,10 @@ If `buffer` is `NULL`, the description is not copied but the size is still retur
 ```
 int rtcCreateWebSocket(const char *url)
 int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config)
+
+typedef struct {
+	bool disableTlsVerification;    // if true, disable TLS certificate verification
+} rtcWsConfiguration;
 ```
 
 Creates a new client WebSocket.
@@ -526,11 +530,100 @@ int rtcDeleteWebSocket(int ws)
 Arguments:
 - `ws`: the identifier of the WebSocket to delete
 
-Return value: the identifier of the new WebSocket or a negative error code
-
 After this function has been called, `ws` must not be used in a function call anymore. This function will block until all scheduled callbacks of `ws` return (except the one this function might be called in) and no other callback will be called for `ws` after it returns.
 
-### Channel (Data Channel, Track, and WebSocket)
+#### rtcGetWebSocketRemoteAddress
+
+```
+int rtcGetWebSocketRemoteAddress(int ws, char *buffer, int size)
+```
+
+Retrieves the remote address, i.e. the network address of the remote endpoint. The address will have the format `"HOST:PORT"`. The call may fail if the underlying TCP transport of the WebSocket is not connected. This function is useful for a client WebSocket received by a WebSocket Server.
+
+Arguments:
+- `ws`: the identifier of the WebSocket
+- `buffer`: a user-supplied buffer to store the description
+- `size`: the size of `buffer`
+
+Return value: the length of the string copied in buffer (including the terminating null character) or a negative error code
+
+If `buffer` is `NULL`, the address is not copied but the size is still returned.
+
+#### rtcGetWebSocketPath
+
+```
+int rtcGetWebSocketPath(int ws, char *buffer, int size)
+```
+
+Retrieves the path of the WebSocket, i.e. the HTTP requested path. This function is useful for a client WebSocket received by a WebSocket Server. Warning: The WebSocket must be open for the call to succeed.
+
+Arguments:
+- `ws`: the identifier of the WebSocket
+- `buffer`: a user-supplied buffer to store the description
+- `size`: the size of `buffer`
+
+Return value: the length of the string copied in buffer (including the terminating null character) or a negative error code
+
+If `buffer` is `NULL`, the path is not copied but the size is still returned.
+
+### WebSocket Server
+
+#### rtcCreateWebSocketServer
+
+```
+int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config, rtcWebSocketClientCallbackFunc cb);
+
+typedef struct {
+	uint16_t port;
+	bool enableTls;
+	const char *certificatePemFile; // NULL for autogenerated certificate
+	const char *keyPemFile;         // NULL for autogenerated certificate
+	const char *keyPemPass;         // NULL if no pass
+} rtcWsServerConfiguration;
+
+```
+
+Creates a new WebSocket server.
+
+Arguments:
+- `config`: a structure with the following parameters:
+  - `uint16_t port`: the port to listen on (if 0, automatically select an available port)
+  - `bool enableTls`: if true, enable the TLS layer (WSS)
+  - `const char *certificatePemFile`: path of the file containing the TLS PEM certificate (`NULL` for an autogenerated certificate)
+  - `const char *keyPemFile`: path of the file containing the TLS PEM key (`NULL` for an autogenerated certificate)
+  - `const char *keyPemPass`: the TLS PEM key passphrase (NULL if no passphrase)
+- `cb`: the callback for incoming client WebSocket connections (must not be `NULL`)
+
+`cb` must have the following signature: `void rtcWebSocketClientCallbackFunc(int wsserver, int ws, void *user_ptr)`
+
+Return value: the identifier of the new WebSocket Server or a negative error code
+
+The new WebSocket Server must be deleted with `rtcDeleteWebSocketServer`.
+
+#### rtcDeleteWebSocketServer
+
+```
+int rtcDeleteWebSocketServer(int wsserver)
+```
+
+Arguments:
+- `wsserver`: the identifier of the WebSocket Server to delete
+
+After this function has been called, `wsserver` must not be used in a function call anymore. This function will block until all scheduled callbacks of `wsserver` return (except the one this function might be called in) and no other callback will be called for `wsserver` after it returns.
+
+#### rtcGetWebSocketServerPort
+```
+int rtcGetWebSocketServerPort(int wsserver);
+```
+
+Retrieves the port which the WebSocket Server is listening on.
+
+Arguments:
+- `wsserver`: the identifier of the WebSocket Server
+
+Return value: The port of the WebSocket Server or a negative error code
+
+### Channel (Common API for Data Channel, Track, and WebSocket)
 
 The following common functions might be called with a generic channel identifier. It may be the identifier of either a Data Channel, a Track, or a WebSocket.
 

+ 142 - 0
cmake/Modules/FindNettle.cmake

@@ -0,0 +1,142 @@
+# Copyright (C) 2020 Dieter Baron and Thomas Klausner
+#
+# The authors can be contacted at <[email protected]>
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+#
+# 1. Redistributions of source code must retain the above copyright
+#   notice, this list of conditions and the following disclaimer.
+#
+# 2. Redistributions in binary form must reproduce the above copyright
+#   notice, this list of conditions and the following disclaimer in
+#   the documentation and/or other materials provided with the
+#   distribution.
+#
+# 3. The names of the authors may not be used to endorse or promote
+#   products derived from this software without specific prior
+#   written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS
+# OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
+# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
+# GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
+# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
+# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN
+# IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+#[=======================================================================[.rst:
+FindNettle
+-------
+
+Finds the Nettle library.
+
+Imported Targets
+^^^^^^^^^^^^^^^^
+
+This module provides the following imported targets, if found:
+
+``Nettle::Nettle``
+  The Nettle library
+
+Result Variables
+^^^^^^^^^^^^^^^^
+
+This will define the following variables:
+
+``Nettle_FOUND``
+  True if the system has the Nettle library.
+``Nettle_VERSION``
+  The version of the Nettle library which was found.
+``Nettle_INCLUDE_DIRS``
+  Include directories needed to use Nettle.
+``Nettle_LIBRARIES``
+  Libraries needed to link to Nettle.
+
+Cache Variables
+^^^^^^^^^^^^^^^
+
+The following cache variables may also be set:
+
+``Nettle_INCLUDE_DIR``
+  The directory containing ``nettle/aes.h``.
+``Nettle_LIBRARY``
+  The path to the Nettle library.
+
+#]=======================================================================]
+
+find_package(PkgConfig)
+pkg_check_modules(PC_Nettle QUIET nettle)
+
+find_path(Nettle_INCLUDE_DIR
+  NAMES nettle/aes.h nettle/md5.h nettle/pbkdf2.h nettle/ripemd160.h nettle/sha.h
+  PATHS ${PC_Nettle_INCLUDE_DIRS}
+)
+find_library(Nettle_LIBRARY
+  NAMES nettle
+  PATHS ${PC_Nettle_LIBRARY_DIRS}
+)
+
+# Extract version information from the header file
+if(Nettle_INCLUDE_DIR)
+  # This file only exists in nettle>=3.0
+  if(EXISTS ${Nettle_INCLUDE_DIR}/nettle/version.h)
+    file(STRINGS ${Nettle_INCLUDE_DIR}/nettle/version.h _ver_major_line
+         REGEX "^#define NETTLE_VERSION_MAJOR  *[0-9]+"
+         LIMIT_COUNT 1)
+    string(REGEX MATCH "[0-9]+"
+           Nettle_MAJOR_VERSION "${_ver_major_line}")
+    file(STRINGS ${Nettle_INCLUDE_DIR}/nettle/version.h _ver_minor_line
+         REGEX "^#define NETTLE_VERSION_MINOR  *[0-9]+"
+         LIMIT_COUNT 1)
+    string(REGEX MATCH "[0-9]+"
+           Nettle_MINOR_VERSION "${_ver_minor_line}")
+    set(Nettle_VERSION "${Nettle_MAJOR_VERSION}.${Nettle_MINOR_VERSION}")
+    unset(_ver_major_line)
+    unset(_ver_minor_line)
+  else()
+    if(PC_Nettle_VERSION)
+      set(Nettle_VERSION ${PC_Nettle_VERSION})
+    else()
+      set(Nettle_VERSION "1.0")
+    endif()
+  endif()
+endif()
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(Nettle
+  FOUND_VAR Nettle_FOUND
+  REQUIRED_VARS
+    Nettle_LIBRARY
+    Nettle_INCLUDE_DIR
+  VERSION_VAR Nettle_VERSION
+)
+
+if(Nettle_FOUND)
+  set(Nettle_LIBRARIES ${Nettle_LIBRARY})
+  set(Nettle_INCLUDE_DIRS ${Nettle_INCLUDE_DIR})
+  set(Nettle_DEFINITIONS ${PC_Nettle_CFLAGS_OTHER})
+endif()
+
+if(Nettle_FOUND AND NOT TARGET Nettle::Nettle)
+  add_library(Nettle::Nettle UNKNOWN IMPORTED)
+  set_target_properties(Nettle::Nettle PROPERTIES
+    IMPORTED_LOCATION "${Nettle_LIBRARY}"
+    INTERFACE_COMPILE_OPTIONS "${PC_Nettle_CFLAGS_OTHER}"
+    INTERFACE_INCLUDE_DIRECTORIES "${Nettle_INCLUDE_DIR}"
+  )
+endif()
+
+mark_as_advanced(
+  Nettle_INCLUDE_DIR
+  Nettle_LIBRARY
+)
+
+# compatibility variables
+set(Nettle_VERSION_STRING ${Nettle_VERSION})
+

+ 2 - 2
include/rtc/common.hpp

@@ -48,9 +48,9 @@
 #include <memory>
 #include <mutex>
 #include <optional>
-#include <variant>
 #include <string>
 #include <string_view>
+#include <variant>
 #include <vector>
 
 namespace rtc {
@@ -68,8 +68,8 @@ using std::weak_ptr;
 using binary = std::vector<byte>;
 using binary_ptr = std::shared_ptr<binary>;
 
-using std::size_t;
 using std::ptrdiff_t;
+using std::size_t;
 using std::uint16_t;
 using std::uint32_t;
 using std::uint64_t;

+ 30 - 7
include/rtc/rtc.h

@@ -306,21 +306,23 @@ RTC_EXPORT int rtcGetPreviousTrackSenderReportTimestamp(int id, uint32_t *timest
 // Set NeedsToReport flag in RtcpSrReporter handler identified by given track id
 RTC_EXPORT int rtcSetNeedsToSendRtcpSr(int id);
 
-/// Get all available payload types for given codec and stores them in buffer, does nothing if buffer is NULL
-int rtcGetTrackPayloadTypesForCodec(int tr, const char * ccodec, int * buffer, int size);
+/// Get all available payload types for given codec and stores them in buffer, does nothing if
+/// buffer is NULL
+int rtcGetTrackPayloadTypesForCodec(int tr, const char *ccodec, int *buffer, int size);
 
 /// Get all SSRCs for given track
-int rtcGetSsrcsForTrack(int tr, uint32_t * buffer, int count);
+int rtcGetSsrcsForTrack(int tr, uint32_t *buffer, int count);
 
 /// Get CName for SSRC
-int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char * cname, int cnameSize);
+int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char *cname, int cnameSize);
 
 /// Get all SSRCs for given media type in given SDP
 /// @param mediaType Media type (audio/video)
-int rtcGetSsrcsForType(const char * mediaType, const char * sdp, uint32_t * buffer, int bufferSize);
+int rtcGetSsrcsForType(const char *mediaType, const char *sdp, uint32_t *buffer, int bufferSize);
 
 /// Set SSRC for given media type in given SDP
-int rtcSetSsrcForType(const char * mediaType, const char * sdp, char * buffer, const int bufferSize, rtcSsrcForTypeInit * init);
+int rtcSetSsrcForType(const char *mediaType, const char *sdp, char *buffer, const int bufferSize,
+                      rtcSsrcForTypeInit *init);
 
 #endif // RTC_ENABLE_MEDIA
 
@@ -334,7 +336,28 @@ typedef struct {
 
 RTC_EXPORT int rtcCreateWebSocket(const char *url); // returns ws id
 RTC_EXPORT int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config);
-RTC_EXPORT int rtcDeleteWebsocket(int ws);
+RTC_EXPORT int rtcDeleteWebSocket(int ws);
+
+RTC_EXPORT int rtcGetWebSocketRemoteAddress(int ws, char *buffer, int size);
+RTC_EXPORT int rtcGetWebSocketPath(int ws, char *buffer, int size);
+
+// WebSocketServer
+
+typedef void(RTC_API *rtcWebSocketClientCallbackFunc)(int wsserver, int ws, void *ptr);
+
+typedef struct {
+	uint16_t port;                  // 0 means automatic selection
+	bool enableTls;                 // if true, enable TLS (WSS)
+	const char *certificatePemFile; // NULL for autogenerated certificate
+	const char *keyPemFile;         // NULL for autogenerated certificate
+	const char *keyPemPass;         // NULL if no pass
+} rtcWsServerConfiguration;
+
+RTC_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config,
+                                        rtcWebSocketClientCallbackFunc cb); // returns wsserver id
+RTC_EXPORT int rtcDeleteWebSocketServer(int wsserver);
+
+RTC_EXPORT int rtcGetWebSocketServerPort(int wsserver);
 
 #endif
 

+ 1 - 0
include/rtc/rtc.hpp

@@ -31,6 +31,7 @@
 
 // WebSocket
 #include "websocket.hpp"
+#include "websocketserver.hpp"
 
 #endif // RTC_ENABLE_WEBSOCKET
 

+ 4 - 0
include/rtc/websocket.hpp

@@ -49,6 +49,7 @@ public:
 
 	WebSocket();
 	WebSocket(Configuration config);
+	WebSocket(impl_ptr<impl::WebSocket> impl);
 	~WebSocket();
 
 	State readyState() const;
@@ -62,6 +63,9 @@ public:
 	bool send(const message_variant data) override;
 	bool send(const byte *data, size_t size) override;
 
+	optional<string> remoteAddress() const;
+	optional<string> path() const;
+
 private:
 	using CheshireCat<impl::WebSocket>::impl;
 };

+ 63 - 0
include/rtc/websocketserver.hpp

@@ -0,0 +1,63 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_WEBSOCKETSERVER_H
+#define RTC_WEBSOCKETSERVER_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "common.hpp"
+#include "websocket.hpp"
+
+namespace rtc {
+
+namespace impl {
+
+struct WebSocketServer;
+
+}
+
+class RTC_CPP_EXPORT WebSocketServer final : private CheshireCat<impl::WebSocketServer> {
+public:
+	struct Configuration {
+		uint16_t port = 8080;
+		bool enableTls = false;
+		optional<string> certificatePemFile;
+		optional<string> keyPemFile;
+		optional<string> keyPemPass;
+	};
+
+	WebSocketServer();
+	WebSocketServer(Configuration config);
+	~WebSocketServer();
+
+	void stop();
+
+	uint16_t port() const;
+
+	void onClient(std::function<void(shared_ptr<WebSocket>)> callback);
+
+private:
+	using CheshireCat<impl::WebSocketServer>::impl;
+};
+
+} // namespace rtc
+
+#endif
+
+#endif // RTC_WEBSOCKET_H

+ 125 - 23
src/capi.cpp

@@ -43,6 +43,7 @@ std::unordered_map<int, shared_ptr<RtpPacketizationConfig>> rtpConfigMap;
 #endif
 #if RTC_ENABLE_WEBSOCKET
 std::unordered_map<int, shared_ptr<WebSocket>> webSocketMap;
+std::unordered_map<int, shared_ptr<WebSocketServer>> webSocketServerMap;
 #endif
 std::unordered_map<int, void *> userPointerMap;
 std::mutex mutex;
@@ -193,6 +194,7 @@ createRtpPacketizationConfig(const rtcPacketizationHandlerInit *init) {
 #endif // RTC_ENABLE_MEDIA
 
 #if RTC_ENABLE_WEBSOCKET
+
 shared_ptr<WebSocket> getWebSocket(int id) {
 	std::lock_guard lock(mutex);
 	if (auto it = webSocketMap.find(id); it != webSocketMap.end())
@@ -215,6 +217,30 @@ void eraseWebSocket(int ws) {
 		throw std::invalid_argument("WebSocket ID does not exist");
 	userPointerMap.erase(ws);
 }
+
+shared_ptr<WebSocketServer> getWebSocketServer(int id) {
+	std::lock_guard lock(mutex);
+	if (auto it = webSocketServerMap.find(id); it != webSocketServerMap.end())
+		return it->second;
+	else
+		throw std::invalid_argument("WebSocketServer ID does not exist");
+}
+
+int emplaceWebSocketServer(shared_ptr<WebSocketServer> ptr) {
+	std::lock_guard lock(mutex);
+	int wsserver = ++lastId;
+	webSocketServerMap.emplace(std::make_pair(wsserver, ptr));
+	userPointerMap.emplace(std::make_pair(wsserver, nullptr));
+	return wsserver;
+}
+
+void eraseWebSocketServer(int wsserver) {
+	std::lock_guard lock(mutex);
+	if (webSocketServerMap.erase(wsserver) == 0)
+		throw std::invalid_argument("WebSocketServer ID does not exist");
+	userPointerMap.erase(wsserver);
+}
+
 #endif
 
 shared_ptr<Channel> getChannel(int id) {
@@ -268,21 +294,21 @@ int copyAndReturn(binary b, char *buffer, int size) {
 	return int(b.size());
 }
 
-template<typename T>
-int copyAndReturn(std::vector<T> b, T *buffer, int size) {
+template <typename T> int copyAndReturn(std::vector<T> b, T *buffer, int size) {
 	if (!buffer)
 		return int(b.size());
 
 	if (size < int(b.size()))
 		return RTC_ERR_TOO_SMALL;
-    std::copy(b.begin(), b.end(), buffer);
+	std::copy(b.begin(), b.end(), buffer);
 	return int(b.size());
 }
 
 #if RTC_ENABLE_MEDIA
 // function is used in RTC_ENABLE_MEDIA only
 string lowercased(string str) {
-	std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return std::tolower(c); });
+	std::transform(str.begin(), str.end(), str.begin(),
+	               [](unsigned char c) { return std::tolower(c); });
 	return str;
 }
 #endif // RTC_ENABLE_MEDIA
@@ -326,7 +352,7 @@ int rtcCreatePeerConnection(const rtcConfiguration *config) {
 		if (config->maxMessageSize)
 			c.maxMessageSize = size_t(config->maxMessageSize);
 
-		return emplacePeerConnection(std::make_shared<PeerConnection>(c));
+		return emplacePeerConnection(std::make_shared<PeerConnection>(std::move(c)));
 	});
 }
 
@@ -384,9 +410,7 @@ int rtcCreateDataChannelEx(int pc, const char *label, const rtcDataChannelInit *
 }
 
 int rtcIsOpen(int cid) {
-	return wrap([cid] {
-        return getChannel(cid)->isOpen();
-	});
+	return wrap([cid] { return getChannel(cid)->isOpen(); });
 }
 
 int rtcDeleteDataChannel(int dc) {
@@ -527,7 +551,8 @@ int rtcGetTrackDescription(int tr, char *buffer, int size) {
 
 #if RTC_ENABLE_MEDIA
 
-void setSSRC(Description::Media *description, uint32_t ssrc, const char *_name, const char *_msid, const char *_trackID) {
+void setSSRC(Description::Media *description, uint32_t ssrc, const char *_name, const char *_msid,
+             const char *_trackID) {
 
 	optional<string> name = nullopt;
 	if (_name) {
@@ -677,7 +702,7 @@ int rtcSetNeedsToSendRtcpSr(int id) {
 	});
 }
 
-int rtcGetTrackPayloadTypesForCodec(int tr, const char * ccodec, int * buffer, int size) {
+int rtcGetTrackPayloadTypesForCodec(int tr, const char *ccodec, int *buffer, int size) {
 	return wrap([&] {
 		auto track = getTrack(tr);
 		auto codec = lowercased(string(ccodec));
@@ -694,7 +719,7 @@ int rtcGetTrackPayloadTypesForCodec(int tr, const char * ccodec, int * buffer, i
 	});
 }
 
-int rtcGetSsrcsForTrack(int tr, uint32_t * buffer, int count) {
+int rtcGetSsrcsForTrack(int tr, uint32_t *buffer, int count) {
 	return wrap([&] {
 		auto track = getTrack(tr);
 		auto ssrcs = track->description().getSSRCs();
@@ -702,7 +727,7 @@ int rtcGetSsrcsForTrack(int tr, uint32_t * buffer, int count) {
 	});
 }
 
-int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char * cname, int cnameSize) {
+int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char *cname, int cnameSize) {
 	return wrap([&] {
 		auto track = getTrack(tr);
 		auto description = track->description();
@@ -715,7 +740,7 @@ int rtcGetCNameForSsrc(int tr, uint32_t ssrc, char * cname, int cnameSize) {
 	});
 }
 
-int rtcGetSsrcsForType(const char * mediaType, const char * sdp, uint32_t * buffer, int bufferSize) {
+int rtcGetSsrcsForType(const char *mediaType, const char *sdp, uint32_t *buffer, int bufferSize) {
 	return wrap([&] {
 		auto type = lowercased(string(mediaType));
 		auto oldSDP = string(sdp);
@@ -735,8 +760,8 @@ int rtcGetSsrcsForType(const char * mediaType, const char * sdp, uint32_t * buff
 	});
 }
 
-int rtcSetSsrcForType(const char * mediaType, const char * sdp, char * buffer, const int bufferSize,
-					  rtcSsrcForTypeInit * init) {
+int rtcSetSsrcForType(const char *mediaType, const char *sdp, char *buffer, const int bufferSize,
+                      rtcSsrcForTypeInit *init) {
 	return wrap([&] {
 		auto type = lowercased(string(mediaType));
 		auto prevSDP = string(sdp);
@@ -761,23 +786,29 @@ int rtcSetSsrcForType(const char * mediaType, const char * sdp, char * buffer, c
 
 int rtcCreateWebSocket(const char *url) {
 	return wrap([&] {
-		auto ws = std::make_shared<WebSocket>();
-		ws->open(url);
-		return emplaceWebSocket(ws);
+		auto webSocket = std::make_shared<WebSocket>();
+		webSocket->open(url);
+		return emplaceWebSocket(webSocket);
 	});
 }
 
 int rtcCreateWebSocketEx(const char *url, const rtcWsConfiguration *config) {
 	return wrap([&] {
+		if (!url)
+			throw std::invalid_argument("Unexpected null pointer for URL");
+
+		if (!config)
+			throw std::invalid_argument("Unexpected null pointer for config");
+
 		WebSocket::Configuration c;
 		c.disableTlsVerification = config->disableTlsVerification;
-		auto ws = std::make_shared<WebSocket>(c);
-		ws->open(url);
-		return emplaceWebSocket(ws);
+		auto webSocket = std::make_shared<WebSocket>(std::move(c));
+		webSocket->open(url);
+		return emplaceWebSocket(webSocket);
 	});
 }
 
-int rtcDeleteWebsocket(int ws) {
+int rtcDeleteWebSocket(int ws) {
 	return wrap([&] {
 		auto webSocket = getWebSocket(ws);
 		webSocket->onOpen(nullptr);
@@ -792,6 +823,76 @@ int rtcDeleteWebsocket(int ws) {
 	});
 }
 
+int rtcGetWebSocketRemoteAddress(int ws, char *buffer, int size) {
+	return wrap([&] {
+		auto webSocket = getWebSocket(ws);
+		if (auto remoteAddress = webSocket->remoteAddress())
+			return copyAndReturn(*remoteAddress, buffer, size);
+		else
+			return RTC_ERR_NOT_AVAIL;
+	});
+}
+
+int rtcGetWebSocketPath(int ws, char *buffer, int size) {
+	return wrap([&] {
+		auto webSocket = getWebSocket(ws);
+		if (auto path = webSocket->path())
+			return copyAndReturn(*path, buffer, size);
+		else
+			return RTC_ERR_NOT_AVAIL;
+	});
+}
+
+RTC_EXPORT int rtcCreateWebSocketServer(const rtcWsServerConfiguration *config,
+                                        rtcWebSocketClientCallbackFunc cb) {
+	return wrap([&] {
+		if (!config)
+			throw std::invalid_argument("Unexpected null pointer for config");
+
+		if (!cb)
+			throw std::invalid_argument("Unexpected null pointer for client callback");
+
+		WebSocketServer::Configuration c;
+		c.port = config->port;
+		c.enableTls = config->enableTls;
+		c.certificatePemFile = config->certificatePemFile
+		                           ? make_optional(string(config->certificatePemFile))
+		                           : nullopt;
+		c.keyPemFile = config->keyPemFile ? make_optional(string(config->keyPemFile)) : nullopt;
+		c.keyPemPass = config->keyPemPass ? make_optional(string(config->keyPemPass)) : nullopt;
+		auto webSocketServer = std::make_shared<WebSocketServer>(std::move(c));
+		int wsserver = emplaceWebSocketServer(webSocketServer);
+
+		webSocketServer->onClient([wsserver, cb](shared_ptr<WebSocket> webSocket) {
+			int ws = emplaceWebSocket(webSocket);
+			if (auto ptr = getUserPointer(wsserver)) {
+				rtcSetUserPointer(wsserver, *ptr);
+				cb(wsserver, ws, *ptr);
+			}
+		});
+
+		return wsserver;
+	});
+}
+
+RTC_EXPORT int rtcDeleteWebSocketServer(int wsserver) {
+	return wrap([&] {
+		auto webSocketServer = getWebSocketServer(wsserver);
+		webSocketServer->onClient(nullptr);
+		webSocketServer->stop();
+
+		eraseWebSocketServer(wsserver);
+		return RTC_ERR_SUCCESS;
+	});
+}
+
+RTC_EXPORT int rtcGetWebSocketServerPort(int wsserver) {
+	return wrap([&] {
+		auto webSocketServer = getWebSocketServer(wsserver);
+		return int(webSocketServer->port());
+	});
+}
+
 #endif
 
 int rtcSetLocalDescriptionCallback(int pc, rtcDescriptionCallbackFunc cb) {
@@ -1272,7 +1373,8 @@ int rtcSetSctpSettings(const rtcSctpSettings *settings) {
 			s.maxRetransmitTimeout = std::chrono::milliseconds(settings->maxRetransmitTimeoutMs);
 
 		if (settings->initialRetransmitTimeoutMs > 0)
-			s.initialRetransmitTimeout = std::chrono::milliseconds(settings->initialRetransmitTimeoutMs);
+			s.initialRetransmitTimeout =
+			    std::chrono::milliseconds(settings->initialRetransmitTimeoutMs);
 
 		if (settings->maxRetransmitAttempts > 0)
 			s.maxRetransmitAttempts = settings->maxRetransmitAttempts;

+ 144 - 90
src/impl/certificate.cpp

@@ -28,70 +28,37 @@
 
 namespace rtc::impl {
 
-const string COMMON_NAME = "libdatachannel";
-
 #if USE_GNUTLS
 
-Certificate::Certificate(string crt_pem, string key_pem)
-    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials) {
+Certificate Certificate::FromString(string crt_pem, string key_pem) {
+	PLOG_DEBUG << "Importing certificate from PEM string (GnuTLS)";
 
+	shared_ptr<gnutls_certificate_credentials_t> creds(gnutls::new_credentials(),
+	                                                   gnutls::free_credentials);
 	gnutls_datum_t crt_datum = gnutls::make_datum(crt_pem.data(), crt_pem.size());
 	gnutls_datum_t key_datum = gnutls::make_datum(key_pem.data(), key_pem.size());
+	gnutls::check(
+	    gnutls_certificate_set_x509_key_mem(*creds, &crt_datum, &key_datum, GNUTLS_X509_FMT_PEM),
+	    "Unable to import PEM certificate and key");
 
-	gnutls::check(gnutls_certificate_set_x509_key_mem(*mCredentials, &crt_datum, &key_datum,
-	                                                  GNUTLS_X509_FMT_PEM),
-	              "Unable to import PEM");
-
-	auto new_crt_list = [this]() -> gnutls_x509_crt_t * {
-		gnutls_x509_crt_t *crt_list = nullptr;
-		unsigned int crt_list_size = 0;
-		gnutls::check(gnutls_certificate_get_x509_crt(*mCredentials, 0, &crt_list, &crt_list_size));
-		assert(crt_list_size == 1);
-		return crt_list;
-	};
-
-	auto free_crt_list = [](gnutls_x509_crt_t *crt_list) {
-		gnutls_x509_crt_deinit(crt_list[0]);
-		gnutls_free(crt_list);
-	};
-
-	unique_ptr<gnutls_x509_crt_t, decltype(free_crt_list)> crt_list(new_crt_list(), free_crt_list);
-
-	mFingerprint = make_fingerprint(*crt_list);
-}
-
-Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
-    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
-      mFingerprint(make_fingerprint(crt)) {
-
-	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
-	              "Unable to set certificate and key pair in credentials");
+	return Certificate(std::move(creds));
 }
 
-gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
+Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
+                                  const string &pass) {
+	PLOG_DEBUG << "Importing certificate from PEM file (GnuTLS): " << crt_pem_file;
 
-string Certificate::fingerprint() const { return mFingerprint; }
+	shared_ptr<gnutls_certificate_credentials_t> creds(gnutls::new_credentials(),
+	                                                   gnutls::free_credentials);
+	gnutls::check(gnutls_certificate_set_x509_key_file2(*creds, crt_pem_file.c_str(),
+	                                                    key_pem_file.c_str(), GNUTLS_X509_FMT_PEM,
+	                                                    pass.c_str(), 0),
+	              "Unable to import PEM certificate and key from file");
 
-string make_fingerprint(gnutls_x509_crt_t crt) {
-	const size_t size = 32;
-	unsigned char buffer[size];
-	size_t len = size;
-	gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
-	              "X509 fingerprint error");
-
-	std::ostringstream oss;
-	oss << std::hex << std::uppercase << std::setfill('0');
-	for (size_t i = 0; i < len; ++i) {
-		if (i)
-			oss << std::setw(1) << ':';
-		oss << std::setw(2) << unsigned(buffer[i]);
-	}
-	return oss.str();
+	return Certificate(std::move(creds));
 }
 
-namespace {
-
-certificate_ptr make_certificate_impl(CertificateType type) {
+Certificate Certificate::Generate(CertificateType type, const string &commonName) {
 	PLOG_DEBUG << "Generating certificate (GnuTLS)";
 
 	using namespace gnutls;
@@ -127,8 +94,8 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	gnutls_x509_crt_set_expiration_time(*crt, (now + hours(24 * 365)).time_since_epoch().count());
 	gnutls_x509_crt_set_version(*crt, 1);
 	gnutls_x509_crt_set_key(*crt, *privkey);
-	gnutls_x509_crt_set_dn_by_oid(*crt, GNUTLS_OID_X520_COMMON_NAME, 0, COMMON_NAME.data(),
-	                              COMMON_NAME.size());
+	gnutls_x509_crt_set_dn_by_oid(*crt, GNUTLS_OID_X520_COMMON_NAME, 0, commonName.data(),
+	                              commonName.size());
 
 	const size_t serialSize = 16;
 	char serial[serialSize];
@@ -138,48 +105,49 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	gnutls::check(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0),
 	              "Unable to auto-sign certificate");
 
-	return std::make_shared<Certificate>(*crt, *privkey);
+	return Certificate(*crt, *privkey);
 }
 
-} // namespace
+Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
+    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
+      mFingerprint(make_fingerprint(crt)) {
 
-#else // USE_GNUTLS==0
+	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
+	              "Unable to set certificate and key pair in credentials");
+}
 
-Certificate::Certificate(string crt_pem, string key_pem) {
-	BIO *bio = BIO_new(BIO_s_mem());
-	BIO_write(bio, crt_pem.data(), int(crt_pem.size()));
-	mX509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, 0, 0), X509_free);
-	BIO_free(bio);
-	if (!mX509)
-		throw std::invalid_argument("Unable to import certificate PEM");
+Certificate::Certificate(shared_ptr<gnutls_certificate_credentials_t> creds)
+    : mCredentials(std::move(creds)), mFingerprint(make_fingerprint(*mCredentials)) {}
 
-	bio = BIO_new(BIO_s_mem());
-	BIO_write(bio, key_pem.data(), int(key_pem.size()));
-	mPKey = shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio, nullptr, 0, 0), EVP_PKEY_free);
-	BIO_free(bio);
-	if (!mPKey)
-		throw std::invalid_argument("Unable to import PEM key PEM");
+gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
 
-	mFingerprint = make_fingerprint(mX509.get());
-}
+string Certificate::fingerprint() const { return mFingerprint; }
 
-Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
-    : mX509(std::move(x509)), mPKey(std::move(pkey)) {
-	mFingerprint = make_fingerprint(mX509.get());
-}
+string make_fingerprint(gnutls_certificate_credentials_t credentials) {
+	auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * {
+		gnutls_x509_crt_t *crt_list = nullptr;
+		unsigned int crt_list_size = 0;
+		gnutls::check(gnutls_certificate_get_x509_crt(credentials, 0, &crt_list, &crt_list_size));
+		assert(crt_list_size == 1);
+		return crt_list;
+	};
 
-string Certificate::fingerprint() const { return mFingerprint; }
+	auto free_crt_list = [](gnutls_x509_crt_t *crt_list) {
+		gnutls_x509_crt_deinit(crt_list[0]);
+		gnutls_free(crt_list);
+	};
 
-std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
-	return {mX509.get(), mPKey.get()};
+	unique_ptr<gnutls_x509_crt_t, decltype(free_crt_list)> crt_list(new_crt_list(), free_crt_list);
+
+	return make_fingerprint(*crt_list);
 }
 
-string make_fingerprint(X509 *x509) {
+string make_fingerprint(gnutls_x509_crt_t crt) {
 	const size_t size = 32;
 	unsigned char buffer[size];
-	unsigned int len = size;
-	if (!X509_digest(x509, EVP_sha256(), buffer, &len))
-		throw std::runtime_error("X509 fingerprint error");
+	size_t len = size;
+	gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
+	              "X509 fingerprint error");
 
 	std::ostringstream oss;
 	oss << std::hex << std::uppercase << std::setfill('0');
@@ -191,9 +159,69 @@ string make_fingerprint(X509 *x509) {
 	return oss.str();
 }
 
+#else // USE_GNUTLS==0
+
+#include <cstdio>
+
 namespace {
 
-certificate_ptr make_certificate_impl(CertificateType type) {
+// Dummy password callback that copies the password from user data
+int dummy_pass_cb(char *buf, int size, int /*rwflag*/, void *u) {
+	const char *pass = static_cast<char *>(u);
+	return snprintf(buf, size, "%s", pass);
+}
+
+} // namespace
+
+Certificate Certificate::FromString(string crt_pem, string key_pem) {
+	PLOG_DEBUG << "Importing certificate from PEM string (OpenSSL)";
+
+	BIO *bio = BIO_new(BIO_s_mem());
+	BIO_write(bio, crt_pem.data(), int(crt_pem.size()));
+	auto x509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, nullptr, nullptr), X509_free);
+	BIO_free(bio);
+	if (!x509)
+		throw std::invalid_argument("Unable to import PEM certificate");
+
+	bio = BIO_new(BIO_s_mem());
+	BIO_write(bio, key_pem.data(), int(key_pem.size()));
+	auto pkey = shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr),
+	                                 EVP_PKEY_free);
+	BIO_free(bio);
+	if (!pkey)
+		throw std::invalid_argument("Unable to import PEM key");
+
+	return Certificate(x509, pkey);
+}
+
+Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
+                                  const string &pass) {
+	PLOG_DEBUG << "Importing certificate from PEM file (OpenSSL): " << crt_pem_file;
+
+	FILE *file = fopen(crt_pem_file.c_str(), "r");
+	if (!file)
+		throw std::invalid_argument("Unable to open PEM certificate file");
+
+	auto x509 = shared_ptr<X509>(PEM_read_X509(file, nullptr, nullptr, nullptr), X509_free);
+	fclose(file);
+	if (!x509)
+		throw std::invalid_argument("Unable to import PEM certificate from file");
+
+	file = fopen(key_pem_file.c_str(), "r");
+	if (!file)
+		throw std::invalid_argument("Unable to open PEM key file");
+
+	auto pkey = shared_ptr<EVP_PKEY>(
+	    PEM_read_PrivateKey(file, nullptr, dummy_pass_cb, const_cast<char *>(pass.c_str())),
+	    EVP_PKEY_free);
+	fclose(file);
+	if (!pkey)
+		throw std::invalid_argument("Unable to import PEM key from file");
+
+	return Certificate(x509, pkey);
+}
+
+Certificate Certificate::Generate(CertificateType type, const string &commonName) {
 	PLOG_DEBUG << "Generating certificate (OpenSSL)";
 
 	shared_ptr<X509> x509(X509_new(), X509_free);
@@ -212,8 +240,8 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	case CertificateType::Ecdsa: {
 		PLOG_VERBOSE << "Generating ECDSA P-256 key pair";
 
-		unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ecc(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1),
-		                                               EC_KEY_free);
+		unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ecc(
+		    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
 		if (!ecc)
 			throw std::runtime_error("Unable to allocate structure for ECDSA P-256 key pair");
 
@@ -250,7 +278,7 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 
 	const size_t serialSize = 16;
 	auto *commonNameBytes =
-	    reinterpret_cast<unsigned char *>(const_cast<char *>(COMMON_NAME.c_str()));
+	    reinterpret_cast<unsigned char *>(const_cast<char *>(commonName.c_str()));
 
 	if (!X509_set_pubkey(x509.get(), pkey.get()))
 		throw std::runtime_error("Unable to set certificate public key");
@@ -269,17 +297,43 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	if (!X509_sign(x509.get(), pkey.get(), EVP_sha256()))
 		throw std::runtime_error("Unable to auto-sign certificate");
 
-	return std::make_shared<Certificate>(x509, pkey);
+	return Certificate(x509, pkey);
 }
 
-} // namespace
+Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
+    : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {}
+
+string Certificate::fingerprint() const { return mFingerprint; }
+
+std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
+	return {mX509.get(), mPKey.get()};
+}
+
+string make_fingerprint(X509 *x509) {
+	const size_t size = 32;
+	unsigned char buffer[size];
+	unsigned int len = size;
+	if (!X509_digest(x509, EVP_sha256(), buffer, &len))
+		throw std::runtime_error("X509 fingerprint error");
+
+	std::ostringstream oss;
+	oss << std::hex << std::uppercase << std::setfill('0');
+	for (size_t i = 0; i < len; ++i) {
+		if (i)
+			oss << std::setw(1) << ':';
+		oss << std::setw(2) << unsigned(buffer[i]);
+	}
+	return oss.str();
+}
 
 #endif
 
 // Common for GnuTLS and OpenSSL
 
 future_certificate_ptr make_certificate(CertificateType type) {
-	return ThreadPool::Instance().enqueue(make_certificate_impl, type);
+	return ThreadPool::Instance().enqueue([type]() {
+		return std::make_shared<Certificate>(Certificate::Generate(type, "libdatachannel"));
+	});
 }
 
 } // namespace rtc::impl

+ 11 - 6
src/impl/certificate.hpp

@@ -20,8 +20,8 @@
 #define RTC_IMPL_CERTIFICATE_H
 
 #include "common.hpp"
-#include "tls.hpp"
 #include "configuration.hpp" // for CertificateType
+#include "tls.hpp"
 
 #include <future>
 #include <tuple>
@@ -30,7 +30,10 @@ namespace rtc::impl {
 
 class Certificate {
 public:
-	Certificate(string crt_pem, string key_pem);
+	static Certificate FromString(string crt_pem, string key_pem);
+	static Certificate FromFile(const string &crt_pem_file, const string &key_pem_file,
+	                            const string &pass = "");
+	static Certificate Generate(CertificateType type, const string &commonName);
 
 #if USE_GNUTLS
 	Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
@@ -44,16 +47,18 @@ public:
 
 private:
 #if USE_GNUTLS
-	shared_ptr<gnutls_certificate_credentials_t> mCredentials;
+	Certificate(shared_ptr<gnutls_certificate_credentials_t> creds);
+	const shared_ptr<gnutls_certificate_credentials_t> mCredentials;
 #else
-	shared_ptr<X509> mX509;
-	shared_ptr<EVP_PKEY> mPKey;
+	const shared_ptr<X509> mX509;
+	const shared_ptr<EVP_PKEY> mPKey;
 #endif
 
-	string mFingerprint;
+	const string mFingerprint;
 };
 
 #if USE_GNUTLS
+string make_fingerprint(gnutls_certificate_credentials_t credentials);
 string make_fingerprint(gnutls_x509_crt_t crt);
 #else
 string make_fingerprint(X509 *x509);

+ 68 - 36
src/impl/dtlstransport.cpp

@@ -17,8 +17,8 @@
  */
 
 #include "dtlstransport.hpp"
-#include "internals.hpp"
 #include "icetransport.hpp"
+#include "internals.hpp"
 
 #include <chrono>
 #include <cstring>
@@ -54,6 +54,9 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
+	if (!mCertificate)
+		throw std::invalid_argument("DTLS certificate is null");
+
 	gnutls_certificate_credentials_t creds = mCertificate->credentials();
 	gnutls_certificate_set_verify_function(creds, CertificateCallback);
 
@@ -244,61 +247,87 @@ void DtlsTransport::runRecvLoop() {
 
 int DtlsTransport::CertificateCallback(gnutls_session_t session) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(gnutls_session_get_ptr(session));
+	try {
+		if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
+			return GNUTLS_E_CERTIFICATE_ERROR;
+		}
 
-	if (gnutls_certificate_type_get(session) != GNUTLS_CRT_X509) {
-		return GNUTLS_E_CERTIFICATE_ERROR;
-	}
+		unsigned int count = 0;
+		const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count);
+		if (!array || count == 0) {
+			return GNUTLS_E_CERTIFICATE_ERROR;
+		}
 
-	unsigned int count = 0;
-	const gnutls_datum_t *array = gnutls_certificate_get_peers(session, &count);
-	if (!array || count == 0) {
-		return GNUTLS_E_CERTIFICATE_ERROR;
-	}
+		gnutls_x509_crt_t crt;
+		gnutls::check(gnutls_x509_crt_init(&crt));
+		int ret = gnutls_x509_crt_import(crt, &array[0], GNUTLS_X509_FMT_DER);
+		if (ret != GNUTLS_E_SUCCESS) {
+			gnutls_x509_crt_deinit(crt);
+			return GNUTLS_E_CERTIFICATE_ERROR;
+		}
 
-	gnutls_x509_crt_t crt;
-	gnutls::check(gnutls_x509_crt_init(&crt));
-	int ret = gnutls_x509_crt_import(crt, &array[0], GNUTLS_X509_FMT_DER);
-	if (ret != GNUTLS_E_SUCCESS) {
+		string fingerprint = make_fingerprint(crt);
 		gnutls_x509_crt_deinit(crt);
-		return GNUTLS_E_CERTIFICATE_ERROR;
-	}
 
-	string fingerprint = make_fingerprint(crt);
-	gnutls_x509_crt_deinit(crt);
+		bool success = t->mVerifierCallback(fingerprint);
+		return success ? GNUTLS_E_SUCCESS : GNUTLS_E_CERTIFICATE_ERROR;
 
-	bool success = t->mVerifierCallback(fingerprint);
-	return success ? GNUTLS_E_SUCCESS : GNUTLS_E_CERTIFICATE_ERROR;
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return GNUTLS_E_CERTIFICATE_ERROR;
+	}
 }
 
 ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	if (len > 0) {
-		auto b = reinterpret_cast<const byte *>(data);
-		t->outgoing(make_message(b, b + len));
+	try {
+		if (len > 0) {
+			auto b = reinterpret_cast<const byte *>(data);
+			t->outgoing(make_message(b, b + len));
+		}
+		gnutls_transport_set_errno(t->mSession, 0);
+		return ssize_t(len);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
-	gnutls_transport_set_errno(t->mSession, 0);
-	return ssize_t(len);
 }
 
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	if (auto next = t->mIncomingQueue.pop()) {
-		message_ptr message = std::move(*next);
-		ssize_t len = std::min(maxlen, message->size());
-		std::memcpy(data, message->data(), len);
+	try {
+		if (auto next = t->mIncomingQueue.pop()) {
+			message_ptr message = std::move(*next);
+			ssize_t len = std::min(maxlen, message->size());
+			std::memcpy(data, message->data(), len);
+			gnutls_transport_set_errno(t->mSession, 0);
+			return len;
+		}
+
+		// Closed
 		gnutls_transport_set_errno(t->mSession, 0);
-		return len;
+		return 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
-	// Closed
-	gnutls_transport_set_errno(t->mSession, 0);
-	return 0;
 }
 
 int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	bool notEmpty = t->mIncomingQueue.wait(
-	    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
-	return notEmpty ? 1 : 0;
+	try {
+		bool notEmpty = t->mIncomingQueue.wait(
+		    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
+		return notEmpty ? 1 : 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return 1;
+	}
 }
 
 #else // USE_GNUTLS==0
@@ -330,7 +359,7 @@ void DtlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
+DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr certificate,
                              optional<size_t> mtu, verifier_callback verifierCallback,
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
@@ -338,6 +367,9 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
       mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
+	if (!mCertificate)
+		throw std::invalid_argument("DTLS certificate is null");
+
 	try {
 		mCtx = SSL_CTX_new(DTLS_method());
 		if (!mCtx)

+ 2 - 0
src/impl/dtlstransport.hpp

@@ -51,6 +51,8 @@ public:
 	virtual bool stop() override;
 	virtual bool send(message_ptr message) override; // false if dropped
 
+	bool isClient() const { return mIsClient; }
+
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;

+ 10 - 3
src/impl/peerconnection.cpp

@@ -186,8 +186,11 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 
 		PLOG_VERBOSE << "Starting DTLS transport";
 
-		auto certificate = mCertificate.get();
 		auto lower = std::atomic_load(&mIceTransport);
+		if(!lower)
+			throw std::logic_error("No underlying ICE transport for DTLS transport");
+
+		auto certificate = mCertificate.get();
 		auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1);
 		auto dtlsStateChangeCallback =
 		    [this, weak_this = weak_from_this()](DtlsTransport::State transportState) {
@@ -258,15 +261,19 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 
 		PLOG_VERBOSE << "Starting SCTP transport";
 
+		auto lower = std::atomic_load(&mDtlsTransport);
+		if(!lower)
+			throw std::logic_error("No underlying DTLS transport for SCTP transport");
+
 		auto remote = remoteDescription();
 		if (!remote || !remote->application())
 			throw std::logic_error("Starting SCTP transport without application description");
 
+		uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT);
+
 		// This is the last occasion to ensure the stream numbers are coherent with the role
 		shiftDataChannels();
 
-		uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT);
-		auto lower = std::atomic_load(&mDtlsTransport);
 		auto transport = std::make_shared<SctpTransport>(
 		    lower, config, sctpPort, weak_bind(&PeerConnection::forwardMessage, this, _1),
 		    weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),

+ 88 - 0
src/impl/selectinterrupter.cpp

@@ -0,0 +1,88 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "selectinterrupter.hpp"
+#include "internals.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#ifndef _WIN32
+#include <fcntl.h>
+#include <unistd.h>
+#endif
+
+namespace rtc::impl {
+
+SelectInterrupter::SelectInterrupter() {
+#ifndef _WIN32
+	int pipefd[2];
+	if (::pipe(pipefd) != 0)
+		throw std::runtime_error("Failed to create pipe");
+	::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
+	::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
+	mPipeOut = pipefd[1]; // read
+	mPipeIn = pipefd[0];  // write
+#endif
+}
+
+SelectInterrupter::~SelectInterrupter() {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock != INVALID_SOCKET)
+		::closesocket(mDummySock);
+#else
+	::close(mPipeIn);
+	::close(mPipeOut);
+#endif
+}
+
+int SelectInterrupter::prepare(fd_set &readfds) {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock == INVALID_SOCKET)
+		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
+	FD_SET(mDummySock, &readfds);
+	return SOCKET_TO_INT(mDummySock) + 1;
+#else
+	char dummy;
+	if (::read(mPipeIn, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
+		PLOG_WARNING << "Reading from interrupter pipe failed, errno=" << errno;
+	}
+	FD_SET(mPipeIn, &readfds);
+	return mPipeIn + 1;
+#endif
+}
+
+void SelectInterrupter::interrupt() {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock != INVALID_SOCKET) {
+		::closesocket(mDummySock);
+		mDummySock = INVALID_SOCKET;
+	}
+#else
+	char dummy = 0;
+	if (::write(mPipeOut, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
+		PLOG_WARNING << "Writing to interrupter pipe failed, errno=" << errno;
+	}
+#endif
+}
+
+} // namespace rtc::impl
+
+#endif

+ 55 - 0
src/impl/selectinterrupter.hpp

@@ -0,0 +1,55 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_SELECT_INTERRUPTER_H
+#define RTC_IMPL_SELECT_INTERRUPTER_H
+
+#include "common.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <mutex>
+
+// Use the socket defines from libjuice
+#include "../deps/libjuice/src/socket.h"
+
+namespace rtc::impl {
+
+// Utility class to interrupt select()
+class SelectInterrupter final {
+public:
+	SelectInterrupter();
+	~SelectInterrupter();
+
+	int prepare(fd_set &readfds);
+	void interrupt();
+
+private:
+	std::mutex mMutex;
+#ifdef _WIN32
+	socket_t mDummySock = INVALID_SOCKET;
+#else // assume POSIX
+	int mPipeIn, mPipeOut;
+#endif
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 69 - 0
src/impl/sha.cpp

@@ -0,0 +1,69 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "sha.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#if USE_GNUTLS
+#include <nettle/sha1.h>
+#else
+#include <openssl/sha.h>
+#endif
+
+namespace rtc::impl {
+
+namespace {
+
+binary Sha1(const byte *data, size_t size) {
+#if USE_GNUTLS
+
+binary output(SHA1_DIGEST_SIZE);
+struct sha1_ctx ctx;
+sha1_init(&ctx);
+sha1_update(&ctx, size, reinterpret_cast<const uint8_t*>(data));
+sha1_digest(&ctx, SHA1_DIGEST_SIZE, reinterpret_cast<uint8_t*>(output.data()));
+return output;
+
+#else // USE_GNUTLS==0
+
+binary output(SHA_DIGEST_LENGTH);
+SHA_CTX ctx;
+SHA1_Init(&ctx);
+SHA1_Update(&ctx, data, size);
+SHA1_Final(reinterpret_cast<unsigned char*>(output.data()), &ctx);
+return output;
+
+#endif
+}
+
+}
+
+binary Sha1(const binary &input) {
+	return Sha1(input.data(), input.size());
+}
+
+
+binary Sha1(const string &input) {
+	return Sha1(reinterpret_cast<const byte*>(input.data()), input.size());
+}
+
+} // namespace rtc::impl
+
+#endif
+

+ 35 - 0
src/impl/sha.hpp

@@ -0,0 +1,35 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_SHA_H
+#define RTC_IMPL_SHA_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "common.hpp"
+
+namespace rtc::impl {
+
+binary Sha1(const binary &input);
+binary Sha1(const string &input);
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 185 - 0
src/impl/tcpserver.cpp

@@ -0,0 +1,185 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "tcpserver.hpp"
+#include "internals.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <unistd.h>
+#endif
+
+namespace rtc::impl {
+
+TcpServer::TcpServer(uint16_t port) {
+	PLOG_DEBUG << "Initializing TCP server";
+	listen(port);
+}
+
+TcpServer::~TcpServer() { close(); }
+
+shared_ptr<TcpTransport> TcpServer::accept() {
+	while (true) {
+		std::unique_lock lock(mSockMutex);
+
+		if (mSock == INVALID_SOCKET)
+			break;
+
+		fd_set readfds;
+		FD_ZERO(&readfds);
+		FD_SET(mSock, &readfds);
+		int n = std::max(mInterrupter.prepare(readfds), SOCKET_TO_INT(mSock) + 1);
+		lock.unlock();
+		int ret = ::select(n, &readfds, NULL, NULL, NULL);
+		lock.lock();
+		if (mSock == INVALID_SOCKET)
+			break;
+
+		if (ret < 0) {
+			if (sockerrno == SEINTR || sockerrno == SEAGAIN) // interrupted
+				continue;
+			else
+				throw std::runtime_error("Failed to wait for socket connection");
+		}
+
+		if (FD_ISSET(mSock, &readfds)) {
+			struct sockaddr_storage addr;
+			socklen_t addrlen = sizeof(addr);
+			socket_t incomingSock = ::accept(mSock, (struct sockaddr *)&addr, &addrlen);
+
+			if (incomingSock != INVALID_SOCKET) {
+				return std::make_shared<TcpTransport>(incomingSock, nullptr); // no state callback
+
+			} else if (sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) {
+				PLOG_ERROR << "TCP server failed, errno=" << sockerrno;
+				throw std::runtime_error("TCP server failed");
+			}
+		}
+	}
+
+	PLOG_DEBUG << "TCP server closed";
+	return nullptr;
+}
+
+void TcpServer::close() {
+	std::unique_lock lock(mSockMutex);
+	if (mSock != INVALID_SOCKET) {
+		PLOG_DEBUG << "Closing TCP server socket";
+		::closesocket(mSock);
+		mSock = INVALID_SOCKET;
+		mInterrupter.interrupt();
+	}
+}
+
+void TcpServer::listen(uint16_t port) {
+	PLOG_DEBUG << "Listening on port " << port;
+
+	struct addrinfo hints = {};
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	hints.ai_protocol = IPPROTO_TCP;
+	hints.ai_flags = AI_ADDRCONFIG;
+
+	struct addrinfo *result = nullptr;
+	if (::getaddrinfo(nullptr, std::to_string(port).c_str(), &hints, &result))
+		throw std::runtime_error("Resolution failed for local address");
+
+	static const auto find_family = [](struct addrinfo *ai_list, int family) {
+		struct addrinfo *ai = ai_list;
+		while (ai && ai->ai_family != family)
+			ai = ai->ai_next;
+		return ai;
+	};
+
+	struct addrinfo *ai;
+	if ((ai = find_family(result, AF_INET6)) == NULL && (ai = find_family(result, AF_INET)) == NULL)
+		throw std::runtime_error("No suitable address family found");
+
+	try {
+		std::unique_lock lock(mSockMutex);
+		PLOG_VERBOSE << "Creating TCP server socket";
+
+		// Create socket
+		mSock = ::socket(ai->ai_family, SOCK_STREAM, IPPROTO_TCP);
+		if (mSock == INVALID_SOCKET)
+			throw std::runtime_error("TCP server socket creation failed");
+
+		// Listen on both IPv6 and IPv4
+		const sockopt_t disabled = 0;
+		if (ai->ai_family == AF_INET6)
+			::setsockopt(mSock, IPPROTO_IPV6, IPV6_V6ONLY, (const char *)&disabled,
+			             sizeof(disabled));
+
+		// Set non-blocking
+		ctl_t b = 1;
+		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+			throw std::runtime_error("Failed to set socket non-blocking mode");
+
+		// Bind socket
+		if (::bind(mSock, ai->ai_addr, ai->ai_addrlen) < 0) {
+			PLOG_WARNING << "TCP server socket binding on port " << port
+			             << " failed, errno=" << sockerrno;
+			throw std::runtime_error("TCP server socket binding failed");
+		}
+
+		// Listen
+		const int backlog = 10;
+		if (::listen(mSock, backlog) < 0) {
+			PLOG_WARNING << "TCP server socket listening failed, errno=" << sockerrno;
+			throw std::runtime_error("TCP server socket listening failed");
+		}
+
+		if (port != 0) {
+			mPort = port;
+		} else {
+			struct sockaddr_storage addr;
+			socklen_t addrlen = sizeof(addr);
+			if (::getsockname(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
+				throw std::runtime_error("getsockname failed");
+
+			switch (addr.ss_family) {
+			case AF_INET:
+				mPort = ntohs(reinterpret_cast<struct sockaddr_in *>(&addr)->sin_port);
+				break;
+			case AF_INET6:
+				mPort = ntohs(reinterpret_cast<struct sockaddr_in6 *>(&addr)->sin6_port);
+				break;
+			default:
+				throw std::logic_error("Unknown address family");
+			}
+		}
+	} catch (...) {
+		freeaddrinfo(result);
+		if (mSock != INVALID_SOCKET) {
+			::closesocket(mSock);
+			mSock = INVALID_SOCKET;
+		}
+		throw;
+	}
+
+	freeaddrinfo(result);
+}
+
+} // namespace rtc::impl
+
+#endif

+ 56 - 0
src/impl/tcpserver.hpp

@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_TCP_SERVER_H
+#define RTC_IMPL_TCP_SERVER_H
+
+#include "common.hpp"
+#include "queue.hpp"
+#include "tcptransport.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+// Use the socket defines from libjuice
+#include "../deps/libjuice/src/socket.h"
+
+namespace rtc::impl {
+
+class TcpServer {
+public:
+	TcpServer(uint16_t port);
+	~TcpServer();
+
+	shared_ptr<TcpTransport> accept();
+	void close();
+
+	uint16_t port() const { return mPort; }
+
+private:
+	void listen(uint16_t port);
+
+	uint16_t mPort;
+	socket_t mSock = INVALID_SOCKET;
+	std::mutex mSockMutex;
+	SelectInterrupter mInterrupter;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 46 - 77
src/impl/tcptransport.cpp

@@ -21,8 +21,6 @@
 
 #if RTC_ENABLE_WEBSOCKET
 
-#include <exception>
-
 #ifndef _WIN32
 #include <fcntl.h>
 #include <unistd.h>
@@ -30,67 +28,38 @@
 
 namespace rtc::impl {
 
-using std::to_string;
+TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
+    : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
+      mService(std::move(service)) {
 
-SelectInterrupter::SelectInterrupter() {
-#ifndef _WIN32
-	int pipefd[2];
-	if (::pipe(pipefd) != 0)
-		throw std::runtime_error("Failed to create pipe");
-	::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
-	::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
-	mPipeOut = pipefd[1]; // read
-	mPipeIn = pipefd[0];  // write
-#endif
+	PLOG_DEBUG << "Initializing TCP transport";
 }
 
-SelectInterrupter::~SelectInterrupter() {
-	std::lock_guard lock(mMutex);
-#ifdef _WIN32
-	if (mDummySock != INVALID_SOCKET)
-		::closesocket(mDummySock);
-#else
-	::close(mPipeIn);
-	::close(mPipeOut);
-#endif
-}
+TcpTransport::TcpTransport(socket_t sock, state_callback callback)
+    : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) {
 
-int SelectInterrupter::prepare(fd_set &readfds, [[maybe_unused]] fd_set &writefds) {
-	std::lock_guard lock(mMutex);
-#ifdef _WIN32
-	if (mDummySock == INVALID_SOCKET)
-		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
-	FD_SET(mDummySock, &readfds);
-	return SOCKET_TO_INT(mDummySock) + 1;
-#else
-	char dummy;
-	if (::read(mPipeIn, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
-		PLOG_WARNING << "Reading from interrupter pipe failed, errno=" << errno;
-	}
-	FD_SET(mPipeIn, &readfds);
-	return mPipeIn + 1;
-#endif
-}
+	PLOG_DEBUG << "Initializing TCP transport with socket";
 
-void SelectInterrupter::interrupt() {
-	std::lock_guard lock(mMutex);
-#ifdef _WIN32
-	if (mDummySock != INVALID_SOCKET) {
-		::closesocket(mDummySock);
-		mDummySock = INVALID_SOCKET;
-	}
-#else
-	char dummy = 0;
-	if (::write(mPipeOut, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
-		PLOG_WARNING << "Writing to interrupter pipe failed, errno=" << errno;
-	}
-#endif
-}
+	// Set non-blocking
+	ctl_t b = 1;
+	if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+		throw std::runtime_error("Failed to set socket non-blocking mode");
 
-TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
-    : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
+	// Retrieve hostname and service
+	struct sockaddr_storage addr;
+	socklen_t addrlen = sizeof(addr);
+	if (::getpeername(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
+		throw std::runtime_error("getsockname failed");
 
-	PLOG_DEBUG << "Initializing TCP transport";
+	char node[MAX_NUMERICNODE_LEN];
+	char serv[MAX_NUMERICSERV_LEN];
+	if (::getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), addrlen, node,
+	                  MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
+	                  NI_NUMERICHOST | NI_NUMERICSERV) != 0)
+		throw std::runtime_error("getnameinfo failed");
+
+	mHostname = node;
+	mService = serv;
 }
 
 TcpTransport::~TcpTransport() { stop(); }
@@ -114,6 +83,9 @@ bool TcpTransport::stop() {
 
 bool TcpTransport::send(message_ptr message) {
 	std::unique_lock lock(mSockMutex);
+	if(state() == State::Connecting)
+		throw std::runtime_error("Connection is not open");
+
 	if (state() != State::Connected)
 		return false;
 
@@ -139,10 +111,12 @@ bool TcpTransport::outgoing(message_ptr message) {
 		return true;
 
 	mSendQueue.push(message);
-	interruptSelect(); // so the thread waits for writability
+	mInterrupter.interrupt(); // so the thread waits for writability
 	return false;
 }
 
+string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
+
 void TcpTransport::connect(const string &hostname, const string &service) {
 	PLOG_DEBUG << "Connecting to " << hostname << ":" << service;
 
@@ -197,6 +171,7 @@ void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
 		if (mSock == INVALID_SOCKET)
 			throw std::runtime_error("TCP socket creation failed");
 
+		// Set non-blocking
 		ctl_t b = 1;
 		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
 			throw std::runtime_error("Failed to set socket non-blocking mode");
@@ -269,7 +244,7 @@ void TcpTransport::close() {
 		mSock = INVALID_SOCKET;
 	}
 	changeState(State::Disconnected);
-	interruptSelect();
+	mInterrupter.interrupt();
 }
 
 bool TcpTransport::trySendQueue() {
@@ -301,7 +276,8 @@ bool TcpTransport::trySendMessage(message_ptr &message) {
 				message = make_message(message->end() - size, message->end());
 				return false;
 			} else {
-				throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
+				PLOG_ERROR << "Connection closed, errno=" << sockerrno;
+				throw std::runtime_error("Connection closed");
 			}
 		}
 
@@ -318,7 +294,8 @@ void TcpTransport::runLoop() {
 	// Connect
 	try {
 		changeState(State::Connecting);
-		connect(mHostname, mService);
+		if (mSock == INVALID_SOCKET)
+			connect(mHostname, mService);
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TCP connect: " << e.what();
@@ -337,7 +314,13 @@ void TcpTransport::runLoop() {
 				break;
 
 			fd_set readfds, writefds;
-			int n = prepareSelect(readfds, writefds);
+			FD_ZERO(&readfds);
+			FD_ZERO(&writefds);
+			FD_SET(mSock, &readfds);
+			if (!mSendQueue.empty())
+				FD_SET(mSock, &writefds);
+
+			int n = std::max(mInterrupter.prepare(readfds), SOCKET_TO_INT(mSock) + 1);
 
 			struct timeval tv;
 			tv.tv_sec = 10;
@@ -367,7 +350,8 @@ void TcpTransport::runLoop() {
 					if (sockerrno == SEAGAIN || sockerrno == SEWOULDBLOCK) {
 						continue;
 					} else {
-						throw std::runtime_error("Connection lost");
+						PLOG_WARNING << "TCP connection lost";
+						break;
 					}
 				}
 
@@ -388,21 +372,6 @@ void TcpTransport::runLoop() {
 	recv(nullptr);
 }
 
-int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) {
-	FD_ZERO(&readfds);
-	FD_ZERO(&writefds);
-	FD_SET(mSock, &readfds);
-
-	if (!mSendQueue.empty())
-		FD_SET(mSock, &writefds);
-
-	int n = SOCKET_TO_INT(mSock) + 1;
-	int m = mInterrupter.prepare(readfds, writefds);
-	return std::max(n, m);
-}
-
-void TcpTransport::interruptSelect() { mInterrupter.interrupt(); }
-
 } // namespace rtc::impl
 
 #endif

+ 8 - 22
src/impl/tcptransport.hpp

@@ -22,6 +22,7 @@
 #include "common.hpp"
 #include "queue.hpp"
 #include "transport.hpp"
+#include "selectinterrupter.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
 
@@ -33,27 +34,10 @@
 
 namespace rtc::impl {
 
-// Utility class to interrupt select()
-class SelectInterrupter {
-public:
-	SelectInterrupter();
-	~SelectInterrupter();
-
-	int prepare(fd_set &readfds, fd_set &writefds);
-	void interrupt();
-
-private:
-	std::mutex mMutex;
-#ifdef _WIN32
-	socket_t mDummySock = INVALID_SOCKET;
-#else // assume POSIX
-	int mPipeIn, mPipeOut;
-#endif
-};
-
 class TcpTransport : public Transport {
 public:
-	TcpTransport(const string &hostname, const string &service, state_callback callback);
+	TcpTransport(string hostname, string service, state_callback callback); // active
+	TcpTransport(socket_t sock, state_callback callback);                   // passive
 	~TcpTransport();
 
 	void start() override;
@@ -63,6 +47,10 @@ public:
 	void incoming(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 
+	bool isActive() const { return mIsActive; }
+
+	string remoteAddress() const;
+
 private:
 	void connect(const string &hostname, const string &service);
 	void connect(const sockaddr *addr, socklen_t addrlen);
@@ -73,9 +61,7 @@ private:
 
 	void runLoop();
 
-	int prepareSelect(fd_set &readfds, fd_set &writefds);
-	void interruptSelect();
-
+	const bool mIsActive;
 	string mHostname, mService;
 
 	socket_t mSock = INVALID_SOCKET;

+ 122 - 64
src/impl/tlstransport.cpp

@@ -32,6 +32,23 @@ namespace rtc::impl {
 
 #if USE_GNUTLS
 
+namespace {
+
+gnutls_certificate_credentials_t default_certificate_credentials() {
+	static std::mutex mutex;
+	static shared_ptr<gnutls_certificate_credentials_t> creds;
+
+	std::lock_guard lock(mutex);
+	if (!creds) {
+		creds = shared_ptr<gnutls_certificate_credentials_t>(gnutls::new_credentials(),
+		                                                     gnutls::free_credentials);
+		gnutls::check(gnutls_certificate_set_x509_system_trust(*creds));
+	}
+	return *creds;
+}
+
+} // namespace
+
 void TlsTransport::Init() {
 	// Nothing to do
 }
@@ -40,25 +57,28 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)) {
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
+                           certificate_ptr certificate, state_callback callback)
+    : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
-	gnutls::check(gnutls_certificate_allocate_credentials(&mCreds));
-	gnutls::check(gnutls_init(&mSession, GNUTLS_CLIENT));
+	gnutls::check(gnutls_init(&mSession, mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER));
 
 	try {
-		gnutls::check(gnutls_certificate_set_x509_system_trust(mCreds));
-		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
-
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *err_pos = NULL;
 		gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
 		              "Failed to set TLS priorities");
 
-		PLOG_VERBOSE << "Server Name Indication: " << mHost;
-		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size());
+		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE,
+		                                     certificate ? certificate->credentials()
+		                                                 : default_certificate_credentials()));
+
+		if (mIsClient && mHost) {
+			PLOG_VERBOSE << "Server Name Indication: " << *mHost;
+			gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size());
+		}
 
 		gnutls_session_set_ptr(mSession, this);
 		gnutls_transport_set_ptr(mSession, this);
@@ -68,7 +88,6 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 
 	} catch (...) {
 		gnutls_deinit(mSession);
-		gnutls_certificate_free_credentials(mCreds);
 		throw;
 	}
 }
@@ -77,7 +96,6 @@ TlsTransport::~TlsTransport() {
 	stop();
 
 	gnutls_deinit(mSession);
-	gnutls_certificate_free_credentials(mCreds);
 }
 
 void TlsTransport::start() {
@@ -117,10 +135,13 @@ bool TlsTransport::send(message_ptr message) {
 }
 
 void TlsTransport::incoming(message_ptr message) {
-	if (message)
-		mIncomingQueue.push(message);
-	else
+	if (!message) {
 		mIncomingQueue.stop();
+		return;
+	}
+
+	PLOG_VERBOSE << "Incoming size=" << message->size();
+	mIncomingQueue.push(message);
 }
 
 void TlsTransport::postHandshake() {
@@ -188,53 +209,72 @@ void TlsTransport::runRecvLoop() {
 
 ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
-	if (len > 0) {
-		auto b = reinterpret_cast<const byte *>(data);
-		t->outgoing(make_message(b, b + len));
+	try {
+		if (len > 0) {
+			auto b = reinterpret_cast<const byte *>(data);
+			t->outgoing(make_message(b, b + len));
+		}
+		gnutls_transport_set_errno(t->mSession, 0);
+		return ssize_t(len);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
-	gnutls_transport_set_errno(t->mSession, 0);
-	return ssize_t(len);
 }
 
 ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
+	try {
+		message_ptr &message = t->mIncomingMessage;
+		size_t &position = t->mIncomingMessagePosition;
 
-	message_ptr &message = t->mIncomingMessage;
-	size_t &position = t->mIncomingMessagePosition;
+		if (message && position >= message->size())
+			message.reset();
 
-	if (message && position >= message->size())
-		message.reset();
+		if (!message) {
+			position = 0;
+			while (auto next = t->mIncomingQueue.pop()) {
+				message = *next;
+				if (message->size() > 0)
+					break;
+				else
+					t->recv(message); // Pass zero-sized messages through
+			}
+		}
 
-	if (!message) {
-		position = 0;
-		while (auto next = t->mIncomingQueue.pop()) {
-			message = *next;
-			if (message->size() > 0)
-				break;
-			else
-				t->recv(message); // Pass zero-sized messages through
+		if (message) {
+			size_t available = message->size() - position;
+			ssize_t len = std::min(maxlen, available);
+			std::memcpy(data, message->data() + position, len);
+			position += len;
+			gnutls_transport_set_errno(t->mSession, 0);
+			return len;
+		} else {
+			// Closed
+			gnutls_transport_set_errno(t->mSession, 0);
+			return 0;
 		}
-	}
 
-	if (message) {
-		size_t available = message->size() - position;
-		ssize_t len = std::min(maxlen, available);
-		std::memcpy(data, message->data() + position, len);
-		position += len;
-		gnutls_transport_set_errno(t->mSession, 0);
-		return len;
-	} else {
-		// Closed
-		gnutls_transport_set_errno(t->mSession, 0);
-		return 0;
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		gnutls_transport_set_errno(t->mSession, ECONNRESET);
+		return -1;
 	}
 }
 
 int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
-	bool notEmpty = t->mIncomingQueue.wait(
-	    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
-	return notEmpty ? 1 : 0;
+	try {
+		bool notEmpty = t->mIncomingQueue.wait(
+		    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
+		return notEmpty ? 1 : 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return 1;
+	}
 }
 
 #else // USE_GNUTLS==0
@@ -253,8 +293,9 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 }
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)) {
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
+                           certificate_ptr certificate, state_callback callback)
+    : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
 
 	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
 
@@ -265,8 +306,14 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
 		               "Failed to set SSL priorities");
 
-		if (!SSL_CTX_set_default_verify_paths(mCtx)) {
-			PLOG_WARNING << "SSL root CA certificates unavailable";
+		if (certificate) {
+			auto [x509, pkey] = certificate->credentials();
+			SSL_CTX_use_certificate(mCtx, x509);
+			SSL_CTX_use_PrivateKey(mCtx, pkey);
+		} else {
+			if (!SSL_CTX_set_default_verify_paths(mCtx)) {
+				PLOG_WARNING << "SSL root CA certificates unavailable";
+			}
 		}
 
 		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
@@ -281,13 +328,18 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
-		SSL_set_hostflags(mSsl, 0);
-		openssl::check(SSL_set1_host(mSsl, mHost.c_str()), "Failed to set SSL host");
+		if (mIsClient && mHost) {
+			SSL_set_hostflags(mSsl, 0);
+			openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host");
 
-		PLOG_VERBOSE << "Server Name Indication: " << mHost;
-		SSL_set_tlsext_host_name(mSsl, mHost.c_str());
+			PLOG_VERBOSE << "Server Name Indication: " << *mHost;
+			SSL_set_tlsext_host_name(mSsl, mHost->c_str());
+		}
 
-		SSL_set_connect_state(mSsl);
+		if (mIsClient)
+			SSL_set_connect_state(mSsl);
+		else
+			SSL_set_accept_state(mSsl);
 
 		if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
 			throw std::runtime_error("Failed to create BIO");
@@ -359,10 +411,13 @@ bool TlsTransport::send(message_ptr message) {
 }
 
 void TlsTransport::incoming(message_ptr message) {
-	if (message)
-		mIncomingQueue.push(message);
-	else
+	if (!message) {
 		mIncomingQueue.stop();
+		return;
+	}
+
+	PLOG_VERBOSE << "Incoming size=" << message->size();
+	mIncomingQueue.push(message);
 }
 
 void TlsTransport::postHandshake() {
@@ -376,10 +431,11 @@ void TlsTransport::runRecvLoop() {
 	try {
 		changeState(State::Connecting);
 
+		int ret;
 		while (true) {
 			if (state() == State::Connecting) {
 				// Initiate or continue the handshake
-				int ret = SSL_do_handshake(mSsl);
+				ret = SSL_do_handshake(mSsl);
 				if (!openssl::check(mSsl, ret, "Handshake failed"))
 					break;
 
@@ -392,13 +448,15 @@ void TlsTransport::runRecvLoop() {
 					changeState(State::Connected);
 					postHandshake();
 				}
-			} else {
-				int ret = SSL_read(mSsl, buffer, bufferSize);
-				if (!openssl::check(mSsl, ret))
-					break;
+			}
 
-				if (ret > 0)
+			if (state() == State::Connected) {
+				// Input
+				while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
 					recv(make_message(buffer, buffer + ret));
+
+				if (!openssl::check(mSsl, ret))
+					break;
 			}
 
 			auto next = mIncomingQueue.pop();

+ 7 - 3
src/impl/tlstransport.hpp

@@ -19,6 +19,7 @@
 #ifndef RTC_IMPL_TLS_TRANSPORT_H
 #define RTC_IMPL_TLS_TRANSPORT_H
 
+#include "certificate.hpp"
 #include "common.hpp"
 #include "queue.hpp"
 #include "tls.hpp"
@@ -37,26 +38,29 @@ public:
 	static void Init();
 	static void Cleanup();
 
-	TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback);
+	TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host, certificate_ptr certificate,
+	             state_callback callback);
 	virtual ~TlsTransport();
 
 	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
 
+	bool isClient() const { return mIsClient; }
+
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void postHandshake();
 	void runRecvLoop();
 
-	string mHost;
+	const optional<string> mHost;
+	const bool mIsClient;
 
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
 
 #if USE_GNUTLS
 	gnutls_session_t mSession;
-	gnutls_certificate_credentials_t mCreds;
 
 	message_ptr mIncomingMessage;
 	size_t mIncomingMessagePosition = 0;

+ 1 - 0
src/impl/transport.hpp

@@ -61,6 +61,7 @@ public:
 	}
 
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
+	void onStateChange(state_callback callback) { mStateChangeCallback = std::move(callback); }
 	State state() const { return mState; }
 
 	virtual bool send(message_ptr message) { return outgoing(message); }

+ 3 - 3
src/impl/verifiedtlstransport.cpp

@@ -24,12 +24,12 @@
 namespace rtc::impl {
 
 VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
-                                           state_callback callback)
-    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {
+                                           certificate_ptr certificate, state_callback callback)
+    : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
 
 #if USE_GNUTLS
 	PLOG_DEBUG << "Setting up TLS certificate verification";
-	gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
+	gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0);
 #else
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);

+ 1 - 1
src/impl/verifiedtlstransport.hpp

@@ -27,7 +27,7 @@ namespace rtc::impl {
 
 class VerifiedTlsTransport final : public TlsTransport {
 public:
-	VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback);
+	VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host, certificate_ptr certificate, state_callback callback);
 	~VerifiedTlsTransport();
 };
 

+ 127 - 92
src/impl/websocket.cpp

@@ -19,8 +19,8 @@
 #if RTC_ENABLE_WEBSOCKET
 
 #include "websocket.hpp"
-#include "internals.hpp"
 #include "common.hpp"
+#include "internals.hpp"
 #include "threadpool.hpp"
 
 #include "tcptransport.hpp"
@@ -38,8 +38,10 @@ namespace rtc::impl {
 
 using namespace std::placeholders;
 
-WebSocket::WebSocket(Configuration config_)
-    : config(std::move(config_)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
+WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
+    : config(optConfig ? std::move(*optConfig) : Configuration()),
+      mCertificate(std::move(certificate)), mIsSecure(mCertificate != nullptr),
+      mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
 	PLOG_VERBOSE << "Creating WebSocket";
 }
 
@@ -48,7 +50,7 @@ WebSocket::~WebSocket() {
 	remoteClose();
 }
 
-void WebSocket::parse(const string &url) {
+void WebSocket::open(const string &url) {
 	PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
 
 	if (state != State::Closed)
@@ -64,34 +66,42 @@ void WebSocket::parse(const string &url) {
 	if (!std::regex_match(url, m, r) || m[10].length() == 0)
 		throw std::invalid_argument("Invalid WebSocket URL: " + url);
 
-	mScheme = m[2];
-	if (mScheme.empty())
-		mScheme = "ws";
-	else if (mScheme != "ws" && mScheme != "wss")
-		throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
-
-	mHostname = m[10];
-	mService = m[12];
-	if (mService.empty()) {
-		mService = mScheme == "ws" ? "80" : "443";
-		mHost = mHostname;
+	string scheme = m[2];
+	if (scheme.empty())
+		scheme = "ws";
+
+	if (scheme != "ws" && scheme != "wss")
+		throw std::invalid_argument("Invalid WebSocket scheme: " + scheme);
+
+	mIsSecure = (scheme != "ws");
+
+	string host;
+	string hostname = m[10];
+	string service = m[12];
+	if (service.empty()) {
+		service = mIsSecure ? "443" : "80";
+		host = hostname;
 	} else {
-		mHost = mHostname + ':' + mService;
+		host = hostname + ':' + service;
 	}
 
-	while (!mHostname.empty() && mHostname.front() == '[')
-		mHostname.erase(mHostname.begin());
-	while (!mHostname.empty() && mHostname.back() == ']')
-		mHostname.pop_back();
+	while (!hostname.empty() && hostname.front() == '[')
+		hostname.erase(hostname.begin());
+	while (!hostname.empty() && hostname.back() == ']')
+		hostname.pop_back();
+
+	string path = m[13];
+	if (path.empty())
+		path += '/';
 
-	mPath = m[13];
-	if (mPath.empty())
-		mPath += '/';
 	if (string query = m[15]; !query.empty())
-		mPath += "?" + query;
+		path += "?" + query;
+
+	mHostname = hostname; // for TLS SNI
+	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 	changeState(State::Connecting);
-	initTcpTransport();
+	setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
 }
 
 void WebSocket::close() {
@@ -165,37 +175,41 @@ void WebSocket::incoming(message_ptr message) {
 	}
 }
 
-shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
+shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
 	PLOG_VERBOSE << "Starting TCP transport";
+
+	if (!transport)
+		throw std::logic_error("TCP transport is null");
+
 	using State = TcpTransport::State;
 	try {
-		if (auto transport = std::atomic_load(&mTcpTransport))
-			return transport;
+		if (std::atomic_load(&mTcpTransport))
+			throw std::logic_error("TCP transport is already set");
+
+		transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+			switch (transportState) {
+			case State::Connected:
+				if (mIsSecure)
+					initTlsTransport();
+				else
+					initWsTransport();
+				break;
+			case State::Failed:
+				triggerError("TCP connection failed");
+				remoteClose();
+				break;
+			case State::Disconnected:
+				remoteClose();
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		});
 
-		auto transport = std::make_shared<TcpTransport>(
-		    mHostname, mService, [this, weak_this = weak_from_this()](State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (transportState) {
-			    case State::Connected:
-				    if (mScheme == "ws")
-					    initWsTransport();
-				    else
-					    initTlsTransport();
-				    break;
-			    case State::Failed:
-				    triggerError("TCP connection failed");
-				    remoteClose();
-				    break;
-			    case State::Disconnected:
-				    remoteClose();
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
 		std::atomic_store(&mTcpTransport, transport);
 		if (state == WebSocket::State::Closed) {
 			mTcpTransport.reset();
@@ -219,6 +233,9 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			return transport;
 
 		auto lower = std::atomic_load(&mTcpTransport);
+		if (!lower)
+			throw std::logic_error("No underlying TCP transport for TLS transport");
+
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
 			auto shared_this = weak_this.lock();
 			if (!shared_this)
@@ -240,19 +257,21 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			}
 		};
 
-		shared_ptr<TlsTransport> transport;
+		bool verify = mHostname.has_value() && !config.disableTlsVerification;
+
 #ifdef _WIN32
-		if (!config.disableTlsVerification) {
+		if (std::exchange(verify, false)) {
 			PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
 		}
-		transport = std::make_shared<TlsTransport>(lower, mHostname, stateChangeCallback);
-#else
-		if (config.disableTlsVerification)
-			transport = std::make_shared<TlsTransport>(lower, mHostname, stateChangeCallback);
+#endif
+
+		shared_ptr<TlsTransport> transport;
+		if (verify)
+			transport = std::make_shared<VerifiedTlsTransport>(lower, mHostname.value(), mCertificate,
+			                                                   stateChangeCallback);
 		else
 			transport =
-			    std::make_shared<VerifiedTlsTransport>(lower, mHostname, stateChangeCallback);
-#endif
+			    std::make_shared<TlsTransport>(lower, mHostname, mCertificate, stateChangeCallback);
 
 		std::atomic_store(&mTlsTransport, transport);
 		if (state == WebSocket::State::Closed) {
@@ -276,41 +295,52 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		if (auto transport = std::atomic_load(&mWsTransport))
 			return transport;
 
-		shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
-		if (!lower)
-			lower = std::atomic_load(&mTcpTransport);
+		variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower;
+		if (mIsSecure) {
+			auto transport = std::atomic_load(&mTlsTransport);
+			if (!transport)
+				throw std::logic_error("No underlying TLS transport for WebSocket transport");
+
+			lower = transport;
+		} else {
+			auto transport = std::atomic_load(&mTcpTransport);
+			if (!transport)
+				throw std::logic_error("No underlying TCP transport for WebSocket transport");
+
+			lower = transport;
+		}
 
-		WsTransport::Configuration wsConfig = {};
-		wsConfig.host = mHost;
-		wsConfig.path = mPath;
-		wsConfig.protocols = config.protocols;
+		if (!atomic_load(&mWsHandshake))
+			atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
+
+		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+			switch (transportState) {
+			case State::Connected:
+				if (state == WebSocket::State::Connecting) {
+					PLOG_DEBUG << "WebSocket open";
+					changeState(WebSocket::State::Open);
+					triggerOpen();
+				}
+				break;
+			case State::Failed:
+				triggerError("WebSocket connection failed");
+				remoteClose();
+				break;
+			case State::Disconnected:
+				remoteClose();
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		};
 
 		auto transport = std::make_shared<WsTransport>(
-		    lower, wsConfig, weak_bind(&WebSocket::incoming, this, _1),
-		    [this, weak_this = weak_from_this()](State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (transportState) {
-			    case State::Connected:
-				    if (state == WebSocket::State::Connecting) {
-					    PLOG_DEBUG << "WebSocket open";
-					    changeState(WebSocket::State::Open);
-					    triggerOpen();
-				    }
-				    break;
-			    case State::Failed:
-				    triggerError("WebSocket connection failed");
-				    remoteClose();
-				    break;
-			    case State::Disconnected:
-				    remoteClose();
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
+		    lower, mWsHandshake, weak_bind(&WebSocket::incoming, this, _1), stateChangeCallback);
+
 		std::atomic_store(&mWsTransport, transport);
 		if (state == WebSocket::State::Closed) {
 			mWsTransport.reset();
@@ -318,6 +348,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		}
 		transport->start();
 		return transport;
+
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		remoteClose();
@@ -337,6 +368,10 @@ shared_ptr<WsTransport> WebSocket::getWsTransport() const {
 	return std::atomic_load(&mWsTransport);
 }
 
+shared_ptr<WsHandshake> WebSocket::getWsHandshake() const {
+	return std::atomic_load(&mWsHandshake);
+}
+
 void WebSocket::closeTransports() {
 	PLOG_VERBOSE << "Closing transports";
 

+ 11 - 5
src/impl/websocket.hpp

@@ -41,10 +41,10 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	using State = rtc::WebSocket::State;
 	using Configuration = rtc::WebSocket::Configuration;
 
-	WebSocket(Configuration config_);
+	WebSocket(optional<Configuration> optConfig = nullopt, certificate_ptr certificate = nullptr);
 	~WebSocket();
 
-	void parse(const string &url);
+	void open(const string &url);
 	void close();
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
@@ -60,26 +60,32 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	bool changeState(State state);
 	void remoteClose();
 
-	shared_ptr<TcpTransport> initTcpTransport();
+	shared_ptr<TcpTransport> setTcpTransport(shared_ptr<TcpTransport> transport);
 	shared_ptr<TlsTransport> initTlsTransport();
 	shared_ptr<WsTransport> initWsTransport();
 	shared_ptr<TcpTransport> getTcpTransport() const;
 	shared_ptr<TlsTransport> getTlsTransport() const;
 	shared_ptr<WsTransport> getWsTransport() const;
+	shared_ptr<WsHandshake> getWsHandshake() const;
 
 	void closeTransports();
 
 	const Configuration config;
+
 	std::atomic<State> state = State::Closed;
 
 private:
 	const init_token mInitToken = Init::Token();
 
+	const certificate_ptr mCertificate;
+	bool mIsSecure;
+
+	optional<string> mHostname; // for TLS SNI
+
 	shared_ptr<TcpTransport> mTcpTransport;
 	shared_ptr<TlsTransport> mTlsTransport;
 	shared_ptr<WsTransport> mWsTransport;
-
-	string mScheme, mHost, mHostname, mService, mPath;
+	shared_ptr<WsHandshake> mWsHandshake;
 
 	Queue<message_ptr> mRecvQueue;
 };

+ 93 - 0
src/impl/websocketserver.cpp

@@ -0,0 +1,93 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "websocketserver.hpp"
+#include "common.hpp"
+#include "internals.hpp"
+#include "threadpool.hpp"
+
+namespace rtc::impl {
+
+using namespace std::placeholders;
+
+WebSocketServer::WebSocketServer(Configuration config_)
+    : config(std::move(config_)), tcpServer(std::make_unique<TcpServer>(config.port)),
+      mStopped(false) {
+	PLOG_VERBOSE << "Creating WebSocketServer";
+
+	if (config.enableTls) {
+		if (config.certificatePemFile && config.keyPemFile) {
+			mCertificate = std::make_shared<Certificate>(Certificate::FromFile(
+			    *config.certificatePemFile, *config.keyPemFile, config.keyPemPass.value_or("")));
+
+		} else if (!config.certificatePemFile && !config.keyPemFile) {
+			mCertificate = std::make_shared<Certificate>(
+			    Certificate::Generate(CertificateType::Default, "localhost"));
+		} else {
+			throw std::invalid_argument(
+			    "Either none or both certificate and key PEM files must be specified");
+		}
+	}
+
+	mThread = std::thread(&WebSocketServer::runLoop, this);
+}
+
+WebSocketServer::~WebSocketServer() {
+	PLOG_VERBOSE << "Destroying WebSocketServer";
+	stop();
+}
+
+void WebSocketServer::stop() {
+	if (mStopped.exchange(true))
+		return;
+
+	PLOG_DEBUG << "Stopping WebSocketServer thread";
+	tcpServer->close();
+	mThread.join();
+}
+
+void WebSocketServer::runLoop() {
+	PLOG_INFO << "Starting WebSocketServer";
+
+	try {
+		while (auto incoming = tcpServer->accept()) {
+			try {
+				if (!clientCallback)
+					continue;
+
+				auto impl = std::make_shared<WebSocket>(nullopt, mCertificate);
+				impl->changeState(WebSocket::State::Connecting);
+				impl->setTcpTransport(incoming);
+				clientCallback(std::make_shared<rtc::WebSocket>(impl));
+
+			} catch (const std::exception &e) {
+				PLOG_ERROR << "WebSocketServer: " << e.what();
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_FATAL << "WebSocketServer: " << e.what();
+	}
+
+	PLOG_INFO << "Stopped WebSocketServer";
+}
+
+} // namespace rtc::impl
+
+#endif

+ 66 - 0
src/impl/websocketserver.hpp

@@ -0,0 +1,66 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_WEBSOCKETSERVER_H
+#define RTC_IMPL_WEBSOCKETSERVER_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "certificate.hpp"
+#include "common.hpp"
+#include "init.hpp"
+#include "message.hpp"
+#include "tcpserver.hpp"
+#include "websocket.hpp"
+
+#include "rtc/websocket.hpp"
+#include "rtc/websocketserver.hpp"
+
+#include <atomic>
+#include <thread>
+
+namespace rtc::impl {
+
+struct WebSocketServer final : public std::enable_shared_from_this<WebSocketServer> {
+	using Configuration = rtc::WebSocketServer::Configuration;
+
+	WebSocketServer(Configuration config_);
+	~WebSocketServer();
+
+	void stop();
+
+	const Configuration config;
+	const unique_ptr<TcpServer> tcpServer;
+
+	synchronized_callback<shared_ptr<rtc::WebSocket>> clientCallback;
+
+private:
+	const init_token mInitToken = Init::Token();
+
+	void runLoop();
+
+	certificate_ptr mCertificate;
+	std::thread mThread;
+	std::atomic<bool> mStopped;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif // RTC_IMPL_WEBSOCKET_H

+ 322 - 0
src/impl/wshandshake.cpp

@@ -0,0 +1,322 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "wshandshake.hpp"
+#include "base64.hpp"
+#include "internals.hpp"
+#include "sha.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <algorithm>
+#include <chrono>
+#include <climits>
+#include <iostream>
+#include <iterator>
+#include <random>
+#include <sstream>
+
+using std::string;
+
+namespace {
+
+std::vector<string> explode(const string &str, char delim) {
+	std::vector<std::string> result;
+	std::istringstream ss(str);
+	string token;
+	while (std::getline(ss, token, delim))
+		result.push_back(token);
+
+	return result;
+}
+
+string implode(const std::vector<string> &tokens, char delim) {
+	string sdelim(1, delim);
+	std::ostringstream ss;
+	std::copy(tokens.begin(), tokens.end(), std::ostream_iterator<string>(ss, sdelim.c_str()));
+	string result = ss.str();
+	if (result.size() > 0)
+		result.resize(result.size() - 1);
+
+	return result;
+}
+
+} // namespace
+
+namespace rtc::impl {
+
+using std::to_string;
+using std::chrono::system_clock;
+using random_bytes_engine =
+    std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
+
+WsHandshake::WsHandshake() {}
+
+WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols)
+    : mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) {
+
+	if (mHost.empty())
+		throw std::invalid_argument("WebSocket HTTP host cannot be empty");
+
+	if (mPath.empty())
+		throw std::invalid_argument("WebSocket HTTP path cannot be empty");
+}
+
+string WsHandshake::host() const {
+	std::unique_lock lock(mMutex);
+	return mHost;
+}
+
+string WsHandshake::path() const {
+	std::unique_lock lock(mMutex);
+	return mPath;
+}
+
+std::vector<string> WsHandshake::protocols() const {
+	std::unique_lock lock(mMutex);
+	return mProtocols;
+}
+
+string WsHandshake::generateHttpRequest() {
+	std::unique_lock lock(mMutex);
+	mKey = generateKey();
+
+	string out = "GET " + mPath +
+	             " HTTP/1.1\r\n"
+	             "Host: " +
+	             mHost +
+	             "\r\n"
+	             "Connection: upgrade\r\n"
+	             "Upgrade: websocket\r\n"
+	             "Sec-WebSocket-Version: 13\r\n"
+	             "Sec-WebSocket-Key: " +
+	             mKey + "\r\n";
+
+	if (!mProtocols.empty())
+		out += "Sec-WebSocket-Protocol: " + implode(mProtocols, ',') + "\r\n";
+
+	out += "\r\n";
+
+	return out;
+}
+
+string WsHandshake::generateHttpResponse() {
+	std::unique_lock lock(mMutex);
+	const string out = "HTTP/1.1 101 Switching Protocols\r\n"
+	                   "Server: libdatachannel\r\n"
+	                   "Connection: upgrade\r\n"
+	                   "Upgrade: websocket\r\n"
+	                   "Sec-WebSocket-Accept: " +
+	                   computeAcceptKey(mKey) + "\r\n\r\n";
+
+	return out;
+}
+
+namespace {
+
+string GetHttpErrorName(int responseCode) {
+	switch(responseCode) {
+	case 400:
+		return "Bad Request";
+	case 404:
+		return "Not Found";
+	case 405:
+		return "Method Not Allowed";
+	case 426:
+		return "Upgrade Required";
+	case 500:
+		return "Internal Server Error";
+	default:
+		return "Error";
+	}
+}
+
+}
+
+string WsHandshake::generateHttpError(int responseCode) {
+	std::unique_lock lock(mMutex);
+
+	const string error = to_string(responseCode) + " " + GetHttpErrorName(responseCode);
+
+	const string out = "HTTP/1.1 " + error + "\r\n"
+	                   "Server: libdatachannel\r\n"
+	                   "Connection: upgrade\r\n"
+	                   "Upgrade: websocket\r\n"
+	                   "Content-Type: text/plain\r\n"
+	                   "Content-Length: " + to_string(error.size()) + "\r\n"
+	                   "Access-Control-Allow-Origin: *\r\n\r\n" + error;
+
+	return out;
+}
+
+size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
+	std::unique_lock lock(mMutex);
+	std::list<string> lines;
+	size_t length = parseHttpLines(buffer, size, lines);
+	if (length == 0)
+		return 0;
+
+	if (lines.empty())
+		throw RequestError("Invalid HTTP request for WebSocket", 400);
+
+	std::istringstream requestLine(std::move(lines.front()));
+	lines.pop_front();
+
+	string method, path, protocol;
+	requestLine >> method >> path >> protocol;
+	PLOG_DEBUG << "WebSocket request method \"" << method << "\" for path: " << path;
+	if (method != "GET")
+		throw RequestError("Invalid request method \"" + method + "\" for WebSocket", 405);
+
+	mPath = std::move(path);
+
+	auto headers = parseHttpHeaders(lines);
+
+	auto h = headers.find("host");
+	if (h == headers.end())
+		throw RequestError("WebSocket host header missing in request", 400);
+
+	mHost = std::move(h->second);
+
+	h = headers.find("upgrade");
+	if (h == headers.end())
+		throw RequestError("WebSocket upgrade header missing in request", 426);
+
+	string upgrade;
+	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
+	               [](char c) { return std::tolower(c); });
+	if (upgrade != "websocket")
+		throw RequestError("WebSocket upgrade header mismatching: " + h->second, 426);
+
+	h = headers.find("sec-websocket-key");
+	if (h == headers.end())
+		throw RequestError("WebSocket key header missing in request", 400);
+
+	mKey = std::move(h->second);
+
+	h = headers.find("sec-websocket-protocol");
+	if (h != headers.end())
+		mProtocols = explode(h->second, ',');
+
+	return length;
+}
+
+size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
+	std::unique_lock lock(mMutex);
+	std::list<string> lines;
+	size_t length = parseHttpLines(buffer, size, lines);
+	if (length == 0)
+		return 0;
+
+	if (lines.empty())
+		throw Error("Invalid HTTP response for WebSocket");
+
+	std::istringstream status(std::move(lines.front()));
+	lines.pop_front();
+
+	string protocol;
+	unsigned int code = 0;
+	status >> protocol >> code;
+	PLOG_DEBUG << "WebSocket response code: " << code;
+	if (code != 101)
+		throw std::runtime_error("Unexpected response code " + to_string(code) + " for WebSocket");
+
+	auto headers = parseHttpHeaders(lines);
+
+	auto h = headers.find("upgrade");
+	if (h == headers.end())
+		throw Error("WebSocket update header missing");
+
+	string upgrade;
+	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
+	               [](char c) { return std::tolower(c); });
+	if (upgrade != "websocket")
+		throw Error("WebSocket update header mismatching: " + h->second);
+
+	h = headers.find("sec-websocket-accept");
+	if (h == headers.end())
+		throw Error("WebSocket accept header missing");
+
+	if (h->second != computeAcceptKey(mKey))
+		throw Error("WebSocket accept header is invalid");
+
+	return length;
+}
+
+string WsHandshake::generateKey() {
+	// RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key.  The value
+	// of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
+	// been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
+	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
+	random_bytes_engine generator(seed);
+	binary key(16);
+	auto k = reinterpret_cast<uint8_t *>(key.data());
+	std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); });
+	return to_base64(key);
+}
+
+string WsHandshake::computeAcceptKey(const string &key) {
+	return to_base64(Sha1(string(key) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
+}
+
+size_t WsHandshake::parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines) {
+	lines.clear();
+	auto begin = reinterpret_cast<const char *>(buffer);
+	auto end = begin + size;
+	auto cur = begin;
+	while (true) {
+		auto last = cur;
+		cur = std::find(cur, end, '\n');
+		if (cur == end)
+			return 0;
+		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
+		if (line.empty())
+			break;
+		lines.emplace_back(std::move(line));
+	}
+
+	return cur - begin;
+}
+
+std::multimap<string, string> WsHandshake::parseHttpHeaders(const std::list<string> &lines) {
+	std::multimap<string, string> headers;
+	for (const auto &line : lines) {
+		if (size_t pos = line.find_first_of(':'); pos != string::npos) {
+			string key = line.substr(0, pos);
+			string value = line.substr(line.find_first_not_of(' ', pos + 1));
+			std::transform(key.begin(), key.end(), key.begin(),
+			               [](char c) { return std::tolower(c); });
+			headers.emplace(std::move(key), std::move(value));
+		} else {
+			headers.emplace(line, "");
+		}
+	}
+
+	return headers;
+}
+
+WsHandshake::Error::Error(const string &w) : std::runtime_error(w) {}
+
+WsHandshake::RequestError::RequestError(const string &w, int responseCode)
+    : Error(w), mResponseCode(responseCode) {}
+
+int WsHandshake::RequestError::RequestError::responseCode() const { return mResponseCode; }
+
+} // namespace rtc::impl
+
+#endif

+ 78 - 0
src/impl/wshandshake.hpp

@@ -0,0 +1,78 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_WS_HANDSHAKE_H
+#define RTC_IMPL_WS_HANDSHAKE_H
+
+#include "common.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <list>
+#include <map>
+
+namespace rtc::impl {
+
+class WsHandshake final {
+public:
+	WsHandshake();
+	WsHandshake(string host, string path = "/", std::vector<string> protocols = {});
+
+	string host() const;
+	string path() const;
+	std::vector<string> protocols() const;
+
+	string generateHttpRequest();
+	string generateHttpResponse();
+	string generateHttpError(int responseCode = 400);
+
+	class Error : public std::runtime_error {
+	public:
+		explicit Error(const string &w);
+	};
+
+	class RequestError : public Error {
+	public:
+		explicit RequestError(const string &w, int responseCode = 400);
+		int responseCode() const;
+
+	private:
+		const int mResponseCode;
+	};
+
+	size_t parseHttpRequest(const byte *buffer, size_t size);
+	size_t parseHttpResponse(const byte *buffer, size_t size);
+
+private:
+	static string generateKey();
+	static string computeAcceptKey(const string &key);
+	static size_t parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines);
+	static std::multimap<string, string> parseHttpHeaders(const std::list<string> &lines);
+
+	string mHost;
+	string mPath;
+	std::vector<string> mProtocols;
+	string mKey;
+	mutable std::mutex mMutex;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 80 - 132
src/impl/wstransport.cpp

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2020 Paul-Louis Ageneau
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -17,19 +17,18 @@
  */
 
 #include "wstransport.hpp"
-#include "base64.hpp"
 #include "tcptransport.hpp"
 #include "tlstransport.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
 
+#include <algorithm>
 #include <chrono>
-#include <iterator>
-#include <list>
-#include <map>
+#include <iostream>
 #include <numeric>
 #include <random>
 #include <regex>
+#include <sstream>
 
 #ifdef _WIN32
 #include <winsock2.h>
@@ -47,25 +46,26 @@
 
 namespace rtc::impl {
 
-using namespace std::chrono;
 using std::to_integer;
 using std::to_string;
-
+using std::chrono::system_clock;
 using random_bytes_engine =
     std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
 
-WsTransport::WsTransport(shared_ptr<Transport> lower, Configuration config,
-                         message_callback recvCallback, state_callback stateCallback)
-    : Transport(lower, std::move(stateCallback)), mConfig(std::move(config)) {
+WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
+                         shared_ptr<WsHandshake> handshake, message_callback recvCallback,
+                         state_callback stateCallback)
+    : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
+                std::move(stateCallback)),
+      mHandshake(std::move(handshake)),
+      mIsClient(
+          std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
+                                     [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
+                     lower)) {
+
 	onRecv(recvCallback);
 
 	PLOG_DEBUG << "Initializing WebSocket transport";
-
-	if (mConfig.host.empty())
-		throw std::invalid_argument("WebSocket HTTP host cannot be empty");
-
-	if (mConfig.path.empty())
-		throw std::invalid_argument("WebSocket HTTP path cannot be empty");
 }
 
 WsTransport::~WsTransport() { stop(); }
@@ -74,7 +74,10 @@ void WsTransport::start() {
 	Transport::start();
 
 	registerIncoming();
-	sendHttpRequest();
+
+	changeState(State::Connecting);
+	if (mIsClient)
+		sendHttpRequest();
 }
 
 bool WsTransport::stop() {
@@ -91,7 +94,7 @@ bool WsTransport::send(message_ptr message) {
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 	return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
-	                  message->size(), true, true});
+	                  message->size(), true, mIsClient});
 }
 
 void WsTransport::incoming(message_ptr message) {
@@ -102,35 +105,57 @@ void WsTransport::incoming(message_ptr message) {
 	if (message) {
 		PLOG_VERBOSE << "Incoming size=" << message->size();
 
-		if (message->size() == 0) {
-			// TCP is idle, send a ping
-			PLOG_DEBUG << "WebSocket sending ping";
-			uint32_t dummy = 0;
-			sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, true});
-			return;
-		}
-
-		mBuffer.insert(mBuffer.end(), message->begin(), message->end());
-
 		try {
+			mBuffer.insert(mBuffer.end(), message->begin(), message->end());
+
 			if (state() == State::Connecting) {
-				if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
-					PLOG_INFO << "WebSocket open";
-					changeState(State::Connected);
-					mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				if (mIsClient) {
+					if (size_t len =
+					        mHandshake->parseHttpResponse(mBuffer.data(), mBuffer.size())) {
+						PLOG_INFO << "WebSocket client-side open";
+						changeState(State::Connected);
+						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+					}
+				} else {
+					if (size_t len = mHandshake->parseHttpRequest(mBuffer.data(), mBuffer.size())) {
+						PLOG_INFO << "WebSocket server-side open";
+						sendHttpResponse();
+						changeState(State::Connected);
+						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+					}
 				}
 			}
 
 			if (state() == State::Connected) {
-				Frame frame;
-				while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
-					recvFrame(frame);
-					mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				if (message->size() == 0) {
+					// TCP is idle, send a ping
+					PLOG_DEBUG << "WebSocket sending ping";
+					uint32_t dummy = 0;
+					sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
+
+				} else {
+					Frame frame;
+					while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
+						recvFrame(frame);
+						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+					}
 				}
 			}
 
 			return;
 
+		} catch (const WsHandshake::RequestError &e) {
+			PLOG_WARNING << e.what();
+			try {
+				sendHttpError(e.responseCode());
+
+			} catch (const std::exception &e) {
+				PLOG_WARNING << e.what();
+			}
+
+		} catch (const WsHandshake::Error &e) {
+			PLOG_WARNING << e.what();
+
 		} catch (const std::exception &e) {
 			PLOG_ERROR << e.what();
 		}
@@ -148,115 +173,38 @@ void WsTransport::incoming(message_ptr message) {
 
 void WsTransport::close() {
 	if (state() == State::Connected) {
-		sendFrame({CLOSE, NULL, 0, true, true});
+		sendFrame({CLOSE, NULL, 0, true, mIsClient});
 		PLOG_INFO << "WebSocket closing";
 		changeState(State::Disconnected);
 	}
 }
 
 bool WsTransport::sendHttpRequest() {
-	PLOG_DEBUG << "Sending WebSocket HTTP request for path " << mConfig.path;
-	changeState(State::Connecting);
-
-	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-	random_bytes_engine generator(seed);
-
-	binary key(16);
-	auto k = reinterpret_cast<uint8_t *>(key.data());
-	std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); });
-
-	string appendHeader = "";
-	if (mConfig.protocols.size() > 0) {
-		appendHeader +=
-		    "Sec-WebSocket-Protocol: " +
-		    std::accumulate(mConfig.protocols.begin(), mConfig.protocols.end(), string(),
-		                    [](const string &a, const string &b) -> string {
-			                    return a + (a.length() > 0 ? "," : "") + b;
-		                    }) +
-		    "\r\n";
-	}
-
-	const string request = "GET " + mConfig.path +
-	                       " HTTP/1.1\r\n"
-	                       "Host: " +
-	                       mConfig.host +
-	                       "\r\n"
-	                       "Connection: Upgrade\r\n"
-	                       "Upgrade: websocket\r\n"
-	                       "Sec-WebSocket-Version: 13\r\n"
-	                       "Sec-WebSocket-Key: " +
-	                       to_base64(key) + "\r\n" + std::move(appendHeader) + "\r\n";
+	PLOG_DEBUG << "Sending WebSocket HTTP request";
 
+	const string request = mHandshake->generateHttpRequest();
 	auto data = reinterpret_cast<const byte *>(request.data());
-	auto size = request.size();
-	return outgoing(make_message(data, data + size));
+	return outgoing(make_message(data, data + request.size()));
 }
 
-size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
-	std::list<string> lines;
-	auto begin = reinterpret_cast<const char *>(buffer);
-	auto end = begin + size;
-	auto cur = begin;
-
-	while (true) {
-		auto last = cur;
-		cur = std::find(cur, end, '\n');
-		if (cur == end)
-			return 0;
-		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
-		if (line.empty())
-			break;
-		lines.emplace_back(std::move(line));
-	}
-	size_t length = cur - begin;
-
-	if (lines.empty())
-		throw std::runtime_error("Invalid HTTP response for WebSocket");
-
-	string status = std::move(lines.front());
-	lines.pop_front();
-
-	std::istringstream ss(status);
-	string protocol;
-	unsigned int code = 0;
-	ss >> protocol >> code;
-	PLOG_DEBUG << "WebSocket response code: " << code;
-	if (code != 101)
-		throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
-
-	std::multimap<string, string> headers;
-	for (const auto &line : lines) {
-		if (size_t pos = line.find_first_of(':'); pos != string::npos) {
-			string key = line.substr(0, pos);
-			string value = line.substr(line.find_first_not_of(' ', pos + 1));
-			std::transform(key.begin(), key.end(), key.begin(),
-			               [](char c) { return std::tolower(c); });
-			headers.emplace(std::move(key), std::move(value));
-		} else {
-			headers.emplace(line, "");
-		}
-	}
-
-	auto h = headers.find("upgrade");
-	if (h == headers.end())
-		throw std::runtime_error("WebSocket update header missing");
-
-	string upgrade;
-	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
-	               [](char c) { return std::tolower(c); });
-	if (upgrade != "websocket")
-		throw std::runtime_error("WebSocket update header mismatching: " + h->second);
+bool WsTransport::sendHttpResponse() {
+	PLOG_DEBUG << "Sending WebSocket HTTP response";
 
-	h = headers.find("sec-websocket-accept");
-	if (h == headers.end())
-		throw std::runtime_error("WebSocket accept header missing");
+	const string response = mHandshake->generateHttpResponse();
+	auto data = reinterpret_cast<const byte *>(response.data());
+	return outgoing(make_message(data, data + response.size()));
+}
 
-	// TODO: Verify Sec-WebSocket-Accept
+bool WsTransport::sendHttpError(int code) {
+	PLOG_WARNING << "Sending WebSocket HTTP error response " << code;
 
-	return length;
+	const string response = mHandshake->generateHttpError(code);
+	auto data = reinterpret_cast<const byte *>(response.data());
+	return outgoing(make_message(data, data + response.size()));
 }
 
-// http://tools.ietf.org/html/rfc6455#section-5.2  Base Framing Protocol
+// RFC6455 5.2. Base Framing Protocol
+// http://tools.ietf.org/html/rfc6455#section-5.2
 //
 //  0                   1                   2                   3
 //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
@@ -364,7 +312,7 @@ void WsTransport::recvFrame(const Frame &frame) {
 	}
 	case PING: {
 		PLOG_DEBUG << "WebSocket received ping, sending pong";
-		sendFrame({PONG, frame.payload, frame.length, true, true});
+		sendFrame({PONG, frame.payload, frame.length, true, mIsClient});
 		break;
 	}
 	case PONG: {
@@ -423,6 +371,6 @@ bool WsTransport::sendFrame(const Frame &frame) {
 	return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
 }
 
-} // namespace rtc
+} // namespace rtc::impl
 
 #endif

+ 12 - 14
src/impl/wstransport.hpp

@@ -1,5 +1,5 @@
 /**
- * Copyright (c) 2020 Paul-Louis Ageneau
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
  *
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
@@ -21,6 +21,7 @@
 
 #include "common.hpp"
 #include "transport.hpp"
+#include "wshandshake.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
 
@@ -29,26 +30,21 @@ namespace rtc::impl {
 class TcpTransport;
 class TlsTransport;
 
-class WsTransport : public Transport {
+class WsTransport final : public Transport {
 public:
-	struct Configuration {
-		string host;
-		string path = "/";
-		std::vector<string> protocols;
-	};
-
-	WsTransport(shared_ptr<Transport> lower, Configuration config,
-	            message_callback recvCallback, state_callback stateCallback);
+	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
+	            shared_ptr<WsHandshake> handshake, message_callback recvCallback,
+	            state_callback stateCallback);
 	~WsTransport();
 
 	void start() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
-
 	void incoming(message_ptr message) override;
-
 	void close();
 
+	bool isClient() const { return mIsClient; }
+
 private:
 	enum Opcode : uint8_t {
 		CONTINUATION = 0,
@@ -68,13 +64,15 @@ private:
 	};
 
 	bool sendHttpRequest();
-	size_t readHttpResponse(const byte *buffer, size_t size);
+	bool sendHttpError(int code);
+	bool sendHttpResponse();
 
 	size_t readFrame(byte *buffer, size_t size, Frame &frame);
 	void recvFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 
-	const Configuration mConfig;
+	const shared_ptr<WsHandshake> mHandshake;
+	const bool mIsClient;
 
 	binary mBuffer;
 	binary mPartial;

+ 17 - 13
src/websocket.cpp

@@ -21,25 +21,21 @@
 #include "websocket.hpp"
 #include "common.hpp"
 
-#include "impl/websocket.hpp"
 #include "impl/internals.hpp"
-
-#include <regex>
-
-#ifdef _WIN32
-#include <winsock2.h>
-#endif
+#include "impl/websocket.hpp"
 
 namespace rtc {
 
-using namespace std::placeholders;
-
 WebSocket::WebSocket() : WebSocket(Configuration()) {}
 
 WebSocket::WebSocket(Configuration config)
     : CheshireCat<impl::WebSocket>(std::move(config)),
       Channel(std::dynamic_pointer_cast<impl::Channel>(CheshireCat<impl::WebSocket>::impl())) {}
 
+WebSocket::WebSocket(impl_ptr<impl::WebSocket> impl)
+    : CheshireCat<impl::WebSocket>(std::move(impl)),
+      Channel(std::dynamic_pointer_cast<impl::Channel>(CheshireCat<impl::WebSocket>::impl())) {}
+
 WebSocket::~WebSocket() { impl()->remoteClose(); }
 
 WebSocket::State WebSocket::readyState() const { return impl()->state; }
@@ -52,10 +48,7 @@ size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
 void WebSocket::open(const string &url) {
 	PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
-
-	impl()->parse(url);
-	impl()->changeState(State::Connecting);
-	impl()->initTcpTransport();
+	impl()->open(url);
 }
 
 void WebSocket::close() {
@@ -78,6 +71,17 @@ bool WebSocket::send(const byte *data, size_t size) {
 	return impl()->outgoing(make_message(data, data + size));
 }
 
+optional<string> WebSocket::remoteAddress() const {
+	auto tcpTransport = impl()->getTcpTransport();
+	return tcpTransport ? make_optional(tcpTransport->remoteAddress()) : nullopt;
+}
+
+optional<string> WebSocket::path() const {
+	auto state = impl()->state.load();
+	auto handshake = impl()->getWsHandshake();
+	return state != State::Connecting && handshake ? make_optional(handshake->path()) : nullopt;
+}
+
 } // namespace rtc
 
 #endif

+ 46 - 0
src/websocketserver.cpp

@@ -0,0 +1,46 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "websocketserver.hpp"
+#include "common.hpp"
+
+#include "impl/internals.hpp"
+#include "impl/websocketserver.hpp"
+
+namespace rtc {
+
+WebSocketServer::WebSocketServer() : WebSocketServer(Configuration()) {}
+
+WebSocketServer::WebSocketServer(Configuration config)
+    : CheshireCat<impl::WebSocketServer>(std::move(config)) {}
+
+WebSocketServer::~WebSocketServer() { impl()->stop(); }
+
+void WebSocketServer::stop() { impl()->stop(); }
+
+uint16_t WebSocketServer::port() const { return impl()->tcpServer->port(); }
+
+void WebSocketServer::onClient(std::function<void(shared_ptr<WebSocket>)> callback) {
+	impl()->clientCallback = callback;
+}
+
+} // namespace rtc
+
+#endif

+ 173 - 0
test/capi_websocketserver.cpp

@@ -0,0 +1,173 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include <rtc/rtc.h>
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+
+#ifdef _WIN32
+#include <windows.h>
+static void sleep(unsigned int secs) { Sleep(secs * 1000); }
+#else
+#include <unistd.h> // for sleep
+#endif
+
+static const char *MESSAGE = "Hello, this is a C API WebSocket test!";
+
+static bool success = false;
+static bool failed = false;
+
+static void RTC_API openCallback(int ws, void *ptr) {
+	printf("WebSocket: Connection open\n");
+
+	if (rtcSendMessage(ws, MESSAGE, -1) < 0) { // negative size indicates a null-terminated string
+		fprintf(stderr, "rtcSendMessage failed\n");
+		failed = true;
+		return;
+	}
+}
+
+static void RTC_API closedCallback(int ws, void *ptr) { printf("WebSocket: Connection closed"); }
+
+static void RTC_API messageCallback(int ws, const char *message, int size, void *ptr) {
+	if (size < 0 && strcmp(message, MESSAGE) == 0) {
+		printf("WebSocket: Received expected message\n");
+		success = true;
+	} else {
+		fprintf(stderr, "Received UNEXPECTED message\n");
+		failed = true;
+	}
+}
+
+static void RTC_API serverOpenCallback(int ws, void *ptr) {
+	printf("WebSocketServer: Client connection open\n");
+
+	char path[256];
+	if (rtcGetWebSocketPath(ws, path, 256) < 0) {
+		fprintf(stderr, "rtcGetWebSocketPath failed\n");
+		failed = true;
+		return;
+	}
+
+	if (strcmp(path, "/mypath") != 0) {
+		fprintf(stderr, "Wrong WebSocket path: %s\n", path);
+		failed = true;
+	}
+}
+
+static void RTC_API serverClosedCallback(int ws, void *ptr) {
+	printf("WebSocketServer: Client connection closed\n");
+}
+
+static void RTC_API serverMessageCallback(int ws, const char *message, int size, void *ptr) {
+	if (rtcSendMessage(ws, message, size) < 0) {
+		fprintf(stderr, "rtcSendMessage failed\n");
+		failed = true;
+	}
+}
+
+static void RTC_API serverClientCallback(int wsserver, int ws, void *ptr) {
+	char address[256];
+	if (rtcGetWebSocketRemoteAddress(ws, address, 256) < 0) {
+		fprintf(stderr, "rtcGetWebSocketRemoteAddress failed\n");
+		failed = true;
+		return;
+	}
+
+	printf("WebSocketServer: Received client connection from %s", address);
+
+	rtcSetOpenCallback(ws, serverOpenCallback);
+	rtcSetClosedCallback(ws, serverClosedCallback);
+	rtcSetMessageCallback(ws, serverMessageCallback);
+}
+
+int test_capi_websocketserver_main() {
+	const char *url = "wss://localhost:48081/mypath";
+	const uint16_t port = 48081;
+	int wsserver = -1;
+	int ws = -1;
+	int attempts;
+
+	rtcInitLogger(RTC_LOG_DEBUG, nullptr);
+
+	rtcWsServerConfiguration serverConfig;
+	memset(&serverConfig, 0, sizeof(serverConfig));
+	serverConfig.port = port;
+	serverConfig.enableTls = true;
+	// serverConfig.certificatePemFile = ...
+	// serverConfig.keyPemFile = ...
+
+	wsserver = rtcCreateWebSocketServer(&serverConfig, serverClientCallback);
+	if (wsserver < 0)
+		goto error;
+
+	if (rtcGetWebSocketServerPort(wsserver) != int(port)) {
+		fprintf(stderr, "rtcGetWebSocketServerPort failed\n");
+		goto error;
+	}
+
+	rtcWsConfiguration config;
+	memset(&config, 0, sizeof(config));
+	config.disableTlsVerification = true;
+
+	ws = rtcCreateWebSocketEx(url, &config);
+	if (ws < 0)
+		goto error;
+
+	rtcSetOpenCallback(ws, openCallback);
+	rtcSetClosedCallback(ws, closedCallback);
+	rtcSetMessageCallback(ws, messageCallback);
+
+	attempts = 10;
+	while (!success && !failed && attempts--)
+		sleep(1);
+
+	if (failed)
+		goto error;
+
+	rtcDeleteWebSocket(ws);
+	sleep(1);
+
+	rtcDeleteWebSocketServer(wsserver);
+	sleep(1);
+
+	printf("Success\n");
+	return 0;
+
+error:
+	if (ws >= 0)
+		rtcDeleteWebSocket(ws);
+
+	if (wsserver >= 0)
+		rtcDeleteWebSocketServer(wsserver);
+
+	return -1;
+}
+
+#include <stdexcept>
+
+void test_capi_websocketserver() {
+	if (test_capi_websocketserver_main())
+		throw std::runtime_error("WebSocketServer test failed");
+}
+
+#endif

+ 22 - 3
test/main.cpp

@@ -29,6 +29,8 @@ void test_track();
 void test_capi_connectivity();
 void test_capi_track();
 void test_websocket();
+void test_websocketserver();
+void test_capi_websocketserver();
 size_t benchmark(chrono::milliseconds duration);
 
 void test_benchmark() {
@@ -101,8 +103,25 @@ int main(int argc, char **argv) {
 		return -1;
 	}
 */
+	try {
+		cout << endl << "*** Running WebSocketServer test..." << endl;
+		test_websocketserver();
+		cout << "*** Finished WebSocketServer test" << endl;
+	} catch (const exception &e) {
+		cerr << "WebSocketServer test failed: " << e.what() << endl;
+		return -1;
+	}
+	try {
+		cout << endl << "*** Running WebSocketServer C API test..." << endl;
+		test_capi_websocketserver();
+		cout << "*** Finished WebSocketServer C API test" << endl;
+	} catch (const exception &e) {
+		cerr << "WebSocketServer C API test failed: " << e.what() << endl;
+		return -1;
+	}
 #endif
-	this_thread::sleep_for(1s);
+/*
+    this_thread::sleep_for(1s);
 	try {
 		cout << endl << "*** Running WebRTC benchmark..." << endl;
 		test_benchmark();
@@ -112,7 +131,7 @@ int main(int argc, char **argv) {
 		std::this_thread::sleep_for(2s);
 		return -1;
 	}
-
-	std::this_thread::sleep_for(2s);
+*/
+	std::this_thread::sleep_for(1s);
 	return 0;
 }

+ 114 - 0
test/websocketserver.cpp

@@ -0,0 +1,114 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "rtc/rtc.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <atomic>
+#include <chrono>
+#include <iostream>
+#include <memory>
+#include <thread>
+
+using namespace rtc;
+using namespace std;
+
+template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
+
+void test_websocketserver() {
+	InitLogger(LogLevel::Debug);
+
+	const string myMessage = "Hello world from client";
+
+	WebSocketServer::Configuration serverConfig;
+	serverConfig.port = 48080;
+	serverConfig.enableTls = true;
+	// serverConfig.certificatePemFile = ...
+	// serverConfig.keyPemFile = ...
+	WebSocketServer server(std::move(serverConfig));
+
+	shared_ptr<WebSocket> client;
+	server.onClient([&client](shared_ptr<WebSocket> incoming) {
+		cout << "WebSocketServer: Client connection received" << endl;
+		client = incoming;
+
+		if(auto addr = client->remoteAddress())
+			cout << "WebSocketServer: Client remote address is " << *addr << endl;
+
+		client->onOpen([wclient = make_weak_ptr(client)]() {
+			cout << "WebSocketServer: Client connection open" << endl;
+			if(auto client = wclient.lock())
+				if(auto path = client->path())
+					cout << "WebSocketServer: Requested path is " << *path << endl;
+		});
+
+		client->onClosed([]() {
+			cout << "WebSocketServer: Client connection closed" << endl;
+		});
+
+		client->onMessage([wclient = make_weak_ptr(client)](variant<binary, string> message) {
+			if(auto client = wclient.lock())
+				client->send(std::move(message));
+		});
+	});
+
+	WebSocket::Configuration config;
+	config.disableTlsVerification = true;
+	WebSocket ws(std::move(config));
+
+	ws.onOpen([&ws, &myMessage]() {
+		cout << "WebSocket: Open" << endl;
+		ws.send(myMessage);
+	});
+
+	ws.onClosed([]() { cout << "WebSocket: Closed" << endl; });
+
+	std::atomic<bool> received = false;
+	ws.onMessage([&received, &myMessage](variant<binary, string> message) {
+		if (holds_alternative<string>(message)) {
+			string str = std::move(get<string>(message));
+			if ((received = (str == myMessage)))
+				cout << "WebSocket: Received expected message" << endl;
+			else
+				cout << "WebSocket: Received UNEXPECTED message" << endl;
+		}
+	});
+
+	ws.open("wss://localhost:48080/");
+
+	int attempts = 10;
+	while ((!ws.isOpen() || !received) && attempts--)
+		this_thread::sleep_for(1s);
+
+	if (!ws.isOpen())
+		throw runtime_error("WebSocket is not open");
+
+	if (!received)
+		throw runtime_error("Expected message not received");
+
+	ws.close();
+	this_thread::sleep_for(1s);
+
+	server.stop();
+	this_thread::sleep_for(1s);
+
+	cout << "Success" << endl;
+}
+
+#endif