Browse Source

Implemented WebSocket server

Paul-Louis Ageneau 4 years ago
parent
commit
f1bfb2758c

+ 9 - 0
CMakeLists.txt

@@ -49,6 +49,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpreceivingsession.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpreceivingsession.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/track.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/track.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/websocket.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/websocketserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizationconfig.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizationconfig.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpsrreporter.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtcpsrreporter.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizer.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtppacketizer.cpp
@@ -116,11 +117,15 @@ set(LIBDATACHANNEL_IMPL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/selectinterrupter.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocketserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wshandshake.cpp
 )
 )
 
 
 set(LIBDATACHANNEL_IMPL_HEADERS
 set(LIBDATACHANNEL_IMPL_HEADERS
@@ -142,11 +147,15 @@ set(LIBDATACHANNEL_IMPL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/selectinterrupter.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/verifiedtlstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocket.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/websocketserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wstransport.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/wshandshake.hpp
 )
 )
 
 
 set(TESTS_SOURCES
 set(TESTS_SOURCES

+ 2 - 2
include/rtc/common.hpp

@@ -48,9 +48,9 @@
 #include <memory>
 #include <memory>
 #include <mutex>
 #include <mutex>
 #include <optional>
 #include <optional>
-#include <variant>
 #include <string>
 #include <string>
 #include <string_view>
 #include <string_view>
+#include <variant>
 #include <vector>
 #include <vector>
 
 
 namespace rtc {
 namespace rtc {
@@ -68,8 +68,8 @@ using std::weak_ptr;
 using binary = std::vector<byte>;
 using binary = std::vector<byte>;
 using binary_ptr = std::shared_ptr<binary>;
 using binary_ptr = std::shared_ptr<binary>;
 
 
-using std::size_t;
 using std::ptrdiff_t;
 using std::ptrdiff_t;
+using std::size_t;
 using std::uint16_t;
 using std::uint16_t;
 using std::uint32_t;
 using std::uint32_t;
 using std::uint64_t;
 using std::uint64_t;

+ 1 - 0
include/rtc/rtc.hpp

@@ -31,6 +31,7 @@
 
 
 // WebSocket
 // WebSocket
 #include "websocket.hpp"
 #include "websocket.hpp"
+#include "websocketserver.hpp"
 
 
 #endif // RTC_ENABLE_WEBSOCKET
 #endif // RTC_ENABLE_WEBSOCKET
 
 

+ 4 - 0
include/rtc/websocket.hpp

@@ -49,6 +49,7 @@ public:
 
 
 	WebSocket();
 	WebSocket();
 	WebSocket(Configuration config);
 	WebSocket(Configuration config);
+	WebSocket(impl_ptr<impl::WebSocket> impl);
 	~WebSocket();
 	~WebSocket();
 
 
 	State readyState() const;
 	State readyState() const;
@@ -62,6 +63,9 @@ public:
 	bool send(const message_variant data) override;
 	bool send(const message_variant data) override;
 	bool send(const byte *data, size_t size) override;
 	bool send(const byte *data, size_t size) override;
 
 
+	optional<string> remoteAddress() const;
+	optional<string> path() const;
+
 private:
 private:
 	using CheshireCat<impl::WebSocket>::impl;
 	using CheshireCat<impl::WebSocket>::impl;
 };
 };

+ 61 - 0
include/rtc/websocketserver.hpp

@@ -0,0 +1,61 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_WEBSOCKETSERVER_H
+#define RTC_WEBSOCKETSERVER_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "channel.hpp"
+#include "common.hpp"
+#include "message.hpp"
+#include "websocket.hpp"
+
+namespace rtc {
+
+namespace impl {
+
+struct WebSocketServer;
+
+}
+
+class RTC_CPP_EXPORT WebSocketServer final : private CheshireCat<impl::WebSocketServer> {
+public:
+	struct Configuration {
+		uint16_t port = 8080;
+	};
+
+	WebSocketServer();
+	WebSocketServer(Configuration config);
+	~WebSocketServer();
+
+	void stop();
+
+	uint16_t port() const;
+
+	void onClient(std::function<void(shared_ptr<WebSocket>)> callback);
+
+private:
+	using CheshireCat<impl::WebSocketServer>::impl;
+};
+
+} // namespace rtc
+
+#endif
+
+#endif // RTC_WEBSOCKET_H

+ 95 - 27
src/impl/certificate.cpp

@@ -32,16 +32,47 @@ const string COMMON_NAME = "libdatachannel";
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 
 
-Certificate::Certificate(string crt_pem, string key_pem)
-    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials) {
+Certificate Certificate::FromString(string crt_pem, string key_pem) {
+	Certificate certificate;
 
 
 	gnutls_datum_t crt_datum = gnutls::make_datum(crt_pem.data(), crt_pem.size());
 	gnutls_datum_t crt_datum = gnutls::make_datum(crt_pem.data(), crt_pem.size());
 	gnutls_datum_t key_datum = gnutls::make_datum(key_pem.data(), key_pem.size());
 	gnutls_datum_t key_datum = gnutls::make_datum(key_pem.data(), key_pem.size());
+	gnutls::check(gnutls_certificate_set_x509_key_mem(*certificate.mCredentials, &crt_datum,
+	                                                  &key_datum, GNUTLS_X509_FMT_PEM),
+	              "Unable to import PEM certificate and key");
 
 
-	gnutls::check(gnutls_certificate_set_x509_key_mem(*mCredentials, &crt_datum, &key_datum,
-	                                                  GNUTLS_X509_FMT_PEM),
-	              "Unable to import PEM");
+	certificate.computeFingerprint();
+	return certificate;
+}
+
+Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
+                                  const string &pass) {
+	Certificate certificate;
+
+	gnutls::check(gnutls_certificate_set_x509_key_file2(*certificate.mCredentials,
+	                                                    crt_pem_file.c_str(), key_pem_file.c_str(),
+	                                                    GNUTLS_X509_FMT_PEM, pass.c_str(), 0),
+	              "Unable to import PEM certificate and key from file");
+
+	certificate.computeFingerprint();
+	return certificate;
+}
+
+Certificate::Certificate() : mCredentials(gnutls::new_credentials(), gnutls::free_credentials) {}
+
+Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
+    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
+      mFingerprint(make_fingerprint(crt)) {
+
+	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
+	              "Unable to set certificate and key pair in credentials");
+}
+
+gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
+
+string Certificate::fingerprint() const { return mFingerprint; }
 
 
+void Certificate::computeFingerprint() {
 	auto new_crt_list = [this]() -> gnutls_x509_crt_t * {
 	auto new_crt_list = [this]() -> gnutls_x509_crt_t * {
 		gnutls_x509_crt_t *crt_list = nullptr;
 		gnutls_x509_crt_t *crt_list = nullptr;
 		unsigned int crt_list_size = 0;
 		unsigned int crt_list_size = 0;
@@ -60,18 +91,6 @@ Certificate::Certificate(string crt_pem, string key_pem)
 	mFingerprint = make_fingerprint(*crt_list);
 	mFingerprint = make_fingerprint(*crt_list);
 }
 }
 
 
-Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
-    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
-      mFingerprint(make_fingerprint(crt)) {
-
-	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
-	              "Unable to set certificate and key pair in credentials");
-}
-
-gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
-
-string Certificate::fingerprint() const { return mFingerprint; }
-
 string make_fingerprint(gnutls_x509_crt_t crt) {
 string make_fingerprint(gnutls_x509_crt_t crt) {
 	const size_t size = 32;
 	const size_t size = 32;
 	unsigned char buffer[size];
 	unsigned char buffer[size];
@@ -145,29 +164,78 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 
 
 #else // USE_GNUTLS==0
 #else // USE_GNUTLS==0
 
 
-Certificate::Certificate(string crt_pem, string key_pem) {
+#include <cstdio>
+
+namespace {
+
+// Dummy password callback that copies the password from user data
+int dummy_pass_cb(char *buf, int size, int /*rwflag*/, void *u) {
+	const char *pass = static_cast<char *>(u);
+	return snprintf(buf, size, "%s", pass);
+}
+
+} // namespace
+
+Certificate Certificate::FromString(string crt_pem, string key_pem) {
+	Certificate certificate;
+
 	BIO *bio = BIO_new(BIO_s_mem());
 	BIO *bio = BIO_new(BIO_s_mem());
 	BIO_write(bio, crt_pem.data(), int(crt_pem.size()));
 	BIO_write(bio, crt_pem.data(), int(crt_pem.size()));
-	mX509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, 0, 0), X509_free);
+	certificate.mX509 =
+	    shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, nullptr, nullptr), X509_free);
 	BIO_free(bio);
 	BIO_free(bio);
-	if (!mX509)
-		throw std::invalid_argument("Unable to import certificate PEM");
+	if (!certificate.mX509)
+		throw std::invalid_argument("Unable to import PEM certificate");
 
 
 	bio = BIO_new(BIO_s_mem());
 	bio = BIO_new(BIO_s_mem());
 	BIO_write(bio, key_pem.data(), int(key_pem.size()));
 	BIO_write(bio, key_pem.data(), int(key_pem.size()));
-	mPKey = shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio, nullptr, 0, 0), EVP_PKEY_free);
+	certificate.mPKey = shared_ptr<EVP_PKEY>(
+	    PEM_read_bio_PrivateKey(bio, nullptr, nullptr, nullptr), EVP_PKEY_free);
 	BIO_free(bio);
 	BIO_free(bio);
-	if (!mPKey)
-		throw std::invalid_argument("Unable to import PEM key PEM");
+	if (!certificate.mPKey)
+		throw std::invalid_argument("Unable to import PEM key");
 
 
-	mFingerprint = make_fingerprint(mX509.get());
+	certificate.computeFingerprint();
+	return certificate;
+}
+
+Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
+                                  const string &pass) {
+	Certificate certificate;
+
+	FILE *file = fopen(crt_pem_file.c_str(), "r");
+	if (!file)
+		throw std::invalid_argument("Unable to open PEM certificate file");
+
+	certificate.mX509 = shared_ptr<X509>(PEM_read_X509(file, nullptr, nullptr, nullptr), X509_free);
+	fclose(file);
+	if (!certificate.mX509)
+		throw std::invalid_argument("Unable to import PEM certificate from file");
+
+	file = fopen(key_pem_file.c_str(), "r");
+	if (!file)
+		throw std::invalid_argument("Unable to open PEM key file");
+
+	certificate.mPKey = shared_ptr<EVP_PKEY>(
+	    PEM_read_PrivateKey(file, nullptr, dummy_pass_cb, const_cast<char *>(pass.c_str())),
+	    EVP_PKEY_free);
+	fclose(file);
+	if (!certificate.mPKey)
+		throw std::invalid_argument("Unable to import PEM key from file");
+
+	certificate.computeFingerprint();
+	return certificate;
 }
 }
 
 
+Certificate::Certificate() {}
+
 Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
 Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
     : mX509(std::move(x509)), mPKey(std::move(pkey)) {
     : mX509(std::move(x509)), mPKey(std::move(pkey)) {
 	mFingerprint = make_fingerprint(mX509.get());
 	mFingerprint = make_fingerprint(mX509.get());
 }
 }
 
 
+void Certificate::computeFingerprint() { mFingerprint = make_fingerprint(mX509.get()); }
+
 string Certificate::fingerprint() const { return mFingerprint; }
 string Certificate::fingerprint() const { return mFingerprint; }
 
 
 std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
 std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
@@ -212,8 +280,8 @@ certificate_ptr make_certificate_impl(CertificateType type) {
 	case CertificateType::Ecdsa: {
 	case CertificateType::Ecdsa: {
 		PLOG_VERBOSE << "Generating ECDSA P-256 key pair";
 		PLOG_VERBOSE << "Generating ECDSA P-256 key pair";
 
 
-		unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ecc(EC_KEY_new_by_curve_name(NID_X9_62_prime256v1),
-		                                               EC_KEY_free);
+		unique_ptr<EC_KEY, decltype(&EC_KEY_free)> ecc(
+		    EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
 		if (!ecc)
 		if (!ecc)
 			throw std::runtime_error("Unable to allocate structure for ECDSA P-256 key pair");
 			throw std::runtime_error("Unable to allocate structure for ECDSA P-256 key pair");
 
 

+ 7 - 2
src/impl/certificate.hpp

@@ -20,8 +20,8 @@
 #define RTC_IMPL_CERTIFICATE_H
 #define RTC_IMPL_CERTIFICATE_H
 
 
 #include "common.hpp"
 #include "common.hpp"
-#include "tls.hpp"
 #include "configuration.hpp" // for CertificateType
 #include "configuration.hpp" // for CertificateType
+#include "tls.hpp"
 
 
 #include <future>
 #include <future>
 #include <tuple>
 #include <tuple>
@@ -30,7 +30,9 @@ namespace rtc::impl {
 
 
 class Certificate {
 class Certificate {
 public:
 public:
-	Certificate(string crt_pem, string key_pem);
+	static Certificate FromString(string crt_pem, string key_pem);
+	static Certificate FromFile(const string &crt_pem_file, const string &key_pem_file,
+	                            const string &pass);
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 	Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
 	Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
@@ -43,6 +45,9 @@ public:
 	string fingerprint() const;
 	string fingerprint() const;
 
 
 private:
 private:
+	Certificate();
+	void computeFingerprint();
+
 #if USE_GNUTLS
 #if USE_GNUTLS
 	shared_ptr<gnutls_certificate_credentials_t> mCredentials;
 	shared_ptr<gnutls_certificate_credentials_t> mCredentials;
 #else
 #else

+ 7 - 1
src/impl/dtlstransport.cpp

@@ -54,6 +54,9 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
 
+	if(!mCertificate)
+		throw std::invalid_argument("DTLS certificate is null");
+
 	gnutls_certificate_credentials_t creds = mCertificate->credentials();
 	gnutls_certificate_credentials_t creds = mCertificate->credentials();
 	gnutls_certificate_set_verify_function(creds, CertificateCallback);
 	gnutls_certificate_set_verify_function(creds, CertificateCallback);
 
 
@@ -330,7 +333,7 @@ void DtlsTransport::Cleanup() {
 	// Nothing to do
 	// Nothing to do
 }
 }
 
 
-DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
+DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr certificate,
                              optional<size_t> mtu, verifier_callback verifierCallback,
                              optional<size_t> mtu, verifier_callback verifierCallback,
                              state_callback stateChangeCallback)
                              state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
     : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
@@ -338,6 +341,9 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
       mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
       mIsClient(lower->role() == Description::Role::Active), mCurrentDscp(0) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
 
+	if(!mCertificate)
+		throw std::invalid_argument("DTLS certificate is null");
+
 	try {
 	try {
 		mCtx = SSL_CTX_new(DTLS_method());
 		mCtx = SSL_CTX_new(DTLS_method());
 		if (!mCtx)
 		if (!mCtx)

+ 2 - 0
src/impl/dtlstransport.hpp

@@ -51,6 +51,8 @@ public:
 	virtual bool stop() override;
 	virtual bool stop() override;
 	virtual bool send(message_ptr message) override; // false if dropped
 	virtual bool send(message_ptr message) override; // false if dropped
 
 
+	bool isClient() const { return mIsClient; }
+
 protected:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void incoming(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;
 	virtual bool outgoing(message_ptr message) override;

+ 10 - 3
src/impl/peerconnection.cpp

@@ -186,8 +186,11 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 
 
 		PLOG_VERBOSE << "Starting DTLS transport";
 		PLOG_VERBOSE << "Starting DTLS transport";
 
 
-		auto certificate = mCertificate.get();
 		auto lower = std::atomic_load(&mIceTransport);
 		auto lower = std::atomic_load(&mIceTransport);
+		if(!lower)
+			throw std::logic_error("No underlying ICE transport for DTLS transport");
+
+		auto certificate = mCertificate.get();
 		auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1);
 		auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1);
 		auto dtlsStateChangeCallback =
 		auto dtlsStateChangeCallback =
 		    [this, weak_this = weak_from_this()](DtlsTransport::State transportState) {
 		    [this, weak_this = weak_from_this()](DtlsTransport::State transportState) {
@@ -258,15 +261,19 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 
 
 		PLOG_VERBOSE << "Starting SCTP transport";
 		PLOG_VERBOSE << "Starting SCTP transport";
 
 
+		auto lower = std::atomic_load(&mDtlsTransport);
+		if(!lower)
+			throw std::logic_error("No underlying DTLS transport for SCTP transport");
+
 		auto remote = remoteDescription();
 		auto remote = remoteDescription();
 		if (!remote || !remote->application())
 		if (!remote || !remote->application())
 			throw std::logic_error("Starting SCTP transport without application description");
 			throw std::logic_error("Starting SCTP transport without application description");
 
 
+		uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT);
+
 		// This is the last occasion to ensure the stream numbers are coherent with the role
 		// This is the last occasion to ensure the stream numbers are coherent with the role
 		shiftDataChannels();
 		shiftDataChannels();
 
 
-		uint16_t sctpPort = remote->application()->sctpPort().value_or(DEFAULT_SCTP_PORT);
-		auto lower = std::atomic_load(&mDtlsTransport);
 		auto transport = std::make_shared<SctpTransport>(
 		auto transport = std::make_shared<SctpTransport>(
 		    lower, config, sctpPort, weak_bind(&PeerConnection::forwardMessage, this, _1),
 		    lower, config, sctpPort, weak_bind(&PeerConnection::forwardMessage, this, _1),
 		    weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
 		    weak_bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),

+ 88 - 0
src/impl/selectinterrupter.cpp

@@ -0,0 +1,88 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "selectinterrupter.hpp"
+#include "internals.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#ifndef _WIN32
+#include <fcntl.h>
+#include <unistd.h>
+#endif
+
+namespace rtc::impl {
+
+SelectInterrupter::SelectInterrupter() {
+#ifndef _WIN32
+	int pipefd[2];
+	if (::pipe(pipefd) != 0)
+		throw std::runtime_error("Failed to create pipe");
+	::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
+	::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
+	mPipeOut = pipefd[1]; // read
+	mPipeIn = pipefd[0];  // write
+#endif
+}
+
+SelectInterrupter::~SelectInterrupter() {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock != INVALID_SOCKET)
+		::closesocket(mDummySock);
+#else
+	::close(mPipeIn);
+	::close(mPipeOut);
+#endif
+}
+
+int SelectInterrupter::prepare(fd_set &readfds) {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock == INVALID_SOCKET)
+		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
+	FD_SET(mDummySock, &readfds);
+	return SOCKET_TO_INT(mDummySock) + 1;
+#else
+	char dummy;
+	if (::read(mPipeIn, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
+		PLOG_WARNING << "Reading from interrupter pipe failed, errno=" << errno;
+	}
+	FD_SET(mPipeIn, &readfds);
+	return mPipeIn + 1;
+#endif
+}
+
+void SelectInterrupter::interrupt() {
+	std::lock_guard lock(mMutex);
+#ifdef _WIN32
+	if (mDummySock != INVALID_SOCKET) {
+		::closesocket(mDummySock);
+		mDummySock = INVALID_SOCKET;
+	}
+#else
+	char dummy = 0;
+	if (::write(mPipeOut, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
+		PLOG_WARNING << "Writing to interrupter pipe failed, errno=" << errno;
+	}
+#endif
+}
+
+} // namespace rtc::impl
+
+#endif

+ 55 - 0
src/impl/selectinterrupter.hpp

@@ -0,0 +1,55 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_SELECT_INTERRUPTER_H
+#define RTC_IMPL_SELECT_INTERRUPTER_H
+
+#include "common.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <mutex>
+
+// Use the socket defines from libjuice
+#include "../deps/libjuice/src/socket.h"
+
+namespace rtc::impl {
+
+// Utility class to interrupt select()
+class SelectInterrupter final {
+public:
+	SelectInterrupter();
+	~SelectInterrupter();
+
+	int prepare(fd_set &readfds);
+	void interrupt();
+
+private:
+	std::mutex mMutex;
+#ifdef _WIN32
+	socket_t mDummySock = INVALID_SOCKET;
+#else // assume POSIX
+	int mPipeIn, mPipeOut;
+#endif
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 17 - 4
src/impl/sha.cpp

@@ -28,14 +28,16 @@
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
-binary Sha1(const binary &input) {
+namespace {
+
+binary Sha1(const byte *data, size_t size) {
 #if USE_GNUTLS
 #if USE_GNUTLS
 
 
 binary output(SHA1_DIGEST_SIZE);
 binary output(SHA1_DIGEST_SIZE);
 struct sha1_ctx ctx;
 struct sha1_ctx ctx;
 sha1_init(&ctx);
 sha1_init(&ctx);
-sha1_update(&ctx, input.size(), input.data());
-sha1_digest(&ctx, SHA1_DIGEST_SIZE, output.size());
+sha1_update(&ctx, size, reinterpret_cast<const uint8_t*>(data));
+sha1_digest(&ctx, SHA1_DIGEST_SIZE, reinterpret_cast<uint8_t*>(output.data()));
 return output;
 return output;
 
 
 #else // USE_GNUTLS==0
 #else // USE_GNUTLS==0
@@ -43,13 +45,24 @@ return output;
 binary output(SHA_DIGEST_LENGTH);
 binary output(SHA_DIGEST_LENGTH);
 SHA_CTX ctx;
 SHA_CTX ctx;
 SHA1_Init(&ctx);
 SHA1_Init(&ctx);
-SHA1_Update(&ctx, input.data(), input.size());
+SHA1_Update(&ctx, data, size);
 SHA1_Final(reinterpret_cast<unsigned char*>(output.data()), &ctx);
 SHA1_Final(reinterpret_cast<unsigned char*>(output.data()), &ctx);
 return output;
 return output;
 
 
 #endif
 #endif
 }
 }
 
 
+}
+
+binary Sha1(const binary &input) {
+	return Sha1(input.data(), input.size());
+}
+
+
+binary Sha1(const string &input) {
+	return Sha1(reinterpret_cast<const byte*>(input.data()), input.size());
+}
+
 } // namespace rtc::impl
 } // namespace rtc::impl
 
 
 #endif
 #endif

+ 1 - 0
src/impl/sha.hpp

@@ -26,6 +26,7 @@
 namespace rtc::impl {
 namespace rtc::impl {
 
 
 binary Sha1(const binary &input);
 binary Sha1(const binary &input);
+binary Sha1(const string &input);
 
 
 } // namespace rtc::impl
 } // namespace rtc::impl
 
 

+ 176 - 0
src/impl/tcpserver.cpp

@@ -0,0 +1,176 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "tcpserver.hpp"
+#include "internals.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#ifdef _WIN32
+#include <winsock2.h>
+#else
+#include <arpa/inet.h>
+#include <fcntl.h>
+#include <unistd.h>
+#endif
+
+namespace rtc::impl {
+
+TcpServer::TcpServer(uint16_t port) {
+	PLOG_DEBUG << "Initializing TCP server";
+	listen(port);
+}
+
+TcpServer::~TcpServer() { close(); }
+
+shared_ptr<TcpTransport> TcpServer::accept() {
+	while (true) {
+		std::unique_lock lock(mSockMutex);
+
+		if (mSock == INVALID_SOCKET)
+			break;
+
+		fd_set readfds;
+		FD_ZERO(&readfds);
+		FD_SET(mSock, &readfds);
+		int n = std::max(mInterrupter.prepare(readfds), SOCKET_TO_INT(mSock) + 1);
+		lock.unlock();
+		int ret = ::select(n, &readfds, NULL, NULL, NULL);
+		lock.lock();
+		if (mSock == INVALID_SOCKET)
+			break;
+
+		if (ret < 0)
+			throw std::runtime_error("Failed to wait on socket");
+
+		if (FD_ISSET(mSock, &readfds)) {
+			struct sockaddr_storage addr;
+			socklen_t addrlen = sizeof(addr);
+			socket_t incomingSock = ::accept(mSock, (struct sockaddr *)&addr, &addrlen);
+			if (incomingSock == INVALID_SOCKET)
+				break;
+
+			return std::make_shared<TcpTransport>(incomingSock, nullptr); // no state callback
+		}
+	}
+
+	return nullptr;
+}
+
+void TcpServer::close() {
+	std::unique_lock lock(mSockMutex);
+	if (mSock != INVALID_SOCKET) {
+		PLOG_DEBUG << "Closing TCP server socket";
+		::closesocket(mSock);
+		mSock = INVALID_SOCKET;
+		mInterrupter.interrupt();
+	}
+}
+
+void TcpServer::listen(uint16_t port) {
+	PLOG_DEBUG << "Listening on port " << port;
+
+	struct addrinfo hints = {};
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_STREAM;
+	hints.ai_protocol = IPPROTO_TCP;
+	hints.ai_flags = AI_ADDRCONFIG;
+
+	struct addrinfo *result = nullptr;
+	if (::getaddrinfo(nullptr, std::to_string(port).c_str(), &hints, &result))
+		throw std::runtime_error("Resolution failed for local address");
+
+	static const auto find_family = [](struct addrinfo *ai_list, int family) {
+		struct addrinfo *ai = ai_list;
+		while (ai && ai->ai_family != family)
+			ai = ai->ai_next;
+		return ai;
+	};
+
+	struct addrinfo *ai;
+	if ((ai = find_family(result, AF_INET6)) == NULL && (ai = find_family(result, AF_INET)) == NULL)
+		throw std::runtime_error("No suitable address family found");
+
+	try {
+		std::unique_lock lock(mSockMutex);
+		PLOG_VERBOSE << "Creating TCP server socket";
+
+		// Create socket
+		mSock = ::socket(ai->ai_family, SOCK_STREAM, IPPROTO_TCP);
+		if (mSock == INVALID_SOCKET)
+			throw std::runtime_error("TCP server socket creation failed");
+
+		// Listen on both IPv6 and IPv4
+		const sockopt_t disabled = 0;
+		if (ai->ai_family == AF_INET6)
+			::setsockopt(mSock, IPPROTO_IPV6, IPV6_V6ONLY, (const char *)&disabled,
+			             sizeof(disabled));
+
+		// Set non-blocking
+		const ctl_t b = 1;
+		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+			throw std::runtime_error("Failed to set socket non-blocking mode");
+
+		// Bind socket
+		if (::bind(mSock, ai->ai_addr, ai->ai_addrlen) < 0) {
+			PLOG_WARNING << "TCP server socket binding on port " << port
+			             << " failed, errno=" << sockerrno;
+			throw std::runtime_error("TCP server socket binding failed");
+		}
+
+		// Listen
+		const int backlog = 10;
+		if (::listen(mSock, backlog) < 0) {
+			PLOG_WARNING << "TCP server socket listening failed, errno=" << sockerrno;
+			throw std::runtime_error("TCP server socket listening failed");
+		}
+
+		if (port != 0) {
+			mPort = port;
+		} else {
+			struct sockaddr_storage addr;
+			socklen_t addrlen = sizeof(addr);
+			if (::getsockname(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
+				throw std::runtime_error("getsockname failed");
+
+			switch (addr.ss_family) {
+			case AF_INET:
+				mPort = ntohs(reinterpret_cast<struct sockaddr_in *>(&addr)->sin_port);
+				break;
+			case AF_INET6:
+				mPort = ntohs(reinterpret_cast<struct sockaddr_in6 *>(&addr)->sin6_port);
+				break;
+			default:
+				throw std::logic_error("Unknown address family");
+			}
+		}
+	} catch (...) {
+		freeaddrinfo(result);
+		if (mSock != INVALID_SOCKET) {
+			::closesocket(mSock);
+			mSock = INVALID_SOCKET;
+		}
+		throw;
+	}
+
+	freeaddrinfo(result);
+}
+
+} // namespace rtc::impl
+
+#endif

+ 56 - 0
src/impl/tcpserver.hpp

@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_TCP_SERVER_H
+#define RTC_IMPL_TCP_SERVER_H
+
+#include "common.hpp"
+#include "queue.hpp"
+#include "tcptransport.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+// Use the socket defines from libjuice
+#include "../deps/libjuice/src/socket.h"
+
+namespace rtc::impl {
+
+class TcpServer {
+public:
+	TcpServer(uint16_t port);
+	~TcpServer();
+
+	shared_ptr<TcpTransport> accept();
+	void close();
+
+	uint16_t port() const { return mPort; }
+
+private:
+	void listen(uint16_t port);
+
+	uint16_t mPort;
+	socket_t mSock = INVALID_SOCKET;
+	std::mutex mSockMutex;
+	SelectInterrupter mInterrupter;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 41 - 77
src/impl/tcptransport.cpp

@@ -21,8 +21,6 @@
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
-#include <exception>
-
 #ifndef _WIN32
 #ifndef _WIN32
 #include <fcntl.h>
 #include <fcntl.h>
 #include <unistd.h>
 #include <unistd.h>
@@ -30,67 +28,38 @@
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
-using std::to_string;
+TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
+    : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
+      mService(std::move(service)) {
 
 
-SelectInterrupter::SelectInterrupter() {
-#ifndef _WIN32
-	int pipefd[2];
-	if (::pipe(pipefd) != 0)
-		throw std::runtime_error("Failed to create pipe");
-	::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
-	::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
-	mPipeOut = pipefd[1]; // read
-	mPipeIn = pipefd[0];  // write
-#endif
+	PLOG_DEBUG << "Initializing TCP transport";
 }
 }
 
 
-SelectInterrupter::~SelectInterrupter() {
-	std::lock_guard lock(mMutex);
-#ifdef _WIN32
-	if (mDummySock != INVALID_SOCKET)
-		::closesocket(mDummySock);
-#else
-	::close(mPipeIn);
-	::close(mPipeOut);
-#endif
-}
+TcpTransport::TcpTransport(socket_t sock, state_callback callback)
+    : Transport(nullptr, std::move(callback)), mIsActive(false), mSock(sock) {
 
 
-int SelectInterrupter::prepare(fd_set &readfds, [[maybe_unused]] fd_set &writefds) {
-	std::lock_guard lock(mMutex);
-#ifdef _WIN32
-	if (mDummySock == INVALID_SOCKET)
-		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
-	FD_SET(mDummySock, &readfds);
-	return SOCKET_TO_INT(mDummySock) + 1;
-#else
-	char dummy;
-	if (::read(mPipeIn, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
-		PLOG_WARNING << "Reading from interrupter pipe failed, errno=" << errno;
-	}
-	FD_SET(mPipeIn, &readfds);
-	return mPipeIn + 1;
-#endif
-}
+	PLOG_DEBUG << "Initializing TCP transport with socket";
 
 
-void SelectInterrupter::interrupt() {
-	std::lock_guard lock(mMutex);
-#ifdef _WIN32
-	if (mDummySock != INVALID_SOCKET) {
-		::closesocket(mDummySock);
-		mDummySock = INVALID_SOCKET;
-	}
-#else
-	char dummy = 0;
-	if (::write(mPipeOut, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
-		PLOG_WARNING << "Writing to interrupter pipe failed, errno=" << errno;
-	}
-#endif
-}
+	// Set non-blocking
+	const ctl_t b = 1;
+	if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+		throw std::runtime_error("Failed to set socket non-blocking mode");
 
 
-TcpTransport::TcpTransport(const string &hostname, const string &service, state_callback callback)
-    : Transport(nullptr, std::move(callback)), mHostname(hostname), mService(service) {
+	// Retrieve hostname and service
+	struct sockaddr_storage addr;
+	socklen_t addrlen = sizeof(addr);
+	if (::getpeername(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
+		throw std::runtime_error("getsockname failed");
 
 
-	PLOG_DEBUG << "Initializing TCP transport";
+	char node[MAX_NUMERICNODE_LEN];
+	char serv[MAX_NUMERICSERV_LEN];
+	if (::getnameinfo(reinterpret_cast<struct sockaddr *>(&addr), addrlen, node,
+	                  MAX_NUMERICNODE_LEN, serv, MAX_NUMERICSERV_LEN,
+	                  NI_NUMERICHOST | NI_NUMERICSERV) != 0)
+		throw std::runtime_error("getnameinfo failed");
+
+	mHostname = node;
+	mService = serv;
 }
 }
 
 
 TcpTransport::~TcpTransport() { stop(); }
 TcpTransport::~TcpTransport() { stop(); }
@@ -139,10 +108,12 @@ bool TcpTransport::outgoing(message_ptr message) {
 		return true;
 		return true;
 
 
 	mSendQueue.push(message);
 	mSendQueue.push(message);
-	interruptSelect(); // so the thread waits for writability
+	mInterrupter.interrupt(); // so the thread waits for writability
 	return false;
 	return false;
 }
 }
 
 
+string TcpTransport::remoteAddress() const { return mHostname + ':' + mService; }
+
 void TcpTransport::connect(const string &hostname, const string &service) {
 void TcpTransport::connect(const string &hostname, const string &service) {
 	PLOG_DEBUG << "Connecting to " << hostname << ":" << service;
 	PLOG_DEBUG << "Connecting to " << hostname << ":" << service;
 
 
@@ -197,7 +168,8 @@ void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
 		if (mSock == INVALID_SOCKET)
 		if (mSock == INVALID_SOCKET)
 			throw std::runtime_error("TCP socket creation failed");
 			throw std::runtime_error("TCP socket creation failed");
 
 
-		ctl_t b = 1;
+		// Set non-blocking
+		const ctl_t b = 1;
 		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
 		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
 			throw std::runtime_error("Failed to set socket non-blocking mode");
 			throw std::runtime_error("Failed to set socket non-blocking mode");
 
 
@@ -269,7 +241,7 @@ void TcpTransport::close() {
 		mSock = INVALID_SOCKET;
 		mSock = INVALID_SOCKET;
 	}
 	}
 	changeState(State::Disconnected);
 	changeState(State::Disconnected);
-	interruptSelect();
+	mInterrupter.interrupt();
 }
 }
 
 
 bool TcpTransport::trySendQueue() {
 bool TcpTransport::trySendQueue() {
@@ -301,7 +273,7 @@ bool TcpTransport::trySendMessage(message_ptr &message) {
 				message = make_message(message->end() - size, message->end());
 				message = make_message(message->end() - size, message->end());
 				return false;
 				return false;
 			} else {
 			} else {
-				throw std::runtime_error("Connection lost, errno=" + to_string(sockerrno));
+				throw std::runtime_error("Connection lost, errno=" + std::to_string(sockerrno));
 			}
 			}
 		}
 		}
 
 
@@ -318,7 +290,8 @@ void TcpTransport::runLoop() {
 	// Connect
 	// Connect
 	try {
 	try {
 		changeState(State::Connecting);
 		changeState(State::Connecting);
-		connect(mHostname, mService);
+		if (mSock == INVALID_SOCKET)
+			connect(mHostname, mService);
 
 
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TCP connect: " << e.what();
 		PLOG_ERROR << "TCP connect: " << e.what();
@@ -337,7 +310,13 @@ void TcpTransport::runLoop() {
 				break;
 				break;
 
 
 			fd_set readfds, writefds;
 			fd_set readfds, writefds;
-			int n = prepareSelect(readfds, writefds);
+			FD_ZERO(&readfds);
+			FD_ZERO(&writefds);
+			FD_SET(mSock, &readfds);
+			if (!mSendQueue.empty())
+				FD_SET(mSock, &writefds);
+
+			int n = std::max(mInterrupter.prepare(readfds), SOCKET_TO_INT(mSock) + 1);
 
 
 			struct timeval tv;
 			struct timeval tv;
 			tv.tv_sec = 10;
 			tv.tv_sec = 10;
@@ -388,21 +367,6 @@ void TcpTransport::runLoop() {
 	recv(nullptr);
 	recv(nullptr);
 }
 }
 
 
-int TcpTransport::prepareSelect(fd_set &readfds, fd_set &writefds) {
-	FD_ZERO(&readfds);
-	FD_ZERO(&writefds);
-	FD_SET(mSock, &readfds);
-
-	if (!mSendQueue.empty())
-		FD_SET(mSock, &writefds);
-
-	int n = SOCKET_TO_INT(mSock) + 1;
-	int m = mInterrupter.prepare(readfds, writefds);
-	return std::max(n, m);
-}
-
-void TcpTransport::interruptSelect() { mInterrupter.interrupt(); }
-
 } // namespace rtc::impl
 } // namespace rtc::impl
 
 
 #endif
 #endif

+ 8 - 22
src/impl/tcptransport.hpp

@@ -22,6 +22,7 @@
 #include "common.hpp"
 #include "common.hpp"
 #include "queue.hpp"
 #include "queue.hpp"
 #include "transport.hpp"
 #include "transport.hpp"
+#include "selectinterrupter.hpp"
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
@@ -33,27 +34,10 @@
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
-// Utility class to interrupt select()
-class SelectInterrupter {
-public:
-	SelectInterrupter();
-	~SelectInterrupter();
-
-	int prepare(fd_set &readfds, fd_set &writefds);
-	void interrupt();
-
-private:
-	std::mutex mMutex;
-#ifdef _WIN32
-	socket_t mDummySock = INVALID_SOCKET;
-#else // assume POSIX
-	int mPipeIn, mPipeOut;
-#endif
-};
-
 class TcpTransport : public Transport {
 class TcpTransport : public Transport {
 public:
 public:
-	TcpTransport(const string &hostname, const string &service, state_callback callback);
+	TcpTransport(string hostname, string service, state_callback callback); // active
+	TcpTransport(socket_t sock, state_callback callback);                   // passive
 	~TcpTransport();
 	~TcpTransport();
 
 
 	void start() override;
 	void start() override;
@@ -63,6 +47,10 @@ public:
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 	bool outgoing(message_ptr message) override;
 
 
+	bool isActive() const { return mIsActive; }
+
+	string remoteAddress() const;
+
 private:
 private:
 	void connect(const string &hostname, const string &service);
 	void connect(const string &hostname, const string &service);
 	void connect(const sockaddr *addr, socklen_t addrlen);
 	void connect(const sockaddr *addr, socklen_t addrlen);
@@ -73,9 +61,7 @@ private:
 
 
 	void runLoop();
 	void runLoop();
 
 
-	int prepareSelect(fd_set &readfds, fd_set &writefds);
-	void interruptSelect();
-
+	const bool mIsActive;
 	string mHostname, mService;
 	string mHostname, mService;
 
 
 	socket_t mSock = INVALID_SOCKET;
 	socket_t mSock = INVALID_SOCKET;

+ 50 - 20
src/impl/tlstransport.cpp

@@ -32,6 +32,23 @@ namespace rtc::impl {
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 
 
+namespace {
+
+gnutls_certificate_credentials_t default_certificate_credentials() {
+	static std::mutex mutex;
+	static shared_ptr<gnutls_certificate_credentials_t> creds;
+
+	std::lock_guard lock(mutex);
+	if (!creds) {
+		creds = shared_ptr<gnutls_certificate_credentials_t>(gnutls::new_credentials(),
+		                                                     gnutls::free_credentials);
+		gnutls::check(gnutls_certificate_set_x509_system_trust(*creds));
+	}
+	return *creds;
+}
+
+} // namespace
+
 void TlsTransport::Init() {
 void TlsTransport::Init() {
 	// Nothing to do
 	// Nothing to do
 }
 }
@@ -40,25 +57,28 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 	// Nothing to do
 }
 }
 
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)) {
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
+                           certificate_ptr certificate, state_callback callback)
+    : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
 
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
 
-	gnutls::check(gnutls_certificate_allocate_credentials(&mCreds));
-	gnutls::check(gnutls_init(&mSession, GNUTLS_CLIENT));
+	gnutls::check(gnutls_init(&mSession, mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER));
 
 
 	try {
 	try {
-		gnutls::check(gnutls_certificate_set_x509_system_trust(mCreds));
-		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
-
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *err_pos = NULL;
 		const char *err_pos = NULL;
 		gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
 		gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
 		              "Failed to set TLS priorities");
 		              "Failed to set TLS priorities");
 
 
-		PLOG_VERBOSE << "Server Name Indication: " << mHost;
-		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size());
+		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE,
+		                                     certificate ? certificate->credentials()
+		                                                 : default_certificate_credentials()));
+
+		if (mHost) {
+			PLOG_VERBOSE << "Server Name Indication: " << *mHost;
+			gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size());
+		}
 
 
 		gnutls_session_set_ptr(mSession, this);
 		gnutls_session_set_ptr(mSession, this);
 		gnutls_transport_set_ptr(mSession, this);
 		gnutls_transport_set_ptr(mSession, this);
@@ -68,7 +88,6 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 
 
 	} catch (...) {
 	} catch (...) {
 		gnutls_deinit(mSession);
 		gnutls_deinit(mSession);
-		gnutls_certificate_free_credentials(mCreds);
 		throw;
 		throw;
 	}
 	}
 }
 }
@@ -77,7 +96,6 @@ TlsTransport::~TlsTransport() {
 	stop();
 	stop();
 
 
 	gnutls_deinit(mSession);
 	gnutls_deinit(mSession);
-	gnutls_certificate_free_credentials(mCreds);
 }
 }
 
 
 void TlsTransport::start() {
 void TlsTransport::start() {
@@ -253,8 +271,9 @@ void TlsTransport::Cleanup() {
 	// Nothing to do
 	// Nothing to do
 }
 }
 
 
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
-    : Transport(lower, std::move(callback)), mHost(std::move(host)) {
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
+                           certificate_ptr certificate, state_callback callback)
+    : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
 
 
 	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
 	PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
 
 
@@ -265,8 +284,14 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
 		openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
 		               "Failed to set SSL priorities");
 		               "Failed to set SSL priorities");
 
 
-		if (!SSL_CTX_set_default_verify_paths(mCtx)) {
-			PLOG_WARNING << "SSL root CA certificates unavailable";
+		if (certificate) {
+			auto [x509, pkey] = certificate->credentials();
+			SSL_CTX_use_certificate(mCtx, x509);
+			SSL_CTX_use_PrivateKey(mCtx, pkey);
+		} else {
+			if (!SSL_CTX_set_default_verify_paths(mCtx)) {
+				PLOG_WARNING << "SSL root CA certificates unavailable";
+			}
 		}
 		}
 
 
 		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
 		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
@@ -281,13 +306,18 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 
 
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
 
-		SSL_set_hostflags(mSsl, 0);
-		openssl::check(SSL_set1_host(mSsl, mHost.c_str()), "Failed to set SSL host");
+		if (mHost) {
+			SSL_set_hostflags(mSsl, 0);
+			openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host");
 
 
-		PLOG_VERBOSE << "Server Name Indication: " << mHost;
-		SSL_set_tlsext_host_name(mSsl, mHost.c_str());
+			PLOG_VERBOSE << "Server Name Indication: " << *mHost;
+			SSL_set_tlsext_host_name(mSsl, mHost->c_str());
+		}
 
 
-		SSL_set_connect_state(mSsl);
+		if (mIsClient)
+			SSL_set_connect_state(mSsl);
+		else
+			SSL_set_accept_state(mSsl);
 
 
 		if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
 		if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
 			throw std::runtime_error("Failed to create BIO");
 			throw std::runtime_error("Failed to create BIO");

+ 7 - 3
src/impl/tlstransport.hpp

@@ -19,6 +19,7 @@
 #ifndef RTC_IMPL_TLS_TRANSPORT_H
 #ifndef RTC_IMPL_TLS_TRANSPORT_H
 #define RTC_IMPL_TLS_TRANSPORT_H
 #define RTC_IMPL_TLS_TRANSPORT_H
 
 
+#include "certificate.hpp"
 #include "common.hpp"
 #include "common.hpp"
 #include "queue.hpp"
 #include "queue.hpp"
 #include "tls.hpp"
 #include "tls.hpp"
@@ -37,26 +38,29 @@ public:
 	static void Init();
 	static void Init();
 	static void Cleanup();
 	static void Cleanup();
 
 
-	TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback);
+	TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host, certificate_ptr certificate,
+	             state_callback callback);
 	virtual ~TlsTransport();
 	virtual ~TlsTransport();
 
 
 	void start() override;
 	void start() override;
 	bool stop() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
 	bool send(message_ptr message) override;
 
 
+	bool isClient() const { return mIsClient; }
+
 protected:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void incoming(message_ptr message) override;
 	virtual void postHandshake();
 	virtual void postHandshake();
 	void runRecvLoop();
 	void runRecvLoop();
 
 
-	string mHost;
+	const optional<string> mHost;
+	const bool mIsClient;
 
 
 	Queue<message_ptr> mIncomingQueue;
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
 	std::thread mRecvThread;
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 	gnutls_session_t mSession;
 	gnutls_session_t mSession;
-	gnutls_certificate_credentials_t mCreds;
 
 
 	message_ptr mIncomingMessage;
 	message_ptr mIncomingMessage;
 	size_t mIncomingMessagePosition = 0;
 	size_t mIncomingMessagePosition = 0;

+ 1 - 0
src/impl/transport.hpp

@@ -61,6 +61,7 @@ public:
 	}
 	}
 
 
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
 	void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
+	void onStateChange(state_callback callback) { mStateChangeCallback = std::move(callback); }
 	State state() const { return mState; }
 	State state() const { return mState; }
 
 
 	virtual bool send(message_ptr message) { return outgoing(message); }
 	virtual bool send(message_ptr message) { return outgoing(message); }

+ 3 - 3
src/impl/verifiedtlstransport.cpp

@@ -24,12 +24,12 @@
 namespace rtc::impl {
 namespace rtc::impl {
 
 
 VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
 VerifiedTlsTransport::VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host,
-                                           state_callback callback)
-    : TlsTransport(std::move(lower), std::move(host), std::move(callback)) {
+                                           certificate_ptr certificate, state_callback callback)
+    : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	PLOG_DEBUG << "Setting up TLS certificate verification";
-	gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
+	gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0);
 #else
 #else
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);

+ 1 - 1
src/impl/verifiedtlstransport.hpp

@@ -27,7 +27,7 @@ namespace rtc::impl {
 
 
 class VerifiedTlsTransport final : public TlsTransport {
 class VerifiedTlsTransport final : public TlsTransport {
 public:
 public:
-	VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback);
+	VerifiedTlsTransport(shared_ptr<TcpTransport> lower, string host, certificate_ptr certificate, state_callback callback);
 	~VerifiedTlsTransport();
 	~VerifiedTlsTransport();
 };
 };
 
 

+ 130 - 93
src/impl/websocket.cpp

@@ -19,8 +19,8 @@
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
 #include "websocket.hpp"
 #include "websocket.hpp"
-#include "internals.hpp"
 #include "common.hpp"
 #include "common.hpp"
+#include "internals.hpp"
 #include "threadpool.hpp"
 #include "threadpool.hpp"
 
 
 #include "tcptransport.hpp"
 #include "tcptransport.hpp"
@@ -38,8 +38,10 @@ namespace rtc::impl {
 
 
 using namespace std::placeholders;
 using namespace std::placeholders;
 
 
-WebSocket::WebSocket(Configuration config_)
-    : config(std::move(config_)), mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
+WebSocket::WebSocket(optional<Configuration> optConfig, certificate_ptr certificate)
+    : config(optConfig ? std::move(*optConfig) : Configuration()),
+      mCertificate(std::move(certificate)), mIsSecure(mCertificate != nullptr),
+      mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {
 	PLOG_VERBOSE << "Creating WebSocket";
 	PLOG_VERBOSE << "Creating WebSocket";
 }
 }
 
 
@@ -48,7 +50,7 @@ WebSocket::~WebSocket() {
 	remoteClose();
 	remoteClose();
 }
 }
 
 
-void WebSocket::parse(const string &url) {
+void WebSocket::open(const string &url) {
 	PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
 	PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
 
 
 	if (state != State::Closed)
 	if (state != State::Closed)
@@ -64,34 +66,41 @@ void WebSocket::parse(const string &url) {
 	if (!std::regex_match(url, m, r) || m[10].length() == 0)
 	if (!std::regex_match(url, m, r) || m[10].length() == 0)
 		throw std::invalid_argument("Invalid WebSocket URL: " + url);
 		throw std::invalid_argument("Invalid WebSocket URL: " + url);
 
 
-	mScheme = m[2];
-	if (mScheme.empty())
-		mScheme = "ws";
-	else if (mScheme != "ws" && mScheme != "wss")
-		throw std::invalid_argument("Invalid WebSocket scheme: " + mScheme);
-
-	mHostname = m[10];
-	mService = m[12];
-	if (mService.empty()) {
-		mService = mScheme == "ws" ? "80" : "443";
-		mHost = mHostname;
+	string scheme = m[2];
+	if (scheme.empty())
+		scheme = "ws";
+
+	if (scheme != "ws" && scheme != "wss")
+		throw std::invalid_argument("Invalid WebSocket scheme: " + scheme);
+
+	mIsSecure = (scheme != "ws");
+
+	string host;
+	string hostname = m[10];
+	string service = m[12];
+	if (service.empty()) {
+		service = mIsSecure ? "443" : "80";
+		host = hostname;
 	} else {
 	} else {
-		mHost = mHostname + ':' + mService;
+		host = hostname + ':' + service;
 	}
 	}
 
 
-	while (!mHostname.empty() && mHostname.front() == '[')
-		mHostname.erase(mHostname.begin());
-	while (!mHostname.empty() && mHostname.back() == ']')
-		mHostname.pop_back();
+	while (!hostname.empty() && hostname.front() == '[')
+		hostname.erase(hostname.begin());
+	while (!hostname.empty() && hostname.back() == ']')
+		hostname.pop_back();
+
+	string path = m[13];
+	if (path.empty())
+		path += '/';
 
 
-	mPath = m[13];
-	if (mPath.empty())
-		mPath += '/';
 	if (string query = m[15]; !query.empty())
 	if (string query = m[15]; !query.empty())
-		mPath += "?" + query;
+		path += "?" + query;
+
+	std::atomic_store(&mWsHandshake, std::make_shared<WsHandshake>(host, path, config.protocols));
 
 
 	changeState(State::Connecting);
 	changeState(State::Connecting);
-	initTcpTransport();
+	setTcpTransport(std::make_shared<TcpTransport>(hostname, service, nullptr));
 }
 }
 
 
 void WebSocket::close() {
 void WebSocket::close() {
@@ -165,37 +174,41 @@ void WebSocket::incoming(message_ptr message) {
 	}
 	}
 }
 }
 
 
-shared_ptr<TcpTransport> WebSocket::initTcpTransport() {
+shared_ptr<TcpTransport> WebSocket::setTcpTransport(shared_ptr<TcpTransport> transport) {
 	PLOG_VERBOSE << "Starting TCP transport";
 	PLOG_VERBOSE << "Starting TCP transport";
+
+	if (!transport)
+		throw std::logic_error("TCP transport is null");
+
 	using State = TcpTransport::State;
 	using State = TcpTransport::State;
 	try {
 	try {
-		if (auto transport = std::atomic_load(&mTcpTransport))
-			return transport;
+		if (std::atomic_load(&mTcpTransport))
+			throw std::logic_error("TCP transport is already set");
+
+		transport->onStateChange([this, weak_this = weak_from_this()](State transportState) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+			switch (transportState) {
+			case State::Connected:
+				if (mIsSecure)
+					initTlsTransport();
+				else
+					initWsTransport();
+				break;
+			case State::Failed:
+				triggerError("TCP connection failed");
+				remoteClose();
+				break;
+			case State::Disconnected:
+				remoteClose();
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		});
 
 
-		auto transport = std::make_shared<TcpTransport>(
-		    mHostname, mService, [this, weak_this = weak_from_this()](State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (transportState) {
-			    case State::Connected:
-				    if (mScheme == "ws")
-					    initWsTransport();
-				    else
-					    initTlsTransport();
-				    break;
-			    case State::Failed:
-				    triggerError("TCP connection failed");
-				    remoteClose();
-				    break;
-			    case State::Disconnected:
-				    remoteClose();
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
 		std::atomic_store(&mTcpTransport, transport);
 		std::atomic_store(&mTcpTransport, transport);
 		if (state == WebSocket::State::Closed) {
 		if (state == WebSocket::State::Closed) {
 			mTcpTransport.reset();
 			mTcpTransport.reset();
@@ -219,6 +232,9 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			return transport;
 			return transport;
 
 
 		auto lower = std::atomic_load(&mTcpTransport);
 		auto lower = std::atomic_load(&mTcpTransport);
+		if (!lower)
+			throw std::logic_error("No underlying TCP transport for TLS transport");
+
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
 		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
 			auto shared_this = weak_this.lock();
 			auto shared_this = weak_this.lock();
 			if (!shared_this)
 			if (!shared_this)
@@ -240,18 +256,22 @@ shared_ptr<TlsTransport> WebSocket::initTlsTransport() {
 			}
 			}
 		};
 		};
 
 
-		shared_ptr<TlsTransport> transport;
+		auto handshake = getWsHandshake();
+		auto host = handshake ? make_optional(handshake->host()) : nullopt;
+		bool verify = host.has_value() && !config.disableTlsVerification;
+
 #ifdef _WIN32
 #ifdef _WIN32
-		if (!config.disableTlsVerification) {
+		if (std::exchange(verify, false)) {
 			PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
 			PLOG_WARNING << "TLS certificate verification with root CA is not supported on Windows";
 		}
 		}
-		transport = std::make_shared<TlsTransport>(lower, mHostname, stateChangeCallback);
 #else
 #else
-		if (config.disableTlsVerification)
-			transport = std::make_shared<TlsTransport>(lower, mHostname, stateChangeCallback);
+		shared_ptr<TlsTransport> transport;
+		if (verify)
+			transport = std::make_shared<VerifiedTlsTransport>(lower, host.value(), mCertificate,
+			                                                   stateChangeCallback);
 		else
 		else
 			transport =
 			transport =
-			    std::make_shared<VerifiedTlsTransport>(lower, mHostname, stateChangeCallback);
+			    std::make_shared<TlsTransport>(lower, host, mCertificate, stateChangeCallback);
 #endif
 #endif
 
 
 		std::atomic_store(&mTlsTransport, transport);
 		std::atomic_store(&mTlsTransport, transport);
@@ -276,41 +296,53 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		if (auto transport = std::atomic_load(&mWsTransport))
 		if (auto transport = std::atomic_load(&mWsTransport))
 			return transport;
 			return transport;
 
 
-		shared_ptr<Transport> lower = std::atomic_load(&mTlsTransport);
-		if (!lower)
-			lower = std::atomic_load(&mTcpTransport);
-
-		WsTransport::Configuration wsConfig = {};
-		wsConfig.host = mHost;
-		wsConfig.path = mPath;
-		wsConfig.protocols = config.protocols;
-
-		auto transport = std::make_shared<WsTransport>(
-		    lower, wsConfig, weak_bind(&WebSocket::incoming, this, _1),
-		    [this, weak_this = weak_from_this()](State transportState) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (transportState) {
-			    case State::Connected:
-				    if (state == WebSocket::State::Connecting) {
-					    PLOG_DEBUG << "WebSocket open";
-					    changeState(WebSocket::State::Open);
-					    triggerOpen();
-				    }
-				    break;
-			    case State::Failed:
-				    triggerError("WebSocket connection failed");
-				    remoteClose();
-				    break;
-			    case State::Disconnected:
-				    remoteClose();
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
+		variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower;
+		if (mIsSecure) {
+			auto transport = std::atomic_load(&mTlsTransport);
+			if (!transport)
+				throw std::logic_error("No underlying TLS transport for WebSocket transport");
+
+			lower = transport;
+		} else {
+			auto transport = std::atomic_load(&mTcpTransport);
+			if (!transport)
+				throw std::logic_error("No underlying TCP transport for WebSocket transport");
+
+			lower = transport;
+		}
+
+		if(!atomic_load(&mWsHandshake))
+			atomic_store(&mWsHandshake, std::make_shared<WsHandshake>());
+
+		auto stateChangeCallback = [this, weak_this = weak_from_this()](State transportState) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+			switch (transportState) {
+			case State::Connected:
+				if (state == WebSocket::State::Connecting) {
+					PLOG_DEBUG << "WebSocket open";
+					changeState(WebSocket::State::Open);
+					triggerOpen();
+				}
+				break;
+			case State::Failed:
+				triggerError("WebSocket connection failed");
+				remoteClose();
+				break;
+			case State::Disconnected:
+				remoteClose();
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		};
+
+		auto transport = std::make_shared<WsTransport>(lower, mWsHandshake,
+		                                          weak_bind(&WebSocket::incoming, this, _1),
+		                                          stateChangeCallback);
+
 		std::atomic_store(&mWsTransport, transport);
 		std::atomic_store(&mWsTransport, transport);
 		if (state == WebSocket::State::Closed) {
 		if (state == WebSocket::State::Closed) {
 			mWsTransport.reset();
 			mWsTransport.reset();
@@ -318,6 +350,7 @@ shared_ptr<WsTransport> WebSocket::initWsTransport() {
 		}
 		}
 		transport->start();
 		transport->start();
 		return transport;
 		return transport;
+
 	} catch (const std::exception &e) {
 	} catch (const std::exception &e) {
 		PLOG_ERROR << e.what();
 		PLOG_ERROR << e.what();
 		remoteClose();
 		remoteClose();
@@ -337,6 +370,10 @@ shared_ptr<WsTransport> WebSocket::getWsTransport() const {
 	return std::atomic_load(&mWsTransport);
 	return std::atomic_load(&mWsTransport);
 }
 }
 
 
+shared_ptr<WsHandshake> WebSocket::getWsHandshake() const {
+	return std::atomic_load(&mWsHandshake);
+}
+
 void WebSocket::closeTransports() {
 void WebSocket::closeTransports() {
 	PLOG_VERBOSE << "Closing transports";
 	PLOG_VERBOSE << "Closing transports";
 
 

+ 9 - 5
src/impl/websocket.hpp

@@ -41,10 +41,10 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	using State = rtc::WebSocket::State;
 	using State = rtc::WebSocket::State;
 	using Configuration = rtc::WebSocket::Configuration;
 	using Configuration = rtc::WebSocket::Configuration;
 
 
-	WebSocket(Configuration config_);
+	WebSocket(optional<Configuration> optConfig = nullopt, certificate_ptr certificate = nullptr);
 	~WebSocket();
 	~WebSocket();
 
 
-	void parse(const string &url);
+	void open(const string &url);
 	void close();
 	void close();
 	bool outgoing(message_ptr message);
 	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 	void incoming(message_ptr message);
@@ -60,26 +60,30 @@ struct WebSocket final : public Channel, public std::enable_shared_from_this<Web
 	bool changeState(State state);
 	bool changeState(State state);
 	void remoteClose();
 	void remoteClose();
 
 
-	shared_ptr<TcpTransport> initTcpTransport();
+	shared_ptr<TcpTransport> setTcpTransport(shared_ptr<TcpTransport> transport);
 	shared_ptr<TlsTransport> initTlsTransport();
 	shared_ptr<TlsTransport> initTlsTransport();
 	shared_ptr<WsTransport> initWsTransport();
 	shared_ptr<WsTransport> initWsTransport();
 	shared_ptr<TcpTransport> getTcpTransport() const;
 	shared_ptr<TcpTransport> getTcpTransport() const;
 	shared_ptr<TlsTransport> getTlsTransport() const;
 	shared_ptr<TlsTransport> getTlsTransport() const;
 	shared_ptr<WsTransport> getWsTransport() const;
 	shared_ptr<WsTransport> getWsTransport() const;
+	shared_ptr<WsHandshake> getWsHandshake() const;
 
 
 	void closeTransports();
 	void closeTransports();
 
 
 	const Configuration config;
 	const Configuration config;
+
 	std::atomic<State> state = State::Closed;
 	std::atomic<State> state = State::Closed;
 
 
 private:
 private:
 	const init_token mInitToken = Init::Token();
 	const init_token mInitToken = Init::Token();
 
 
+	const certificate_ptr mCertificate;
+	bool mIsSecure;
+
 	shared_ptr<TcpTransport> mTcpTransport;
 	shared_ptr<TcpTransport> mTcpTransport;
 	shared_ptr<TlsTransport> mTlsTransport;
 	shared_ptr<TlsTransport> mTlsTransport;
 	shared_ptr<WsTransport> mWsTransport;
 	shared_ptr<WsTransport> mWsTransport;
-
-	string mScheme, mHost, mHostname, mService, mPath;
+	shared_ptr<WsHandshake> mWsHandshake;
 
 
 	Queue<message_ptr> mRecvQueue;
 	Queue<message_ptr> mRecvQueue;
 };
 };

+ 78 - 0
src/impl/websocketserver.cpp

@@ -0,0 +1,78 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "websocketserver.hpp"
+#include "common.hpp"
+#include "internals.hpp"
+#include "threadpool.hpp"
+
+namespace rtc::impl {
+
+using namespace std::placeholders;
+
+WebSocketServer::WebSocketServer(Configuration config_)
+    : config(std::move(config_)), tcpServer(std::make_unique<TcpServer>(config.port)),
+      mStopped(false) {
+	PLOG_VERBOSE << "Creating WebSocketServer";
+	mThread = std::thread(&WebSocketServer::runLoop, this);
+}
+
+WebSocketServer::~WebSocketServer() {
+	PLOG_VERBOSE << "Destroying WebSocketServer";
+	stop();
+}
+
+void WebSocketServer::stop() {
+	if (mStopped.exchange(true))
+		return;
+
+	PLOG_DEBUG << "Stopping WebSocketServer thread";
+	tcpServer->close();
+	mThread.join();
+}
+
+void WebSocketServer::runLoop() {
+	PLOG_INFO << "Starting WebSocketServer";
+
+	try {
+		while (auto incoming = tcpServer->accept()) {
+			try {
+				if (!clientCallback)
+					continue;
+
+				auto impl = std::make_shared<WebSocket>(nullopt, mCertificate);
+				impl->changeState(WebSocket::State::Connecting);
+				impl->setTcpTransport(incoming);
+				clientCallback(std::make_shared<rtc::WebSocket>(impl));
+
+			} catch (const std::exception &e) {
+				PLOG_ERROR << "WebSocketServer: " << e.what();
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_FATAL << "WebSocketServer: " << e.what();
+	}
+
+	PLOG_INFO << "Stopped WebSocketServer";
+}
+
+} // namespace rtc::impl
+
+#endif

+ 66 - 0
src/impl/websocketserver.hpp

@@ -0,0 +1,66 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_WEBSOCKETSERVER_H
+#define RTC_IMPL_WEBSOCKETSERVER_H
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "certificate.hpp"
+#include "common.hpp"
+#include "init.hpp"
+#include "message.hpp"
+#include "tcpserver.hpp"
+#include "websocket.hpp"
+
+#include "rtc/websocket.hpp"
+#include "rtc/websocketserver.hpp"
+
+#include <atomic>
+#include <thread>
+
+namespace rtc::impl {
+
+struct WebSocketServer final : public std::enable_shared_from_this<WebSocketServer> {
+	using Configuration = rtc::WebSocketServer::Configuration;
+
+	WebSocketServer(Configuration config_);
+	~WebSocketServer();
+
+	void stop();
+
+	const Configuration config;
+	const unique_ptr<TcpServer> tcpServer;
+
+	synchronized_callback<shared_ptr<rtc::WebSocket>> clientCallback;
+
+private:
+	const init_token mInitToken = Init::Token();
+
+	void runLoop();
+
+	certificate_ptr mCertificate;
+	std::thread mThread;
+	std::atomic<bool> mStopped;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif // RTC_IMPL_WEBSOCKET_H

+ 277 - 0
src/impl/wshandshake.cpp

@@ -0,0 +1,277 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "wshandshake.hpp"
+#include "base64.hpp"
+#include "internals.hpp"
+#include "sha.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <algorithm>
+#include <chrono>
+#include <climits>
+#include <iostream>
+#include <iterator>
+#include <random>
+#include <sstream>
+
+using std::string;
+
+namespace {
+
+std::vector<string> explode(const string &str, char delim) {
+	std::vector<std::string> result;
+	std::istringstream ss(str);
+	string token;
+	while (std::getline(ss, token, delim))
+		result.push_back(token);
+
+	return result;
+}
+
+string implode(const std::vector<string> &tokens, char delim) {
+	string sdelim(1, delim);
+	std::ostringstream ss;
+	std::copy(tokens.begin(), tokens.end(), std::ostream_iterator<string>(ss, sdelim.c_str()));
+	string result = ss.str();
+	if (result.size() > 0)
+		result.resize(result.size() - 1);
+
+	return result;
+}
+
+} // namespace
+
+namespace rtc::impl {
+
+using std::to_string;
+using std::chrono::system_clock;
+using random_bytes_engine =
+    std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
+
+WsHandshake::WsHandshake() {}
+
+WsHandshake::WsHandshake(string host, string path, std::vector<string> protocols)
+    : mHost(std::move(host)), mPath(std::move(path)), mProtocols(std::move(protocols)) {
+
+	if (mHost.empty())
+		throw std::invalid_argument("WebSocket HTTP host cannot be empty");
+
+	if (mPath.empty())
+		throw std::invalid_argument("WebSocket HTTP path cannot be empty");
+}
+
+string WsHandshake::host() const {
+	std::unique_lock lock(mMutex);
+	return mHost;
+}
+
+string WsHandshake::path() const {
+	std::unique_lock lock(mMutex);
+	return mPath;
+}
+
+std::vector<string> WsHandshake::protocols() const {
+	std::unique_lock lock(mMutex);
+	return mProtocols;
+}
+
+string WsHandshake::generateHttpRequest() {
+	std::unique_lock lock(mMutex);
+	mKey = generateKey();
+
+	string out = "GET " + mPath +
+	             " HTTP/1.1\r\n"
+	             "Host: " +
+	             mHost +
+	             "\r\n"
+	             "Connection: Upgrade\r\n"
+	             "Upgrade: websocket\r\n"
+	             "Sec-WebSocket-Version: 13\r\n"
+	             "Sec-WebSocket-Key: " +
+	             mKey + "\r\n";
+
+	if (!mProtocols.empty())
+		out += "Sec-WebSocket-Protocol: " + implode(mProtocols, ',') + "\r\n";
+
+	out += "\r\n";
+
+	return out;
+}
+
+string WsHandshake::generateHttpResponse() {
+	std::unique_lock lock(mMutex);
+	const string out = "HTTP/1.1 101 Switching Protocols\r\n"
+	                   "Connection: Upgrade\r\n"
+	                   "Upgrade: websocket\r\n"
+	                   "Sec-WebSocket-Accept: " +
+	                   computeAcceptKey(mKey) + "\r\n\r\n";
+
+	return out;
+}
+
+size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
+	std::unique_lock lock(mMutex);
+	std::list<string> lines;
+	size_t length = parseHttpLines(buffer, size, lines);
+	if (length == 0)
+		return 0;
+
+	if (lines.empty())
+		throw std::runtime_error("Invalid HTTP request for WebSocket");
+
+	std::istringstream requestLine(std::move(lines.front()));
+	lines.pop_front();
+
+	string method, path, protocol;
+	requestLine >> method >> path >> protocol;
+	PLOG_DEBUG << "WebSocket request method \"" << method << "\" for path: " << path;
+	if (method != "GET")
+		throw std::runtime_error("Unexpected request method \"" + method + "\" for WebSocket");
+
+	mPath = std::move(path);
+
+	auto headers = parseHttpHeaders(lines);
+
+	auto h = headers.find("host");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket host header missing in request");
+
+	mHost = std::move(h->second);
+
+	h = headers.find("upgrade");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket update header missing in request");
+
+	string upgrade;
+	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
+	               [](char c) { return std::tolower(c); });
+	if (upgrade != "websocket")
+		throw std::runtime_error("WebSocket update header mismatching: " + h->second);
+
+	h = headers.find("sec-websocket-key");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket key header missing in request");
+
+	mKey = std::move(h->second);
+
+	h = headers.find("sec-websocket-protocol");
+	if (h != headers.end())
+		mProtocols = explode(h->second, ',');
+
+	return length;
+}
+
+size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
+	std::unique_lock lock(mMutex);
+	std::list<string> lines;
+	size_t length = parseHttpLines(buffer, size, lines);
+	if (length == 0)
+		return 0;
+
+	if (lines.empty())
+		throw std::runtime_error("Invalid HTTP response for WebSocket");
+
+	std::istringstream status(std::move(lines.front()));
+	lines.pop_front();
+
+	string protocol;
+	unsigned int code = 0;
+	status >> protocol >> code;
+	PLOG_DEBUG << "WebSocket response code: " << code;
+	if (code != 101)
+		throw std::runtime_error("Unexpected response code " + to_string(code) + " for WebSocket");
+
+	auto headers = parseHttpHeaders(lines);
+
+	auto h = headers.find("upgrade");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket update header missing");
+
+	string upgrade;
+	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
+	               [](char c) { return std::tolower(c); });
+	if (upgrade != "websocket")
+		throw std::runtime_error("WebSocket update header mismatching: " + h->second);
+
+	h = headers.find("sec-websocket-accept");
+	if (h == headers.end())
+		throw std::runtime_error("WebSocket accept header missing");
+
+	if (h->second != computeAcceptKey(mKey))
+		throw std::runtime_error("WebSocket accept header is invalid");
+
+	return length;
+}
+
+string WsHandshake::generateKey() {
+	// RFC 6455: The request MUST include a header field with the name Sec-WebSocket-Key.  The value
+	// of this header field MUST be a nonce consisting of a randomly selected 16-byte value that has
+	// been base64-encoded. [...] The nonce MUST be selected randomly for each connection.
+	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
+	random_bytes_engine generator(seed);
+	binary key(16);
+	auto k = reinterpret_cast<uint8_t *>(key.data());
+	std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); });
+	return to_base64(key);
+}
+
+string WsHandshake::computeAcceptKey(const string &key) {
+	return to_base64(Sha1(string(key) + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"));
+}
+
+size_t WsHandshake::parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines) {
+	lines.clear();
+	auto begin = reinterpret_cast<const char *>(buffer);
+	auto end = begin + size;
+	auto cur = begin;
+	while (true) {
+		auto last = cur;
+		cur = std::find(cur, end, '\n');
+		if (cur == end)
+			return 0;
+		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
+		if (line.empty())
+			break;
+		lines.emplace_back(std::move(line));
+	}
+
+	return cur - begin;
+}
+
+std::multimap<string, string> WsHandshake::parseHttpHeaders(const std::list<string> &lines) {
+	std::multimap<string, string> headers;
+	for (const auto &line : lines) {
+		if (size_t pos = line.find_first_of(':'); pos != string::npos) {
+			string key = line.substr(0, pos);
+			string value = line.substr(line.find_first_not_of(' ', pos + 1));
+			std::transform(key.begin(), key.end(), key.begin(),
+			               [](char c) { return std::tolower(c); });
+			headers.emplace(std::move(key), std::move(value));
+		} else {
+			headers.emplace(line, "");
+		}
+	}
+
+	return headers;
+}
+
+} // namespace rtc::impl
+
+#endif

+ 62 - 0
src/impl/wshandshake.hpp

@@ -0,0 +1,62 @@
+/**
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_IMPL_WS_HANDSHAKE_H
+#define RTC_IMPL_WS_HANDSHAKE_H
+
+#include "common.hpp"
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include <list>
+#include <map>
+
+namespace rtc::impl {
+
+class WsHandshake final {
+public:
+	WsHandshake();
+	WsHandshake(string host, string path = "/", std::vector<string> protocols = {});
+
+	string host() const;
+	string path() const;
+	std::vector<string> protocols() const;
+
+	string generateHttpRequest();
+	string generateHttpResponse();
+	size_t parseHttpRequest(const byte *buffer, size_t size);
+	size_t parseHttpResponse(const byte *buffer, size_t size);
+
+private:
+	static string generateKey();
+	static string computeAcceptKey(const string &key);
+	static size_t parseHttpLines(const byte *buffer, size_t size, std::list<string> &lines);
+	static std::multimap<string, string> parseHttpHeaders(const std::list<string> &lines);
+
+	string mHost;
+	string mPath;
+	std::vector<string> mProtocols;
+	string mKey;
+	mutable std::mutex mMutex;
+};
+
+} // namespace rtc::impl
+
+#endif
+
+#endif

+ 56 - 121
src/impl/wstransport.cpp

@@ -1,5 +1,5 @@
 /**
 /**
- * Copyright (c) 2020 Paul-Louis Ageneau
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
  *
  *
  * This library is free software; you can redistribute it and/or
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
  * modify it under the terms of the GNU Lesser General Public
@@ -17,19 +17,18 @@
  */
  */
 
 
 #include "wstransport.hpp"
 #include "wstransport.hpp"
-#include "base64.hpp"
 #include "tcptransport.hpp"
 #include "tcptransport.hpp"
 #include "tlstransport.hpp"
 #include "tlstransport.hpp"
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
+#include <algorithm>
 #include <chrono>
 #include <chrono>
-#include <iterator>
-#include <list>
-#include <map>
+#include <iostream>
 #include <numeric>
 #include <numeric>
 #include <random>
 #include <random>
 #include <regex>
 #include <regex>
+#include <sstream>
 
 
 #ifdef _WIN32
 #ifdef _WIN32
 #include <winsock2.h>
 #include <winsock2.h>
@@ -47,25 +46,26 @@
 
 
 namespace rtc::impl {
 namespace rtc::impl {
 
 
-using namespace std::chrono;
 using std::to_integer;
 using std::to_integer;
 using std::to_string;
 using std::to_string;
-
+using std::chrono::system_clock;
 using random_bytes_engine =
 using random_bytes_engine =
     std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
     std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned short>;
 
 
-WsTransport::WsTransport(shared_ptr<Transport> lower, Configuration config,
-                         message_callback recvCallback, state_callback stateCallback)
-    : Transport(lower, std::move(stateCallback)), mConfig(std::move(config)) {
+WsTransport::WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
+                         shared_ptr<WsHandshake> handshake, message_callback recvCallback,
+                         state_callback stateCallback)
+    : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
+                std::move(stateCallback)),
+      mHandshake(std::move(handshake)),
+      mIsClient(
+          std::visit(rtc::overloaded{[](shared_ptr<TcpTransport> l) { return l->isActive(); },
+                                     [](shared_ptr<TlsTransport> l) { return l->isClient(); }},
+                     lower)) {
+
 	onRecv(recvCallback);
 	onRecv(recvCallback);
 
 
 	PLOG_DEBUG << "Initializing WebSocket transport";
 	PLOG_DEBUG << "Initializing WebSocket transport";
-
-	if (mConfig.host.empty())
-		throw std::invalid_argument("WebSocket HTTP host cannot be empty");
-
-	if (mConfig.path.empty())
-		throw std::invalid_argument("WebSocket HTTP path cannot be empty");
 }
 }
 
 
 WsTransport::~WsTransport() { stop(); }
 WsTransport::~WsTransport() { stop(); }
@@ -74,7 +74,10 @@ void WsTransport::start() {
 	Transport::start();
 	Transport::start();
 
 
 	registerIncoming();
 	registerIncoming();
-	sendHttpRequest();
+
+	changeState(State::Connecting);
+	if (mIsClient)
+		sendHttpRequest();
 }
 }
 
 
 bool WsTransport::stop() {
 bool WsTransport::stop() {
@@ -91,7 +94,7 @@ bool WsTransport::send(message_ptr message) {
 
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 	PLOG_VERBOSE << "Send size=" << message->size();
 	return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
 	return sendFrame({message->type == Message::String ? TEXT_FRAME : BINARY_FRAME, message->data(),
-	                  message->size(), true, true});
+	                  message->size(), true, mIsClient});
 }
 }
 
 
 void WsTransport::incoming(message_ptr message) {
 void WsTransport::incoming(message_ptr message) {
@@ -103,10 +106,12 @@ void WsTransport::incoming(message_ptr message) {
 		PLOG_VERBOSE << "Incoming size=" << message->size();
 		PLOG_VERBOSE << "Incoming size=" << message->size();
 
 
 		if (message->size() == 0) {
 		if (message->size() == 0) {
-			// TCP is idle, send a ping
-			PLOG_DEBUG << "WebSocket sending ping";
-			uint32_t dummy = 0;
-			sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, true});
+			if (state() == State::Connected) {
+				// TCP is idle, send a ping
+				PLOG_DEBUG << "WebSocket sending ping";
+				uint32_t dummy = 0;
+				sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
+			}
 			return;
 			return;
 		}
 		}
 
 
@@ -114,10 +119,20 @@ void WsTransport::incoming(message_ptr message) {
 
 
 		try {
 		try {
 			if (state() == State::Connecting) {
 			if (state() == State::Connecting) {
-				if (size_t len = readHttpResponse(mBuffer.data(), mBuffer.size())) {
-					PLOG_INFO << "WebSocket open";
-					changeState(State::Connected);
-					mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				if (mIsClient) {
+					if (size_t len =
+					        mHandshake->parseHttpResponse(mBuffer.data(), mBuffer.size())) {
+						PLOG_INFO << "WebSocket client-side open";
+						changeState(State::Connected);
+						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+					}
+				} else {
+					if (size_t len = mHandshake->parseHttpRequest(mBuffer.data(), mBuffer.size())) {
+						PLOG_INFO << "WebSocket server-side open";
+						sendHttpResponse();
+						changeState(State::Connected);
+						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+					}
 				}
 				}
 			}
 			}
 
 
@@ -148,115 +163,35 @@ void WsTransport::incoming(message_ptr message) {
 
 
 void WsTransport::close() {
 void WsTransport::close() {
 	if (state() == State::Connected) {
 	if (state() == State::Connected) {
-		sendFrame({CLOSE, NULL, 0, true, true});
+		sendFrame({CLOSE, NULL, 0, true, mIsClient});
 		PLOG_INFO << "WebSocket closing";
 		PLOG_INFO << "WebSocket closing";
 		changeState(State::Disconnected);
 		changeState(State::Disconnected);
 	}
 	}
 }
 }
 
 
 bool WsTransport::sendHttpRequest() {
 bool WsTransport::sendHttpRequest() {
-	PLOG_DEBUG << "Sending WebSocket HTTP request for path " << mConfig.path;
-	changeState(State::Connecting);
-
-	auto seed = static_cast<unsigned int>(system_clock::now().time_since_epoch().count());
-	random_bytes_engine generator(seed);
-
-	binary key(16);
-	auto k = reinterpret_cast<uint8_t *>(key.data());
-	std::generate(k, k + key.size(), [&]() { return uint8_t(generator()); });
-
-	string appendHeader = "";
-	if (mConfig.protocols.size() > 0) {
-		appendHeader +=
-		    "Sec-WebSocket-Protocol: " +
-		    std::accumulate(mConfig.protocols.begin(), mConfig.protocols.end(), string(),
-		                    [](const string &a, const string &b) -> string {
-			                    return a + (a.length() > 0 ? "," : "") + b;
-		                    }) +
-		    "\r\n";
-	}
-
-	const string request = "GET " + mConfig.path +
-	                       " HTTP/1.1\r\n"
-	                       "Host: " +
-	                       mConfig.host +
-	                       "\r\n"
-	                       "Connection: Upgrade\r\n"
-	                       "Upgrade: websocket\r\n"
-	                       "Sec-WebSocket-Version: 13\r\n"
-	                       "Sec-WebSocket-Key: " +
-	                       to_base64(key) + "\r\n" + std::move(appendHeader) + "\r\n";
+	PLOG_DEBUG << "Sending WebSocket HTTP request";
 
 
+	const string request = mHandshake->generateHttpRequest();
 	auto data = reinterpret_cast<const byte *>(request.data());
 	auto data = reinterpret_cast<const byte *>(request.data());
 	auto size = request.size();
 	auto size = request.size();
 	return outgoing(make_message(data, data + size));
 	return outgoing(make_message(data, data + size));
 }
 }
 
 
-size_t WsTransport::readHttpResponse(const byte *buffer, size_t size) {
-	std::list<string> lines;
-	auto begin = reinterpret_cast<const char *>(buffer);
-	auto end = begin + size;
-	auto cur = begin;
-
-	while (true) {
-		auto last = cur;
-		cur = std::find(cur, end, '\n');
-		if (cur == end)
-			return 0;
-		string line(last, cur != begin && *std::prev(cur) == '\r' ? std::prev(cur++) : cur++);
-		if (line.empty())
-			break;
-		lines.emplace_back(std::move(line));
-	}
-	size_t length = cur - begin;
-
-	if (lines.empty())
-		throw std::runtime_error("Invalid HTTP response for WebSocket");
-
-	string status = std::move(lines.front());
-	lines.pop_front();
-
-	std::istringstream ss(status);
-	string protocol;
-	unsigned int code = 0;
-	ss >> protocol >> code;
-	PLOG_DEBUG << "WebSocket response code: " << code;
-	if (code != 101)
-		throw std::runtime_error("Unexpected response code for WebSocket: " + to_string(code));
-
-	std::multimap<string, string> headers;
-	for (const auto &line : lines) {
-		if (size_t pos = line.find_first_of(':'); pos != string::npos) {
-			string key = line.substr(0, pos);
-			string value = line.substr(line.find_first_not_of(' ', pos + 1));
-			std::transform(key.begin(), key.end(), key.begin(),
-			               [](char c) { return std::tolower(c); });
-			headers.emplace(std::move(key), std::move(value));
-		} else {
-			headers.emplace(line, "");
-		}
-	}
-
-	auto h = headers.find("upgrade");
-	if (h == headers.end())
-		throw std::runtime_error("WebSocket update header missing");
-
-	string upgrade;
-	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
-	               [](char c) { return std::tolower(c); });
-	if (upgrade != "websocket")
-		throw std::runtime_error("WebSocket update header mismatching: " + h->second);
-
-	h = headers.find("sec-websocket-accept");
-	if (h == headers.end())
-		throw std::runtime_error("WebSocket accept header missing");
+bool WsTransport::sendHttpResponse() {
+	PLOG_DEBUG << "Sending WebSocket HTTP response";
 
 
-	// TODO: Verify Sec-WebSocket-Accept
+	const string response = mHandshake->generateHttpResponse();
+	auto data = reinterpret_cast<const byte *>(response.data());
+	auto size = response.size();
+	bool ret = outgoing(make_message(data, data + size));
 
 
-	return length;
+	changeState(State::Connected);
+	return ret;
 }
 }
 
 
-// http://tools.ietf.org/html/rfc6455#section-5.2  Base Framing Protocol
+// RFC6455 5.2. Base Framing Protocol
+// http://tools.ietf.org/html/rfc6455#section-5.2
 //
 //
 //  0                   1                   2                   3
 //  0                   1                   2                   3
 //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
 //  0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
@@ -364,7 +299,7 @@ void WsTransport::recvFrame(const Frame &frame) {
 	}
 	}
 	case PING: {
 	case PING: {
 		PLOG_DEBUG << "WebSocket received ping, sending pong";
 		PLOG_DEBUG << "WebSocket received ping, sending pong";
-		sendFrame({PONG, frame.payload, frame.length, true, true});
+		sendFrame({PONG, frame.payload, frame.length, true, mIsClient});
 		break;
 		break;
 	}
 	}
 	case PONG: {
 	case PONG: {
@@ -423,6 +358,6 @@ bool WsTransport::sendFrame(const Frame &frame) {
 	return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
 	return outgoing(make_message(frame.payload, frame.payload + frame.length)); // payload
 }
 }
 
 
-} // namespace rtc
+} // namespace rtc::impl
 
 
 #endif
 #endif

+ 11 - 14
src/impl/wstransport.hpp

@@ -1,5 +1,5 @@
 /**
 /**
- * Copyright (c) 2020 Paul-Louis Ageneau
+ * Copyright (c) 2020-2021 Paul-Louis Ageneau
  *
  *
  * This library is free software; you can redistribute it and/or
  * This library is free software; you can redistribute it and/or
  * modify it under the terms of the GNU Lesser General Public
  * modify it under the terms of the GNU Lesser General Public
@@ -21,6 +21,7 @@
 
 
 #include "common.hpp"
 #include "common.hpp"
 #include "transport.hpp"
 #include "transport.hpp"
+#include "wshandshake.hpp"
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
@@ -29,26 +30,21 @@ namespace rtc::impl {
 class TcpTransport;
 class TcpTransport;
 class TlsTransport;
 class TlsTransport;
 
 
-class WsTransport : public Transport {
+class WsTransport final : public Transport {
 public:
 public:
-	struct Configuration {
-		string host;
-		string path = "/";
-		std::vector<string> protocols;
-	};
-
-	WsTransport(shared_ptr<Transport> lower, Configuration config,
-	            message_callback recvCallback, state_callback stateCallback);
+	WsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<TlsTransport>> lower,
+	            shared_ptr<WsHandshake> handshake, message_callback recvCallback,
+	            state_callback stateCallback);
 	~WsTransport();
 	~WsTransport();
 
 
 	void start() override;
 	void start() override;
 	bool stop() override;
 	bool stop() override;
 	bool send(message_ptr message) override;
 	bool send(message_ptr message) override;
-
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
-
 	void close();
 	void close();
 
 
+	bool isClient() const { return mIsClient; }
+
 private:
 private:
 	enum Opcode : uint8_t {
 	enum Opcode : uint8_t {
 		CONTINUATION = 0,
 		CONTINUATION = 0,
@@ -68,13 +64,14 @@ private:
 	};
 	};
 
 
 	bool sendHttpRequest();
 	bool sendHttpRequest();
-	size_t readHttpResponse(const byte *buffer, size_t size);
+	bool sendHttpResponse();
 
 
 	size_t readFrame(byte *buffer, size_t size, Frame &frame);
 	size_t readFrame(byte *buffer, size_t size, Frame &frame);
 	void recvFrame(const Frame &frame);
 	void recvFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 	bool sendFrame(const Frame &frame);
 
 
-	const Configuration mConfig;
+	const shared_ptr<WsHandshake> mHandshake;
+	const bool mIsClient;
 
 
 	binary mBuffer;
 	binary mBuffer;
 	binary mPartial;
 	binary mPartial;

+ 17 - 13
src/websocket.cpp

@@ -21,25 +21,21 @@
 #include "websocket.hpp"
 #include "websocket.hpp"
 #include "common.hpp"
 #include "common.hpp"
 
 
-#include "impl/websocket.hpp"
 #include "impl/internals.hpp"
 #include "impl/internals.hpp"
-
-#include <regex>
-
-#ifdef _WIN32
-#include <winsock2.h>
-#endif
+#include "impl/websocket.hpp"
 
 
 namespace rtc {
 namespace rtc {
 
 
-using namespace std::placeholders;
-
 WebSocket::WebSocket() : WebSocket(Configuration()) {}
 WebSocket::WebSocket() : WebSocket(Configuration()) {}
 
 
 WebSocket::WebSocket(Configuration config)
 WebSocket::WebSocket(Configuration config)
     : CheshireCat<impl::WebSocket>(std::move(config)),
     : CheshireCat<impl::WebSocket>(std::move(config)),
       Channel(std::dynamic_pointer_cast<impl::Channel>(CheshireCat<impl::WebSocket>::impl())) {}
       Channel(std::dynamic_pointer_cast<impl::Channel>(CheshireCat<impl::WebSocket>::impl())) {}
 
 
+WebSocket::WebSocket(impl_ptr<impl::WebSocket> impl)
+    : CheshireCat<impl::WebSocket>(std::move(impl)),
+      Channel(std::dynamic_pointer_cast<impl::Channel>(CheshireCat<impl::WebSocket>::impl())) {}
+
 WebSocket::~WebSocket() { impl()->remoteClose(); }
 WebSocket::~WebSocket() { impl()->remoteClose(); }
 
 
 WebSocket::State WebSocket::readyState() const { return impl()->state; }
 WebSocket::State WebSocket::readyState() const { return impl()->state; }
@@ -52,10 +48,7 @@ size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
 
 void WebSocket::open(const string &url) {
 void WebSocket::open(const string &url) {
 	PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
 	PLOG_VERBOSE << "Opening WebSocket to URL: " << url;
-
-	impl()->parse(url);
-	impl()->changeState(State::Connecting);
-	impl()->initTcpTransport();
+	impl()->open(url);
 }
 }
 
 
 void WebSocket::close() {
 void WebSocket::close() {
@@ -78,6 +71,17 @@ bool WebSocket::send(const byte *data, size_t size) {
 	return impl()->outgoing(make_message(data, data + size));
 	return impl()->outgoing(make_message(data, data + size));
 }
 }
 
 
+optional<string> WebSocket::remoteAddress() const {
+	auto tcpTransport = impl()->getTcpTransport();
+	return tcpTransport ? make_optional(tcpTransport->remoteAddress()) : nullopt;
+}
+
+optional<string> WebSocket::path() const {
+	auto state = impl()->state.load();
+	auto handshake = impl()->getWsHandshake();
+	return state != State::Connecting && handshake ? make_optional(handshake->path()) : nullopt;
+}
+
 } // namespace rtc
 } // namespace rtc
 
 
 #endif
 #endif

+ 46 - 0
src/websocketserver.cpp

@@ -0,0 +1,46 @@
+/**
+ * Copyright (c) 2021 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#if RTC_ENABLE_WEBSOCKET
+
+#include "websocketserver.hpp"
+#include "common.hpp"
+
+#include "impl/internals.hpp"
+#include "impl/websocketserver.hpp"
+
+namespace rtc {
+
+WebSocketServer::WebSocketServer() : WebSocketServer(Configuration()) {}
+
+WebSocketServer::WebSocketServer(Configuration config)
+    : CheshireCat<impl::WebSocketServer>(std::move(config)) {}
+
+WebSocketServer::~WebSocketServer() { impl()->stop(); }
+
+void WebSocketServer::stop() { impl()->stop(); }
+
+uint16_t WebSocketServer::port() const { return impl()->tcpServer->port(); }
+
+void WebSocketServer::onClient(std::function<void(shared_ptr<WebSocket>)> callback) {
+	impl()->clientCallback = callback;
+}
+
+} // namespace rtc
+
+#endif