Browse Source

Merge pull request #84 from paullouisageneau/srtp

SRTP transport support
Paul-Louis Ageneau 5 years ago
parent
commit
516a529952

+ 28 - 8
CMakeLists.txt

@@ -31,6 +31,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/configuration.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/datachannel.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/description.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/dtlssrtptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/dtlstransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/icetransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/init.cpp
@@ -38,6 +39,7 @@ set(LIBDATACHANNEL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/peerconnection.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/tls.cpp
 )
 
 set(LIBDATACHANNEL_WEBSOCKET_SOURCES
@@ -77,6 +79,7 @@ set(TESTS_SOURCES
 set(CMAKE_THREAD_PREFER_PTHREAD TRUE)
 set(THREADS_PREFER_PTHREAD_FLAG TRUE)
 find_package(Threads REQUIRED)
+find_package(SRTP)
 
 set(CMAKE_POLICY_DEFAULT_CMP0048 NEW)
 add_subdirectory(deps/plog)
@@ -135,39 +138,56 @@ if(WIN32)
 	target_link_libraries(datachannel-static PRIVATE wsock32 ws2_32) # winsock2
 endif()
 
+if(SRTP_FOUND)
+	if(NOT TARGET SRTP::SRTP)
+		add_library(SRTP::SRTP UNKNOWN IMPORTED)
+		set_target_properties(SRTP::SRTP PROPERTIES
+			INTERFACE_INCLUDE_DIRECTORIES ${SRTP_INCLUDE_DIRS}
+			IMPORTED_LINK_INTERFACE_LANGUAGES C
+			IMPORTED_LOCATION ${SRTP_LIBRARIES})
+	endif()
+	target_compile_definitions(datachannel PUBLIC RTC_ENABLE_MEDIA=1)
+	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_MEDIA=1)
+	target_link_libraries(datachannel PRIVATE SRTP::SRTP)
+	target_link_libraries(datachannel-static PRIVATE SRTP::SRTP)
+else()
+	target_compile_definitions(datachannel PUBLIC RTC_ENABLE_MEDIA=0)
+	target_compile_definitions(datachannel-static PUBLIC RTC_ENABLE_MEDIA=0)
+endif()
+
 if (USE_GNUTLS)
 	find_package(GnuTLS REQUIRED)
 	if(NOT TARGET GnuTLS::GnuTLS)
 		add_library(GnuTLS::GnuTLS UNKNOWN IMPORTED)
 		set_target_properties(GnuTLS::GnuTLS PROPERTIES
-			INTERFACE_INCLUDE_DIRECTORIES "${GNUTLS_INCLUDE_DIRS}"
-			INTERFACE_COMPILE_DEFINITIONS "${GNUTLS_DEFINITIONS}"
-			IMPORTED_LINK_INTERFACE_LANGUAGES "C"
-			IMPORTED_LOCATION "${GNUTLS_LIBRARIES}")
+			INTERFACE_INCLUDE_DIRECTORIES ${GNUTLS_INCLUDE_DIRS}
+			INTERFACE_COMPILE_DEFINITIONS ${GNUTLS_DEFINITIONS}
+			IMPORTED_LINK_INTERFACE_LANGUAGES C
+			IMPORTED_LOCATION ${GNUTLS_LIBRARIES})
 	endif()
 	target_compile_definitions(datachannel PRIVATE USE_GNUTLS=1)
-	target_link_libraries(datachannel PRIVATE GnuTLS::GnuTLS)
 	target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=1)
+	target_link_libraries(datachannel PRIVATE GnuTLS::GnuTLS)
 	target_link_libraries(datachannel-static PRIVATE GnuTLS::GnuTLS)
 else()
 	find_package(OpenSSL REQUIRED)
 	target_compile_definitions(datachannel PRIVATE USE_GNUTLS=0)
-	target_link_libraries(datachannel PRIVATE OpenSSL::SSL)
 	target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=0)
+	target_link_libraries(datachannel PRIVATE OpenSSL::SSL)
 	target_link_libraries(datachannel-static PRIVATE OpenSSL::SSL)
 endif()
 
 if (USE_JUICE)
 	add_subdirectory(deps/libjuice EXCLUDE_FROM_ALL)
 	target_compile_definitions(datachannel PRIVATE USE_JUICE=1)
-	target_link_libraries(datachannel PRIVATE LibJuice::LibJuiceStatic)
 	target_compile_definitions(datachannel-static PRIVATE USE_JUICE=1)
+	target_link_libraries(datachannel PRIVATE LibJuice::LibJuiceStatic)
 	target_link_libraries(datachannel-static PRIVATE LibJuice::LibJuiceStatic)
 else()
 	find_package(LibNice REQUIRED)
 	target_compile_definitions(datachannel PRIVATE USE_JUICE=0)
-	target_link_libraries(datachannel PRIVATE LibNice::LibNice)
 	target_compile_definitions(datachannel-static PRIVATE USE_JUICE=0)
+	target_link_libraries(datachannel PRIVATE LibNice::LibNice)
 	target_link_libraries(datachannel-static PRIVATE LibNice::LibNice)
 endif()
 

+ 1 - 0
Jamfile

@@ -10,6 +10,7 @@ lib libdatachannel
 	<cxxstd>17
 	<include>./include/rtc
 	<define>USE_JUICE=1
+	<define>RTC_ENABLE_MEDIA=0
 	<define>RTC_ENABLE_WEBSOCKET=0
 	<library>/libdatachannel//usrsctp
 	<library>/libdatachannel//juice

+ 8 - 0
Makefile

@@ -38,6 +38,14 @@ else
         LIBS+=glib-2.0 gobject-2.0 nice
 endif
 
+RTC_ENABLE_MEDIA ?= 0
+ifneq ($(RTC_ENABLE_MEDIA), 0)
+        CPPFLAGS+=-DRTC_ENABLE_MEDIA=1
+        LIBS+=srtp
+else
+        CPPFLAGS+=-DRTC_ENABLE_MEDIA=0
+endif
+
 RTC_ENABLE_WEBSOCKET ?= 1
 ifneq ($(RTC_ENABLE_WEBSOCKET), 0)
         CPPFLAGS+=-DRTC_ENABLE_WEBSOCKET=1

+ 2 - 0
README.md

@@ -28,6 +28,7 @@ Features:
 - Trickle ICE ([draft-ietf-ice-trickle-21](https://tools.ietf.org/html/draft-ietf-ice-trickle-21))
 - Multicast DNS candidates ([draft-ietf-rtcweb-mdns-ice-candidates-04](https://tools.ietf.org/html/draft-ietf-rtcweb-mdns-ice-candidates-04))
 - TURN relaying ([RFC5766](https://tools.ietf.org/html/rfc5766)) with [libnice](https://github.com/libnice/libnice) as ICE backend
+- SRTP media transport ([RFC3711](https://tools.ietf.org/html/rfc3711)) with [libSRTP](https://github.com/cisco/libsrtp)
 
 ### WebSocket
 
@@ -47,6 +48,7 @@ Features:
 
 Optional:
 - libnice: https://nice.freedesktop.org/ (substituable with libjuice)
+- libSRTP: https://github.com/cisco/libsrtp
 
 Submodules:
 - libjuice: https://github.com/paullouisageneau/libjuice

+ 73 - 0
cmake/Modules/FindSRTP.cmake

@@ -0,0 +1,73 @@
+############################################################################
+# FindSRTP.txt
+# Copyright (C) 2014  Belledonne Communications, Grenoble France
+#
+############################################################################
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; either version 2
+# of the License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+#
+############################################################################
+#
+# - Find the SRTP include file and library
+#
+#  SRTP_FOUND - system has SRTP
+#  SRTP_INCLUDE_DIRS - the SRTP include directory
+#  SRTP_LIBRARIES - The libraries needed to use SRTP
+
+set(_SRTP_ROOT_PATHS
+	${CMAKE_INSTALL_PREFIX}
+)
+
+find_path(SRTP2_INCLUDE_DIRS
+	NAMES srtp2/srtp.h
+	HINTS _SRTP_ROOT_PATHS
+	PATH_SUFFIXES include
+)
+
+if(SRTP2_INCLUDE_DIRS)
+	set(HAVE_SRTP_SRTP_H 1)
+	set(SRTP_INCLUDE_DIRS ${SRTP2_INCLUDE_DIRS})
+	set(SRTP_VERSION 2)
+	find_library(SRTP_LIBRARIES
+		NAMES srtp2
+		HINTS ${_SRTP_ROOT_PATHS}
+		PATH_SUFFIXES bin lib
+	)
+else()
+	find_path(SRTP_INCLUDE_DIRS
+		NAMES srtp/srtp.h
+		HINTS _SRTP_ROOT_PATHS
+		PATH_SUFFIXES include
+	)
+	if(SRTP_INCLUDE_DIRS)
+		set(HAVE_SRTP_SRTP_H 1)
+		set(SRTP_VERSION 1)
+	endif()
+	find_library(SRTP_LIBRARIES
+	NAMES srtp
+	HINTS ${_SRTP_ROOT_PATHS}
+	PATH_SUFFIXES bin lib
+)
+endif()
+
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(SRTP
+	DEFAULT_MSG
+	SRTP_INCLUDE_DIRS SRTP_LIBRARIES HAVE_SRTP_SRTP_H SRTP_VERSION
+)
+
+mark_as_advanced(SRTP_INCLUDE_DIRS SRTP_LIBRARIES HAVE_SRTP_SRTP_H SRTP_VERSION)
+

+ 1 - 1
include/rtc/datachannel.hpp

@@ -66,7 +66,7 @@ public:
 private:
 	void remoteClose();
 	void open(std::shared_ptr<SctpTransport> transport);
-	bool outgoing(mutable_message_ptr message);
+	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 	void processOpenMessage(message_ptr message);
 

+ 26 - 7
include/rtc/description.hpp

@@ -42,11 +42,11 @@ public:
 	string typeString() const;
 	Role role() const;
 	string roleString() const;
-	string mid() const;
+	string dataMid() const;
 	std::optional<string> fingerprint() const;
 	std::optional<uint16_t> sctpPort() const;
 	std::optional<size_t> maxMessageSize() const;
-	bool trickleEnabled() const;
+	bool ended() const;
 
 	void hintType(Type type);
 	void setFingerprint(string fingerprint);
@@ -57,21 +57,40 @@ public:
 	void endCandidates();
 	std::vector<Candidate> extractCandidates();
 
-	operator string() const;
+	bool hasMedia() const;
+	void addMedia(const Description &source);
 
+	operator string() const;
 	string generateSdp(const string &eol) const;
 
 private:
 	Type mType;
 	Role mRole;
 	string mSessionId;
-	string mMid;
 	string mIceUfrag, mIcePwd;
 	std::optional<string> mFingerprint;
-	std::optional<uint16_t> mSctpPort;
-	std::optional<size_t> mMaxMessageSize;
+
+	// Data
+	struct Data {
+		string mid;
+		std::optional<uint16_t> sctpPort;
+		std::optional<size_t> maxMessageSize;
+	};
+	Data mData;
+
+	// Media (non-data)
+	struct Media {
+		Media(const string &mline);
+		string type;
+		string description;
+		string mid;
+		std::vector<string> attributes;
+	};
+	std::map<string, Media> mMedia; // by mid
+
+	// Candidates
 	std::vector<Candidate> mCandidates;
-	bool mTrickle;
+	bool mEnded = false;
 
 	static Type stringToType(const string &typeString);
 	static string typeToString(Type type);

+ 4 - 0
include/rtc/include.hpp

@@ -19,6 +19,10 @@
 #ifndef RTC_INCLUDE_H
 #define RTC_INCLUDE_H
 
+#ifndef RTC_ENABLE_MEDIA
+#define RTC_ENABLE_MEDIA 1
+#endif
+
 #ifndef RTC_ENABLE_WEBSOCKET
 #define RTC_ENABLE_WEBSOCKET 1
 #endif

+ 1 - 2
include/rtc/message.hpp

@@ -42,8 +42,7 @@ struct Message : binary {
 	std::shared_ptr<Reliability> reliability;
 };
 
-using message_ptr = std::shared_ptr<const Message>;
-using mutable_message_ptr = std::shared_ptr<Message>;
+using message_ptr = std::shared_ptr<Message>;
 using message_callback = std::function<void(message_ptr message)>;
 
 constexpr auto message_size_func = [](const message_ptr &m) -> size_t {

+ 15 - 2
include/rtc/peerconnection.hpp

@@ -79,6 +79,7 @@ public:
 	std::optional<string> localAddress() const;
 	std::optional<string> remoteAddress() const;
 
+	void setLocalDescription(std::optional<Description> description = nullopt);
 	void setRemoteDescription(Description description);
 	void addRemoteCandidate(Candidate candidate);
 
@@ -91,14 +92,22 @@ public:
 	void onStateChange(std::function<void(State state)> callback);
 	void onGatheringStateChange(std::function<void(GatheringState state)> callback);
 
-	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
-
 	// Stats
 	void clearStats();
 	size_t bytesSent();
 	size_t bytesReceived();
 	std::optional<std::chrono::milliseconds> rtt();
 
+	// Media
+	bool hasMedia() const;
+	void sendMedia(const binary &packet);
+	void send(const byte *packet, size_t size);
+
+	void onMedia(std::function<void(const binary &packet)> callback);
+
+	// libnice only
+	bool getSelectedCandidatePair(CandidateInfo *local, CandidateInfo *remote);
+
 private:
 	std::shared_ptr<IceTransport> initIceTransport(Description::Role role);
 	std::shared_ptr<DtlsTransport> initDtlsTransport();
@@ -108,6 +117,7 @@ private:
 	void endLocalCandidates();
 	bool checkFingerprint(const std::string &fingerprint) const;
 	void forwardMessage(message_ptr message);
+	void forwardMedia(message_ptr message);
 	void forwardBufferedAmount(uint16_t stream, size_t amount);
 
 	std::shared_ptr<DataChannel> emplaceDataChannel(Description::Role role, const string &label,
@@ -127,6 +137,8 @@ private:
 
 	void resetCallbacks();
 
+	void outgoingMedia(message_ptr message);
+
 	const Configuration mConfig;
 	const future_certificate_ptr mCertificate;
 
@@ -150,6 +162,7 @@ private:
 	synchronized_callback<const Candidate &> mLocalCandidateCallback;
 	synchronized_callback<State> mStateChangeCallback;
 	synchronized_callback<GatheringState> mGatheringStateChangeCallback;
+	synchronized_callback<const binary &> mMediaCallback;
 };
 
 } // namespace rtc

+ 4 - 0
include/rtc/rtc.h

@@ -27,6 +27,10 @@ extern "C" {
 
 // libdatachannel C API
 
+#ifndef RTC_ENABLE_MEDIA
+#define RTC_ENABLE_MEDIA 1
+#endif
+
 #ifndef RTC_ENABLE_WEBSOCKET
 #define RTC_ENABLE_WEBSOCKET 1
 #endif

+ 1 - 1
include/rtc/websocket.hpp

@@ -67,7 +67,7 @@ public:
 private:
 	bool changeState(State state);
 	void remoteClose();
-	bool outgoing(mutable_message_ptr message);
+	bool outgoing(message_ptr message);
 	void incoming(message_ptr message);
 
 	std::shared_ptr<TcpTransport> initTcpTransport();

+ 37 - 93
src/certificate.cpp

@@ -31,93 +31,43 @@ using std::unique_ptr;
 
 #if USE_GNUTLS
 
-#include <gnutls/crypto.h>
-
-namespace {
-
-void check_gnutls(int ret, const string &message = "GnuTLS error") {
-	if (ret != GNUTLS_E_SUCCESS)
-		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
-}
-
-gnutls_certificate_credentials_t *create_credentials() {
-	auto creds = new gnutls_certificate_credentials_t;
-	check_gnutls(gnutls_certificate_allocate_credentials(creds));
-	return creds;
-}
-
-void delete_credentials(gnutls_certificate_credentials_t *creds) {
-	gnutls_certificate_free_credentials(*creds);
-	delete creds;
-}
-
-gnutls_x509_crt_t *create_crt() {
-	auto crt = new gnutls_x509_crt_t;
-	check_gnutls(gnutls_x509_crt_init(crt));
-	return crt;
-}
-
-void delete_crt(gnutls_x509_crt_t *crt) {
-	gnutls_x509_crt_deinit(*crt);
-	delete crt;
-}
-
-gnutls_x509_privkey_t *create_privkey() {
-	auto privkey = new gnutls_x509_privkey_t;
-	check_gnutls(gnutls_x509_privkey_init(privkey));
-	return privkey;
-}
-
-void delete_privkey(gnutls_x509_privkey_t *privkey) {
-	gnutls_x509_privkey_deinit(*privkey);
-	delete privkey;
-}
-
-gnutls_datum_t make_datum(char *data, size_t size) {
-	gnutls_datum_t datum;
-	datum.data = reinterpret_cast<unsigned char *>(data);
-	datum.size = size;
-	return datum;
-}
-
-} // namespace
-
 namespace rtc {
 
 Certificate::Certificate(string crt_pem, string key_pem)
-    : mCredentials(create_credentials(), delete_credentials) {
+    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials) {
 
-	gnutls_datum_t crt_datum = make_datum(crt_pem.data(), crt_pem.size());
-	gnutls_datum_t key_datum = make_datum(key_pem.data(), key_pem.size());
+	gnutls_datum_t crt_datum = gnutls::make_datum(crt_pem.data(), crt_pem.size());
+	gnutls_datum_t key_datum = gnutls::make_datum(key_pem.data(), key_pem.size());
 
-	check_gnutls(gnutls_certificate_set_x509_key_mem(*mCredentials, &crt_datum, &key_datum,
-	                                                 GNUTLS_X509_FMT_PEM),
-	             "Unable to import PEM");
+	gnutls::check(gnutls_certificate_set_x509_key_mem(*mCredentials, &crt_datum, &key_datum,
+	                                                  GNUTLS_X509_FMT_PEM),
+	              "Unable to import PEM");
 
-	auto create_crt_list = [this]() -> gnutls_x509_crt_t * {
+	auto new_crt_list = [this]() -> gnutls_x509_crt_t * {
 		gnutls_x509_crt_t *crt_list = nullptr;
 		unsigned int crt_list_size = 0;
-		check_gnutls(gnutls_certificate_get_x509_crt(*mCredentials, 0, &crt_list, &crt_list_size));
+		gnutls::check(gnutls_certificate_get_x509_crt(*mCredentials, 0, &crt_list, &crt_list_size));
 		assert(crt_list_size == 1);
 		return crt_list;
 	};
 
-	auto delete_crt_list = [](gnutls_x509_crt_t *crt_list) {
+	auto free_crt_list = [](gnutls_x509_crt_t *crt_list) {
 		gnutls_x509_crt_deinit(crt_list[0]);
 		gnutls_free(crt_list);
 	};
 
-	std::unique_ptr<gnutls_x509_crt_t, decltype(delete_crt_list)> crt_list(create_crt_list(),
-	                                                                       delete_crt_list);
+	std::unique_ptr<gnutls_x509_crt_t, decltype(free_crt_list)> crt_list(new_crt_list(),
+	                                                                     free_crt_list);
 
 	mFingerprint = make_fingerprint(*crt_list);
 }
 
 Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
-    : mCredentials(create_credentials(), delete_credentials), mFingerprint(make_fingerprint(crt)) {
+    : mCredentials(gnutls::new_credentials(), gnutls::free_credentials),
+      mFingerprint(make_fingerprint(crt)) {
 
-	check_gnutls(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
-	             "Unable to set certificate and key pair in credentials");
+	gnutls::check(gnutls_certificate_set_x509_key(*mCredentials, &crt, 1, privkey),
+	              "Unable to set certificate and key pair in credentials");
 }
 
 gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
@@ -128,8 +78,8 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
 	const size_t size = 32;
 	unsigned char buffer[size];
 	size_t len = size;
-	check_gnutls(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
-	             "X509 fingerprint error");
+	gnutls::check(gnutls_x509_crt_get_fingerprint(crt, GNUTLS_DIG_SHA256, buffer, &len),
+	              "X509 fingerprint error");
 
 	std::ostringstream oss;
 	oss << std::hex << std::uppercase << std::setfill('0');
@@ -144,13 +94,13 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
 namespace {
 
 certificate_ptr make_certificate_impl(string commonName) {
-	std::unique_ptr<gnutls_x509_crt_t, decltype(&delete_crt)> crt(create_crt(), delete_crt);
-	std::unique_ptr<gnutls_x509_privkey_t, decltype(&delete_privkey)> privkey(create_privkey(),
-	                                                                          delete_privkey);
+	using namespace gnutls;
+	unique_ptr<gnutls_x509_crt_t, decltype(&free_crt)> crt(new_crt(), free_crt);
+	unique_ptr<gnutls_x509_privkey_t, decltype(&free_privkey)> privkey(new_privkey(), free_privkey);
 
 	const unsigned int bits = gnutls_sec_param_to_pk_bits(GNUTLS_PK_RSA, GNUTLS_SEC_PARAM_HIGH);
-	check_gnutls(gnutls_x509_privkey_generate(*privkey, GNUTLS_PK_RSA, bits, 0),
-	             "Unable to generate key pair");
+	gnutls::check(gnutls_x509_privkey_generate(*privkey, GNUTLS_PK_RSA, bits, 0),
+	              "Unable to generate key pair");
 
 	using namespace std::chrono;
 	auto now = time_point_cast<seconds>(system_clock::now());
@@ -166,8 +116,8 @@ certificate_ptr make_certificate_impl(string commonName) {
 	gnutls_rnd(GNUTLS_RND_NONCE, serial, serialSize);
 	gnutls_x509_crt_set_serial(*crt, serial, serialSize);
 
-	check_gnutls(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0),
-	             "Unable to auto-sign certificate");
+	gnutls::check(gnutls_x509_crt_sign2(*crt, *crt, *privkey, GNUTLS_DIG_SHA256, 0),
+	              "Unable to auto-sign certificate");
 
 	return std::make_shared<Certificate>(*crt, *privkey);
 }
@@ -176,30 +126,24 @@ certificate_ptr make_certificate_impl(string commonName) {
 
 } // namespace rtc
 
-#else
-
-#include <openssl/err.h>
-#include <openssl/pem.h>
-#include <openssl/ssl.h>
+#else // USE_GNUTLS==0
 
 namespace rtc {
 
 Certificate::Certificate(string crt_pem, string key_pem) {
-    BIO *bio;
-
-    bio = BIO_new(BIO_s_mem());
-    BIO_write(bio, crt_pem.data(), crt_pem.size());
-    mX509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, 0, 0), X509_free);
-    BIO_free(bio);
-    if (!mX509)
-      throw std::invalid_argument("Unable to import certificate PEM");
-
-    bio = BIO_new(BIO_s_mem());
-    BIO_write(bio, key_pem.data(), key_pem.size());
+	BIO *bio = BIO_new(BIO_s_mem());
+	BIO_write(bio, crt_pem.data(), crt_pem.size());
+	mX509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, 0, 0), X509_free);
+	BIO_free(bio);
+	if (!mX509)
+		throw std::invalid_argument("Unable to import certificate PEM");
+
+	bio = BIO_new(BIO_s_mem());
+	BIO_write(bio, key_pem.data(), key_pem.size());
 	mPKey = shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio, nullptr, 0, 0), EVP_PKEY_free);
-    BIO_free(bio);
-    if (!mPKey)
-      throw std::invalid_argument("Unable to import PEM key PEM");
+	BIO_free(bio);
+	if (!mPKey)
+		throw std::invalid_argument("Unable to import PEM key PEM");
 
 	mFingerprint = make_fingerprint(mX509.get());
 }

+ 1 - 6
src/certificate.hpp

@@ -20,16 +20,11 @@
 #define RTC_CERTIFICATE_H
 
 #include "include.hpp"
+#include "tls.hpp"
 
 #include <future>
 #include <tuple>
 
-#if USE_GNUTLS
-#include <gnutls/x509.h>
-#else
-#include <openssl/x509.h>
-#endif
-
 namespace rtc {
 
 class Certificate {

+ 4 - 4
src/datachannel.cpp

@@ -154,13 +154,13 @@ bool DataChannel::isOpen(void) const { return mIsOpen; }
 bool DataChannel::isClosed(void) const { return mIsClosed; }
 
 size_t DataChannel::maxMessageSize() const {
-	size_t max = DEFAULT_MAX_MESSAGE_SIZE;
+	size_t remoteMax = DEFAULT_MAX_MESSAGE_SIZE;
 	if (auto pc = mPeerConnection.lock())
 		if (auto description = pc->remoteDescription())
 			if (auto maxMessageSize = description->maxMessageSize())
-				return *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
+				remoteMax = *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
 
-	return std::min(max, LOCAL_MAX_MESSAGE_SIZE);
+	return std::min(remoteMax, LOCAL_MAX_MESSAGE_SIZE);
 }
 
 size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
@@ -196,7 +196,7 @@ void DataChannel::open(shared_ptr<SctpTransport> transport) {
 	transport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
 }
 
-bool DataChannel::outgoing(mutable_message_ptr message) {
+bool DataChannel::outgoing(message_ptr message) {
 	if (mIsClosed)
 		throw std::runtime_error("DataChannel is closed");
 

+ 147 - 52
src/description.cpp

@@ -29,7 +29,7 @@ using std::string;
 
 namespace {
 
-inline bool hasprefix(const string &str, const string &prefix) {
+inline bool match_prefix(const string &str, const string &prefix) {
 	return str.size() >= prefix.size() &&
 	       std::mismatch(prefix.begin(), prefix.end(), str.begin()).first == prefix.end();
 }
@@ -50,7 +50,8 @@ Description::Description(const string &sdp, const string &typeString)
 Description::Description(const string &sdp, Type type) : Description(sdp, type, Role::ActPass) {}
 
 Description::Description(const string &sdp, Type type, Role role)
-    : mType(Type::Unspec), mRole(role), mMid("0"), mIceUfrag(""), mIcePwd(""), mTrickle(true) {
+    : mType(Type::Unspec), mRole(role) {
+	mData.mid = "data";
 	hintType(type);
 
 	auto seed = std::chrono::system_clock::now().time_since_epoch().count();
@@ -59,37 +60,79 @@ Description::Description(const string &sdp, Type type, Role role)
 	mSessionId = std::to_string(uniform(generator));
 
 	std::istringstream ss(sdp);
-	string line;
-	while (std::getline(ss, line)) {
+	std::optional<Media> currentMedia;
+
+	bool finished;
+	do {
+		string line;
+		finished = !std::getline(ss, line) && line.empty();
 		trim_end(line);
-		if (hasprefix(line, "a=setup:")) {
-			const string setup = line.substr(line.find(':') + 1);
-			if (setup == "active")
-				mRole = Role::Active;
-			else if (setup == "passive")
-				mRole = Role::Passive;
-			else
-				mRole = Role::ActPass;
-		} else if (hasprefix(line, "a=mid:")) {
-			mMid = line.substr(line.find(':') + 1);
-		} else if (hasprefix(line, "a=fingerprint:sha-256")) {
-			mFingerprint = line.substr(line.find(' ') + 1);
-			std::transform(mFingerprint->begin(), mFingerprint->end(), mFingerprint->begin(),
-						   [](char c) { return std::toupper(c); });
-		} else if (hasprefix(line, "a=ice-ufrag")) {
-			mIceUfrag = line.substr(line.find(':') + 1);
-		} else if (hasprefix(line, "a=ice-pwd")) {
-			mIcePwd = line.substr(line.find(':') + 1);
-		} else if (hasprefix(line, "a=sctp-port")) {
-			mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
-		} else if (hasprefix(line, "a=max-message-size")) {
-			mMaxMessageSize = size_t(std::stoul(line.substr(line.find(':') + 1)));
-		} else if (hasprefix(line, "a=candidate")) {
-			addCandidate(Candidate(line.substr(2), mMid));
-		} else if (hasprefix(line, "a=end-of-candidates")) {
-			mTrickle = false;
+
+		// Media description line (aka m-line)
+		if (finished || match_prefix(line, "m=")) {
+			if (currentMedia) {
+				if (!currentMedia->mid.empty()) {
+					if (currentMedia->type == "application")
+						mData.mid = currentMedia->mid;
+					else
+						mMedia.emplace(currentMedia->mid, std::move(*currentMedia));
+
+				} else if (line.find(" ICE/SDP") != string::npos) {
+					PLOG_WARNING << "SDP \"m=\" line has no corresponding mid, ignoring";
+				}
+			}
+			if (!finished)
+				currentMedia.emplace(Media(line.substr(2)));
+
+			// Attribute line
+		} else if (match_prefix(line, "a=")) {
+			string attr = line.substr(2);
+
+			string key, value;
+			if (size_t separator = attr.find(':'); separator != string::npos) {
+				key = attr.substr(0, separator);
+				value = attr.substr(separator + 1);
+			} else {
+				key = attr;
+			}
+
+			if (key == "mid") {
+				if (currentMedia)
+					currentMedia->mid = value;
+
+			} else if (key == "setup") {
+				if (value == "active")
+					mRole = Role::Active;
+				else if (value == "passive")
+					mRole = Role::Passive;
+				else
+					mRole = Role::ActPass;
+
+			} else if (key == "fingerprint") {
+				if (match_prefix(value, "sha-256 ")) {
+					mFingerprint = value.substr(8);
+					std::transform(mFingerprint->begin(), mFingerprint->end(),
+					               mFingerprint->begin(), [](char c) { return std::toupper(c); });
+				} else {
+					PLOG_WARNING << "Unknown SDP fingerprint type: " << value;
+				}
+			} else if (key == "ice-ufrag") {
+				mIceUfrag = value;
+			} else if (key == "ice-pwd") {
+				mIcePwd = value;
+			} else if (key == "sctp-port") {
+				mData.sctpPort = uint16_t(std::stoul(value));
+			} else if (key == "max-message-size") {
+				mData.maxMessageSize = size_t(std::stoul(value));
+			} else if (key == "candidate") {
+				addCandidate(Candidate(attr, currentMedia ? currentMedia->mid : mData.mid));
+			} else if (key == "end-of-candidates") {
+				mEnded = true;
+			} else if (currentMedia) {
+				currentMedia->attributes.emplace_back(line.substr(2));
+			}
 		}
-	}
+	} while (!finished);
 }
 
 Description::Type Description::type() const { return mType; }
@@ -100,15 +143,15 @@ Description::Role Description::role() const { return mRole; }
 
 string Description::roleString() const { return roleToString(mRole); }
 
-string Description::mid() const { return mMid; }
+string Description::dataMid() const { return mData.mid; }
 
 std::optional<string> Description::fingerprint() const { return mFingerprint; }
 
-std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
+std::optional<uint16_t> Description::sctpPort() const { return mData.sctpPort; }
 
-std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
+std::optional<size_t> Description::maxMessageSize() const { return mData.maxMessageSize; }
 
-bool Description::trickleEnabled() const { return mTrickle; }
+bool Description::ended() const { return mEnded; }
 
 void Description::hintType(Type type) {
 	if (mType == Type::Unspec) {
@@ -122,23 +165,33 @@ void Description::setFingerprint(string fingerprint) {
 	mFingerprint.emplace(std::move(fingerprint));
 }
 
-void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
+void Description::setSctpPort(uint16_t port) { mData.sctpPort.emplace(port); }
 
-void Description::setMaxMessageSize(size_t size) { mMaxMessageSize.emplace(size); }
+void Description::setMaxMessageSize(size_t size) { mData.maxMessageSize.emplace(size); }
 
 void Description::addCandidate(Candidate candidate) {
 	mCandidates.emplace_back(std::move(candidate));
 }
 
-void Description::endCandidates() { mTrickle = false; }
+void Description::endCandidates() { mEnded = true; }
 
 std::vector<Candidate> Description::extractCandidates() {
 	std::vector<Candidate> result;
 	std::swap(mCandidates, result);
-	mTrickle = true;
+	mEnded = false;
 	return result;
 }
 
+bool Description::hasMedia() const { return !mMedia.empty(); }
+
+void Description::addMedia(const Description &source) {
+	for (auto [mid, media] : source.mMedia)
+		if (mid != mData.mid)
+			mMedia.emplace(mid, media);
+		else
+			PLOG_WARNING << "Media mid \"" << mid << "\" is the same as data mid, ignoring";
+}
+
 Description::operator string() const { return generateSdp("\r\n"); }
 
 string Description::generateSdp(const string &eol) const {
@@ -146,36 +199,78 @@ string Description::generateSdp(const string &eol) const {
 		throw std::logic_error("Fingerprint must be set to generate an SDP string");
 
 	std::ostringstream sdp;
+
+	// Header
 	sdp << "v=0" << eol;
 	sdp << "o=- " << mSessionId << " 0 IN IP4 127.0.0.1" << eol;
 	sdp << "s=-" << eol;
 	sdp << "t=0 0" << eol;
-	sdp << "a=group:BUNDLE 0" << eol;
-	sdp << "m=application 9 UDP/DTLS/SCTP webrtc-datachannel" << eol;
+
+	// Bundle
+	// see Negotiating Media Multiplexing Using the Session Description Protocol
+	// https://tools.ietf.org/html/draft-ietf-mmusic-sdp-bundle-negotiation-54
+	sdp << "a=group:BUNDLE";
+	for (const auto &[mid, _] : mMedia)
+		sdp << " " << mid;
+	sdp << " " << mData.mid << eol;
+
+	// Data
+	const string dataDescription = "UDP/DTLS/SCTP webrtc-datachannel";
+	sdp << "m=application" << ' ' << (!mMedia.empty() ? 0 : 9) << ' ' << dataDescription << eol;
 	sdp << "c=IN IP4 0.0.0.0" << eol;
+	if (!mMedia.empty())
+		sdp << "a=bundle-only" << eol;
+	sdp << "a=mid:" << mData.mid << eol;
+	if (mData.sctpPort)
+		sdp << "a=sctp-port:" << *mData.sctpPort << eol;
+	if (mData.maxMessageSize)
+		sdp << "a=max-message-size:" << *mData.maxMessageSize << eol;
+
+	// Non-data media
+	if (!mMedia.empty()) {
+		// Lip-sync
+		sdp << "a=group:LS";
+		for (const auto &[mid, _] : mMedia)
+			sdp << " " << mid;
+		sdp << eol;
+
+		// Descriptions and attributes
+		for (const auto &[_, media] : mMedia) {
+			sdp << "m=" << media.type << ' ' << 0 << ' ' << media.description << eol;
+			sdp << "c=IN IP4 0.0.0.0" << eol;
+			sdp << "a=bundle-only" << eol;
+			sdp << "a=mid:" << media.mid << eol;
+			for (const auto &attr : media.attributes)
+				sdp << "a=" << attr << eol;
+		}
+	}
+
+	// Common
+	sdp << "a=ice-options:trickle" << eol;
 	sdp << "a=ice-ufrag:" << mIceUfrag << eol;
 	sdp << "a=ice-pwd:" << mIcePwd << eol;
-	if (mTrickle)
-		sdp << "a=ice-options:trickle" << eol;
-	sdp << "a=mid:" << mMid << eol;
 	sdp << "a=setup:" << roleToString(mRole) << eol;
 	sdp << "a=dtls-id:1" << eol;
 	if (mFingerprint)
 		sdp << "a=fingerprint:sha-256 " << *mFingerprint << eol;
-	if (mSctpPort)
-		sdp << "a=sctp-port:" << *mSctpPort << eol;
-	if (mMaxMessageSize)
-		sdp << "a=max-message-size:" << *mMaxMessageSize << eol;
-	for (const auto &candidate : mCandidates) {
-		sdp << string(candidate) << eol;
-	}
 
-	if (!mTrickle)
+	// Candidates
+	for (const auto &candidate : mCandidates)
+		sdp << string(candidate) << eol;
+	if (mEnded)
 		sdp << "a=end-of-candidates" << eol;
 
 	return sdp.str();
 }
 
+Description::Media::Media(const string &mline) {
+	size_t p = mline.find(' ');
+	this->type = mline.substr(0, p);
+	if (p != string::npos)
+		if (size_t q = mline.find(' ', p + 1); q != string::npos)
+			this->description = mline.substr(q + 1);
+}
+
 Description::Type Description::stringToType(const string &typeString) {
 	if (typeString == "offer")
 		return Type::Offer;

+ 205 - 0
src/dtlssrtptransport.cpp

@@ -0,0 +1,205 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "dtlssrtptransport.hpp"
+#include "tls.hpp"
+
+#if RTC_ENABLE_MEDIA
+
+#include <cstring>
+#include <exception>
+
+using std::shared_ptr;
+using std::to_integer;
+using std::to_string;
+
+namespace rtc {
+
+void DtlsSrtpTransport::Init() { srtp_init(); }
+
+void DtlsSrtpTransport::Cleanup() { srtp_shutdown(); }
+
+DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
+                                     shared_ptr<Certificate> certificate,
+                                     verifier_callback verifierCallback,
+                                     message_callback srtpRecvCallback,
+                                     state_callback stateChangeCallback)
+    : DtlsTransport(lower, certificate, std::move(verifierCallback),
+                    std::move(stateChangeCallback)),
+      mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback
+
+	PLOG_DEBUG << "Initializing SRTP transport";
+
+#if USE_GNUTLS
+	PLOG_DEBUG << "Initializing DTLS-SRTP transport (GnuTLS)";
+	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
+	              "Failed to set SRTP profile");
+#else
+	PLOG_DEBUG << "Initializing DTLS-SRTP transport (OpenSSL)";
+	openssl::check(SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"),
+	               "Failed to set SRTP profile");
+#endif
+}
+
+DtlsSrtpTransport::~DtlsSrtpTransport() {
+	stop();
+
+	if (mCreated)
+		srtp_dealloc(mSrtp);
+}
+
+bool DtlsSrtpTransport::send(message_ptr message) {
+	if (!message)
+		return false;
+
+	int size = message->size();
+	PLOG_VERBOSE << "Send size=" << size;
+
+	// srtp_protect() assumes that it can write SRTP_MAX_TRAILER_LEN (for the authentication tag)
+	// into the location in memory immediately following the RTP packet.
+	message->resize(size + SRTP_MAX_TRAILER_LEN);
+	if (srtp_err_status_t err = srtp_protect(mSrtp, message->data(), &size)) {
+		if (err == srtp_err_status_replay_fail)
+			throw std::runtime_error("SRTP packet is a replay");
+		else
+			throw std::runtime_error("SRTP protect error, status=" +
+			                         to_string(static_cast<int>(err)));
+	}
+	PLOG_VERBOSE << "Protected SRTP packet, size=" << size;
+	message->resize(size);
+	outgoing(message);
+	return true;
+}
+
+void DtlsSrtpTransport::incoming(message_ptr message) {
+	int size = message->size();
+	if (size == 0)
+		return;
+
+	// RFC 5764 5.1.2. Reception
+	// The process for demultiplexing a packet is as follows. The receiver looks at the first byte
+	// of the packet. [...] If the value is in between 128 and 191 (inclusive), then the packet is
+	// RTP (or RTCP [...]). If the value is between 20 and 63 (inclusive), the packet is DTLS.
+	uint8_t value = to_integer<uint8_t>(*message->begin());
+
+	if (value >= 128 && value <= 192) {
+		PLOG_VERBOSE << "Incoming DTLS packet, size=" << size;
+		DtlsTransport::incoming(message);
+	} else if (value >= 20 && value <= 64) {
+		PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
+
+		if (srtp_err_status_t err = srtp_unprotect(mSrtp, message->data(), &size)) {
+			if (err == srtp_err_status_replay_fail)
+				PLOG_WARNING << "Incoming SRTP packet is a replay";
+			else
+				PLOG_WARNING << "SRTP unprotect error, status=" << err;
+			return;
+		}
+		PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+		message->resize(size);
+		mSrtpRecvCallback(message);
+
+	} else {
+		PLOG_WARNING << "Unknown packet type, value=" << value << ", size=" << size;
+	}
+}
+
+void DtlsSrtpTransport::postHandshake() {
+	if (mCreated)
+		return;
+
+	srtp_policy_t inbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+	inbound.ssrc.type = ssrc_any_inbound;
+
+	srtp_policy_t outbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+	outbound.ssrc.type = ssrc_any_outbound;
+
+	const size_t materialLen = SRTP_AES_ICM_128_KEY_LEN_WSALT * 2;
+	unsigned char material[materialLen];
+	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
+
+#if USE_GNUTLS
+	gnutls_datum_t clientKeyDatum, clientSaltDatum, serverKeyDatum, serverSaltDatum;
+	gnutls::check(gnutls_srtp_get_keys(mSession, material, materialLen, &clientKeyDatum,
+	                                   &clientSaltDatum, &serverKeyDatum, &serverSaltDatum),
+	              "Failed to derive SRTP keys");
+
+	if (clientKeyDatum.size != SRTP_AES_128_KEY_LEN)
+		throw std::logic_error("Unexpected SRTP master key length: " +
+		                       to_string(clientKeyDatum.size));
+	if (clientSaltDatum.size != SRTP_SALT_LEN)
+		throw std::logic_error("Unexpected SRTP salt length: " + to_string(clientSaltDatum.size));
+	if (serverKeyDatum.size != SRTP_AES_128_KEY_LEN)
+		throw std::logic_error("Unexpected SRTP master key length: " +
+		                       to_string(serverKeyDatum.size));
+	if (serverSaltDatum.size != SRTP_SALT_LEN)
+		throw std::logic_error("Unexpected SRTP salt size: " + to_string(serverSaltDatum.size));
+
+	clientKey = reinterpret_cast<const unsigned char *>(clientKeyDatum.data);
+	clientSalt = reinterpret_cast<const unsigned char *>(clientSaltDatum.data);
+
+	serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
+	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
+#else
+	// This provides the client write master key, the server write master key, the client write
+	// master salt and the server write master salt in that order.
+	const string label = "EXTRACTOR-dtls_srtp";
+	openssl::check(SSL_export_keying_material(mSsl, material, materialLen, label.c_str(),
+	                                          label.size(), nullptr, 0, 0),
+	               "Failed to derive SRTP keys");
+
+	clientKey = material;
+	clientSalt = clientKey + SRTP_AES_128_KEY_LEN;
+
+	serverKey = material + SRTP_AES_ICM_128_KEY_LEN_WSALT;
+	serverSalt = serverSalt + SRTP_AES_128_KEY_LEN;
+#endif
+
+	unsigned char clientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+	std::memcpy(clientSessionKey, clientKey, SRTP_AES_128_KEY_LEN);
+	std::memcpy(clientSessionKey + SRTP_AES_128_KEY_LEN, clientSalt, SRTP_SALT_LEN);
+
+	unsigned char serverSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+	std::memcpy(serverSessionKey, serverKey, SRTP_AES_128_KEY_LEN);
+	std::memcpy(serverSessionKey + SRTP_AES_128_KEY_LEN, serverSalt, SRTP_SALT_LEN);
+
+	if (mIsClient) {
+		inbound.key = serverSessionKey;
+		outbound.key = clientSessionKey;
+	} else {
+		inbound.key = clientSessionKey;
+		outbound.key = serverSessionKey;
+	}
+
+	srtp_policy_t *policies = &inbound;
+	inbound.next = &outbound;
+	outbound.next = nullptr;
+
+	if (srtp_err_status_t err = srtp_create(&mSrtp, policies))
+		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
+
+	mCreated = true;
+}
+
+} // namespace rtc
+
+#endif

+ 57 - 0
src/dtlssrtptransport.hpp

@@ -0,0 +1,57 @@
+/**
+ * Copyright (c) 2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_DTLS_SRTP_TRANSPORT_H
+#define RTC_DTLS_SRTP_TRANSPORT_H
+
+#include "dtlstransport.hpp"
+#include "include.hpp"
+
+#if RTC_ENABLE_MEDIA
+
+#include <srtp2/srtp.h>
+
+namespace rtc {
+
+class DtlsSrtpTransport final : public DtlsTransport {
+public:
+	static void Init();
+	static void Cleanup();
+
+	DtlsSrtpTransport(std::shared_ptr<IceTransport> lower, std::shared_ptr<Certificate> certificate,
+	                  verifier_callback verifierCallback, message_callback srtpRecvCallback,
+	                  state_callback stateChangeCallback);
+	~DtlsSrtpTransport();
+
+	bool send(message_ptr message) override;
+
+private:
+	void incoming(message_ptr message) override;
+	void postHandshake() override;
+
+	message_callback mSrtpRecvCallback;
+
+	srtp_t mSrtp;
+	bool mCreated = false;
+};
+
+} // namespace rtc
+
+#endif
+
+#endif

+ 42 - 94
src/dtlstransport.cpp

@@ -31,28 +31,10 @@ using std::string;
 using std::unique_ptr;
 using std::weak_ptr;
 
-#if USE_GNUTLS
-
-#include <gnutls/dtls.h>
-
-namespace {
-
-static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
-	if (ret < 0) {
-		if (!gnutls_error_is_fatal(ret)) {
-			PLOG_INFO << gnutls_strerror(ret);
-			return false;
-		}
-		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
-		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
-	}
-	return true;
-}
-
-} // namespace
-
 namespace rtc {
 
+#if USE_GNUTLS
+
 void DtlsTransport::Init() {
 	// Nothing to do
 }
@@ -64,13 +46,16 @@ void DtlsTransport::Cleanup() {
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr certificate,
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
-      mVerifierCallback(std::move(verifierCallback)) {
+      mVerifierCallback(std::move(verifierCallback)),
+      mIsClient(lower->role() == Description::Role::Active) {
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
-	bool active = lower->role() == Description::Role::Active;
-	unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
-	check_gnutls(gnutls_init(&mSession, flags));
+	gnutls_certificate_credentials_t creds = mCertificate->credentials();
+	gnutls_certificate_set_verify_function(creds, CertificateCallback);
+
+	unsigned int flags = GNUTLS_DATAGRAM | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER);
+	gnutls::check(gnutls_init(&mSession, flags));
 
 	try {
 		// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
@@ -78,12 +63,10 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		// See https://tools.ietf.org/html/rfc8261#section-5
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL:+COMP-NULL";
 		const char *err_pos = NULL;
-		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
-		             "Failed to set TLS priorities");
+		gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
+		              "Failed to set TLS priorities");
 
-		gnutls_certificate_set_verify_function(mCertificate->credentials(), CertificateCallback);
-		check_gnutls(
-		    gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials()));
+		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, creds));
 
 		gnutls_dtls_set_timeouts(mSession,
 		                         1000,   // 1s retransmission timeout recommended by RFC 6347
@@ -135,7 +118,7 @@ bool DtlsTransport::send(message_ptr message) {
 	if (ret == GNUTLS_E_LARGE_PACKET)
 		return false;
 
-	return check_gnutls(ret);
+	return gnutls::check(ret);
 }
 
 void DtlsTransport::incoming(message_ptr message) {
@@ -148,6 +131,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
+void DtlsTransport::postHandshake() {
+	// Dummy
+}
+
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 
@@ -164,7 +151,7 @@ void DtlsTransport::runRecvLoop() {
 				throw std::runtime_error("MTU is too low");
 
 		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
-		         !check_gnutls(ret, "DTLS handshake failed"));
+		         !gnutls::check(ret, "DTLS handshake failed"));
 
 		// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
 		// See https://tools.ietf.org/html/rfc8261#section-5
@@ -180,6 +167,7 @@ void DtlsTransport::runRecvLoop() {
 	try {
 		PLOG_INFO << "DTLS handshake finished";
 		changeState(State::Connected);
+		postHandshake();
 
 		const size_t bufferSize = maxMtu;
 		char buffer[bufferSize];
@@ -196,7 +184,7 @@ void DtlsTransport::runRecvLoop() {
 				break;
 			}
 
-			if (check_gnutls(ret)) {
+			if (gnutls::check(ret)) {
 				if (ret == 0) {
 					// Closed
 					PLOG_DEBUG << "DTLS connection cleanly closed";
@@ -232,7 +220,7 @@ int DtlsTransport::CertificateCallback(gnutls_session_t session) {
 	}
 
 	gnutls_x509_crt_t crt;
-	check_gnutls(gnutls_x509_crt_init(&crt));
+	gnutls::check(gnutls_x509_crt_init(&crt));
 	int ret = gnutls_x509_crt_import(crt, &array[0], GNUTLS_X509_FMT_DER);
 	if (ret != GNUTLS_E_SUCCESS) {
 		gnutls_x509_crt_deinit(crt);
@@ -277,62 +265,17 @@ int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms)
 	return !t->mIncomingQueue.empty() ? 1 : 0;
 }
 
-} // namespace rtc
-
 #else // USE_GNUTLS==0
 
-#include <openssl/bio.h>
-#include <openssl/ec.h>
-#include <openssl/err.h>
-#include <openssl/ssl.h>
-
-namespace {
-
-const int BIO_EOF = -1;
-
-string openssl_error_string(unsigned long err) {
-	const size_t bufferSize = 256;
-	char buffer[bufferSize];
-	ERR_error_string_n(err, buffer, bufferSize);
-	return string(buffer);
-}
-
-bool check_openssl(int success, const string &message = "OpenSSL error") {
-	if (success)
-		return true;
-
-	string str = openssl_error_string(ERR_get_error());
-	PLOG_ERROR << message << ": " << str;
-	throw std::runtime_error(message + ": " + str);
-}
-
-bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
-	if (ret == BIO_EOF)
-		return true;
-
-	unsigned long err = SSL_get_error(ssl, ret);
-	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
-		return true;
-	}
-	if (err == SSL_ERROR_ZERO_RETURN) {
-		PLOG_DEBUG << "DTLS connection cleanly closed";
-		return false;
-	}
-	string str = openssl_error_string(err);
-	PLOG_ERROR << str;
-	throw std::runtime_error(message + ": " + str);
-}
-
-} // namespace
-
-namespace rtc {
-
 BIO_METHOD *DtlsTransport::BioMethods = NULL;
 int DtlsTransport::TransportExIndex = -1;
 std::mutex DtlsTransport::GlobalMutex;
 
 void DtlsTransport::Init() {
 	std::lock_guard lock(GlobalMutex);
+
+	openssl::init();
+
 	if (!BioMethods) {
 		BioMethods = BIO_meth_new(BIO_TYPE_BIO, "DTLS writer");
 		if (!BioMethods)
@@ -354,16 +297,16 @@ void DtlsTransport::Cleanup() {
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
-      mVerifierCallback(std::move(verifierCallback)) {
-
+      mVerifierCallback(std::move(verifierCallback)),
+      mIsClient(lower->role() == Description::Role::Active) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
 	try {
 		if (!(mCtx = SSL_CTX_new(DTLS_method())))
 			throw std::runtime_error("Failed to create SSL context");
 
-		check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
-		              "Failed to set SSL priorities");
+		openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
+		               "Failed to set SSL priorities");
 
 		// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
 		// Therefore, the DTLS layer MUST NOT use any compression algorithm.
@@ -381,14 +324,14 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 		SSL_CTX_use_certificate(mCtx, x509);
 		SSL_CTX_use_PrivateKey(mCtx, pkey);
 
-		check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
+		openssl::check(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
 
 		if (!(mSsl = SSL_new(mCtx)))
 			throw std::runtime_error("Failed to create SSL instance");
 
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
-		if (lower->role() == Description::Role::Active)
+		if (mIsClient)
 			SSL_set_connect_state(mSsl);
 		else
 			SSL_set_accept_state(mSsl);
@@ -442,7 +385,7 @@ bool DtlsTransport::send(message_ptr message) {
 	PLOG_VERBOSE << "Send size=" << message->size();
 
 	int ret = SSL_write(mSsl, message->data(), message->size());
-	return check_openssl_ret(mSsl, ret);
+	return openssl::check(mSsl, ret);
 }
 
 void DtlsTransport::incoming(message_ptr message) {
@@ -455,6 +398,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 }
 
+void DtlsTransport::postHandshake() {
+	// Dummy
+}
+
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 	try {
@@ -463,7 +410,7 @@ void DtlsTransport::runRecvLoop() {
 
 		// Initiate the handshake
 		int ret = SSL_do_handshake(mSsl);
-		check_openssl_ret(mSsl, ret, "Handshake failed");
+		openssl::check(mSsl, ret, "Handshake failed");
 
 		const size_t bufferSize = maxMtu;
 		byte buffer[bufferSize];
@@ -476,7 +423,7 @@ void DtlsTransport::runRecvLoop() {
 				if (state() == State::Connecting) {
 					// Continue the handshake
 					int ret = SSL_do_handshake(mSsl);
-					if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
+					if (!openssl::check(mSsl, ret, "Handshake failed"))
 						break;
 
 					if (SSL_is_init_finished(mSsl)) {
@@ -486,10 +433,11 @@ void DtlsTransport::runRecvLoop() {
 
 						PLOG_INFO << "DTLS handshake finished";
 						changeState(State::Connected);
+						postHandshake();
 					}
 				} else {
 					int ret = SSL_read(mSsl, buffer, bufferSize);
-					if (!check_openssl_ret(mSsl, ret))
+					if (!openssl::check(mSsl, ret))
 						break;
 
 					if (ret > 0)
@@ -547,7 +495,7 @@ int DtlsTransport::CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx) {
 	    static_cast<DtlsTransport *>(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex));
 
 	X509 *crt = X509_STORE_CTX_get_current_cert(ctx);
-	std::string fingerprint = make_fingerprint(crt);
+	string fingerprint = make_fingerprint(crt);
 
 	return t->mVerifierCallback(fingerprint) ? 1 : 0;
 }
@@ -604,7 +552,7 @@ long DtlsTransport::BioMethodCtrl(BIO *bio, int cmd, long num, void *ptr) {
 	return 0;
 }
 
-} // namespace rtc
-
 #endif
 
+} // namespace rtc
+

+ 8 - 12
src/dtlstransport.hpp

@@ -23,6 +23,7 @@
 #include "include.hpp"
 #include "peerconnection.hpp"
 #include "queue.hpp"
+#include "tls.hpp"
 #include "transport.hpp"
 
 #include <atomic>
@@ -31,12 +32,6 @@
 #include <mutex>
 #include <thread>
 
-#if USE_GNUTLS
-#include <gnutls/gnutls.h>
-#else
-#include <openssl/ssl.h>
-#endif
-
 namespace rtc {
 
 class IceTransport;
@@ -52,20 +47,21 @@ public:
 	              verifier_callback verifierCallback, state_callback stateChangeCallback);
 	~DtlsTransport();
 
-	bool stop() override;
-	bool send(message_ptr message) override; // false if dropped
+	virtual bool stop() override;
+	virtual bool send(message_ptr message) override; // false if dropped
 
-private:
-	void incoming(message_ptr message) override;
+protected:
+	virtual void incoming(message_ptr message) override;
+	virtual void postHandshake();
 	void runRecvLoop();
 
 	const certificate_ptr mCertificate;
+	const verifier_callback mVerifierCallback;
+	const bool mIsClient;
 
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
 
-	verifier_callback mVerifierCallback;
-
 #if USE_GNUTLS
 	gnutls_session_t mSession;
 

+ 3 - 3
src/icetransport.cpp

@@ -122,7 +122,7 @@ Description IceTransport::getLocalDescription(Description::Type type) const {
 void IceTransport::setRemoteDescription(const Description &description) {
 	mRole = description.role() == Description::Role::Active ? Description::Role::Passive
 	                                                        : Description::Role::Active;
-	mMid = description.mid();
+	mMid = description.dataMid();
 	if (juice_set_remote_description(mAgent.get(), string(description).c_str()) < 0)
 		throw std::runtime_error("Failed to parse remote SDP");
 }
@@ -483,8 +483,8 @@ Description IceTransport::getLocalDescription(Description::Type type) const {
 void IceTransport::setRemoteDescription(const Description &description) {
 	mRole = description.role() == Description::Role::Active ? Description::Role::Passive
 	                                                        : Description::Role::Active;
-	mMid = description.mid();
-	mTrickleTimeout = description.trickleEnabled() ? 30s : 0s;
+	mMid = description.dataMid();
+	mTrickleTimeout = !description.ended() ? 30s : 0s;
 
 	// Warning: libnice expects "\n" as end of line
 	if (nice_agent_parse_remote_sdp(mNiceAgent.get(), description.generateSdp("\n").c_str()) < 0)

+ 12 - 10
src/init.cpp

@@ -21,20 +21,18 @@
 #include "certificate.hpp"
 #include "dtlstransport.hpp"
 #include "sctptransport.hpp"
+#include "tls.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
 #include "tlstransport.hpp"
 #endif
 
-#ifdef _WIN32
-#include <winsock2.h>
+#if RTC_ENABLE_MEDIA
+#include "dtlssrtptransport.hpp"
 #endif
 
-#if USE_GNUTLS
-// Nothing to do
-#else
-#include <openssl/err.h>
-#include <openssl/ssl.h>
+#ifdef _WIN32
+#include <winsock2.h>
 #endif
 
 using std::shared_ptr;
@@ -69,9 +67,7 @@ Init::Init() {
 #if USE_GNUTLS
 		// Nothing to do
 #else
-	OPENSSL_init_ssl(0, NULL);
-	SSL_load_error_strings();
-	ERR_load_crypto_strings();
+	openssl::init();
 #endif
 
 	SctpTransport::Init();
@@ -79,6 +75,9 @@ Init::Init() {
 #if RTC_ENABLE_WEBSOCKET
 	TlsTransport::Init();
 #endif
+#if RTC_ENABLE_MEDIA
+	DtlsSrtpTransport::Init();
+#endif
 }
 
 Init::~Init() {
@@ -88,6 +87,9 @@ Init::~Init() {
 #if RTC_ENABLE_WEBSOCKET
 	TlsTransport::Cleanup();
 #endif
+#if RTC_ENABLE_MEDIA
+	DtlsSrtpTransport::Cleanup();
+#endif
 
 #ifdef _WIN32
 	WSACleanup();

+ 111 - 23
src/peerconnection.cpp

@@ -23,6 +23,10 @@
 #include "include.hpp"
 #include "sctptransport.hpp"
 
+#if RTC_ENABLE_MEDIA
+#include "dtlssrtptransport.hpp"
+#endif
+
 #include <thread>
 
 namespace rtc {
@@ -67,6 +71,22 @@ std::optional<Description> PeerConnection::remoteDescription() const {
 	return mRemoteDescription;
 }
 
+void PeerConnection::setLocalDescription(std::optional<Description> description) {
+	if (auto iceTransport = std::atomic_load(&mIceTransport)) {
+		throw std::logic_error("Local description is already set");
+	} else {
+		// RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
+		// setup:actpass.
+		// See https://tools.ietf.org/html/rfc5763#section-5
+		iceTransport = initIceTransport(Description::Role::ActPass);
+		Description localDescription = iceTransport->getLocalDescription(Description::Type::Offer);
+		if (description)
+			localDescription.addMedia(*description);
+		processLocalDescription(localDescription);
+		iceTransport->gatherLocalCandidates();
+	}
+}
+
 void PeerConnection::setRemoteDescription(Description description) {
 	description.hintType(localDescription() ? Description::Type::Answer : Description::Type::Offer);
 	auto remoteCandidates = description.extractCandidates();
@@ -82,7 +102,9 @@ void PeerConnection::setRemoteDescription(Description description) {
 
 	if (mRemoteDescription->type() == Description::Type::Offer) {
 		// This is an offer and we are the answerer.
-		processLocalDescription(iceTransport->getLocalDescription(Description::Type::Answer));
+		Description localDescription = iceTransport->getLocalDescription(Description::Type::Answer);
+		localDescription.addMedia(description); // blindly accept media
+		processLocalDescription(localDescription);
 		iceTransport->gatherLocalCandidates();
 	} else {
 		// This is an answer and we are the offerer.
@@ -190,6 +212,39 @@ void PeerConnection::onGatheringStateChange(std::function<void(GatheringState st
 	mGatheringStateChangeCallback = callback;
 }
 
+bool PeerConnection::hasMedia() const {
+	auto local = localDescription();
+	auto remote = remoteDescription();
+	return (local && local->hasMedia()) || (remote && remote->hasMedia());
+}
+
+void PeerConnection::sendMedia(const binary &packet) {
+	outgoingMedia(make_message(packet.begin(), packet.end(), Message::Binary));
+}
+
+void PeerConnection::send(const byte *packet, size_t size) {
+	outgoingMedia(make_message(packet, packet + size, Message::Binary));
+}
+
+void PeerConnection::onMedia(std::function<void(const binary &packet)> callback) {
+	mMediaCallback = callback;
+}
+
+void PeerConnection::outgoingMedia(message_ptr message) {
+	if (!hasMedia())
+		throw std::runtime_error("PeerConnection has no media support");
+
+#if RTC_ENABLE_MEDIA
+	auto transport = std::atomic_load(&mDtlsTransport);
+	if (!transport)
+		throw std::runtime_error("PeerConnection is not open");
+
+	std::dynamic_pointer_cast<DtlsSrtpTransport>(transport)->send(message);
+#else
+	PLOG_WARNING << "Ignoring sent media (not compiled with SRTP support)";
+#endif
+}
+
 shared_ptr<IceTransport> PeerConnection::initIceTransport(Description::Role role) {
 	try {
 		if (auto transport = std::atomic_load(&mIceTransport))
@@ -259,27 +314,48 @@ shared_ptr<DtlsTransport> PeerConnection::initDtlsTransport() {
 
 		auto certificate = mCertificate.get();
 		auto lower = std::atomic_load(&mIceTransport);
-		auto transport = std::make_shared<DtlsTransport>(
-		    lower, certificate, weak_bind(&PeerConnection::checkFingerprint, this, _1),
-		    [this, weak_this = weak_from_this()](DtlsTransport::State state) {
-			    auto shared_this = weak_this.lock();
-			    if (!shared_this)
-				    return;
-			    switch (state) {
-			    case DtlsTransport::State::Connected:
-				    initSctpTransport();
-				    break;
-			    case DtlsTransport::State::Failed:
-				    changeState(State::Failed);
-				    break;
-			    case DtlsTransport::State::Disconnected:
-				    changeState(State::Disconnected);
-				    break;
-			    default:
-				    // Ignore
-				    break;
-			    }
-		    });
+		auto verifierCallback = weak_bind(&PeerConnection::checkFingerprint, this, _1);
+		auto stateChangeCallback = [this,
+		                            weak_this = weak_from_this()](DtlsTransport::State state) {
+			auto shared_this = weak_this.lock();
+			if (!shared_this)
+				return;
+
+			switch (state) {
+			case DtlsTransport::State::Connected:
+				initSctpTransport();
+				break;
+			case DtlsTransport::State::Failed:
+				changeState(State::Failed);
+				break;
+			case DtlsTransport::State::Disconnected:
+				changeState(State::Disconnected);
+				break;
+			default:
+				// Ignore
+				break;
+			}
+		};
+
+		shared_ptr<DtlsTransport> transport;
+		if (hasMedia()) {
+#if RTC_ENABLE_MEDIA
+			PLOG_INFO << "This connection requires media support";
+
+			// DTLS-SRTP
+			transport = std::make_shared<DtlsSrtpTransport>(
+			    lower, certificate, verifierCallback,
+			    std::bind(&PeerConnection::forwardMedia, this, _1), stateChangeCallback);
+#else
+			PLOG_WARNING << "Ignoring media support (not compiled with SRTP support)";
+#endif
+		}
+
+		if (!transport) {
+			// DTLS only
+			transport = std::make_shared<DtlsTransport>(lower, certificate, verifierCallback,
+			                                            stateChangeCallback);
+		}
 
 		std::atomic_store(&mDtlsTransport, transport);
 		if (mState == State::Closed) {
@@ -316,8 +392,15 @@ shared_ptr<SctpTransport> PeerConnection::initSctpTransport() {
 				    openDataChannels();
 				    break;
 			    case SctpTransport::State::Failed:
+				    LOG_WARNING << "SCTP transport failed";
 				    remoteCloseDataChannels();
+#if RTC_ENABLE_MEDIA
+				    // Ignore SCTP failure if media is present
+				    if (!hasMedia())
+					    changeState(State::Failed);
+#else
 				    changeState(State::Failed);
+#endif
 				    break;
 			    case SctpTransport::State::Disconnected:
 				    remoteCloseDataChannels();
@@ -358,7 +441,7 @@ void PeerConnection::closeTransports() {
 	auto dtls = std::atomic_exchange(&mDtlsTransport, decltype(mDtlsTransport)(nullptr));
 	auto ice = std::atomic_exchange(&mIceTransport, decltype(mIceTransport)(nullptr));
 	if (sctp || dtls || ice) {
-		std::thread t([sctp, dtls, ice]() mutable {
+		std::thread t([sctp, dtls, ice, token = mInitToken]() mutable {
 			if (sctp)
 				sctp->stop();
 			if (dtls)
@@ -422,6 +505,11 @@ void PeerConnection::forwardMessage(message_ptr message) {
 	channel->incoming(message);
 }
 
+void PeerConnection::forwardMedia(message_ptr message) {
+	if (message)
+		mMediaCallback(*message);
+}
+
 void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
 	if (auto channel = findDataChannel(stream))
 		channel->triggerBufferedAmount(amount);

+ 2 - 2
src/tcptransport.cpp

@@ -16,10 +16,10 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#if RTC_ENABLE_WEBSOCKET
-
 #include "tcptransport.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+
 #include <exception>
 #ifndef _WIN32
 #include <fcntl.h>

+ 2 - 2
src/tcptransport.hpp

@@ -19,12 +19,12 @@
 #ifndef RTC_TCP_TRANSPORT_H
 #define RTC_TCP_TRANSPORT_H
 
-#if RTC_ENABLE_WEBSOCKET
-
 #include "include.hpp"
 #include "queue.hpp"
 #include "transport.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+
 #include <mutex>
 #include <thread>
 

+ 132 - 0
src/tls.cpp

@@ -0,0 +1,132 @@
+/**
+ * Copyright (c) 2019-2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "tls.hpp"
+
+#if USE_GNUTLS
+
+namespace rtc::gnutls {
+
+bool check(int ret, const string &message) {
+	if (ret < 0) {
+		if (!gnutls_error_is_fatal(ret)) {
+			PLOG_INFO << gnutls_strerror(ret);
+			return false;
+		}
+		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
+		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
+	}
+	return true;
+}
+
+gnutls_certificate_credentials_t *new_credentials() {
+	auto creds = new gnutls_certificate_credentials_t;
+	gnutls::check(gnutls_certificate_allocate_credentials(creds));
+	return creds;
+}
+
+void free_credentials(gnutls_certificate_credentials_t *creds) {
+	gnutls_certificate_free_credentials(*creds);
+	delete creds;
+}
+
+gnutls_x509_crt_t *new_crt() {
+	auto crt = new gnutls_x509_crt_t;
+	gnutls::check(gnutls_x509_crt_init(crt));
+	return crt;
+}
+
+void free_crt(gnutls_x509_crt_t *crt) {
+	gnutls_x509_crt_deinit(*crt);
+	delete crt;
+}
+
+gnutls_x509_privkey_t *new_privkey() {
+	auto privkey = new gnutls_x509_privkey_t;
+	gnutls::check(gnutls_x509_privkey_init(privkey));
+	return privkey;
+}
+
+void free_privkey(gnutls_x509_privkey_t *privkey) {
+	gnutls_x509_privkey_deinit(*privkey);
+	delete privkey;
+}
+
+gnutls_datum_t make_datum(char *data, size_t size) {
+	gnutls_datum_t datum;
+	datum.data = reinterpret_cast<unsigned char *>(data);
+	datum.size = size;
+	return datum;
+}
+
+} // namespace rtc::gnutls
+
+#else // USE_GNUTLS==0
+
+namespace rtc::openssl {
+
+void init() {
+	static std::mutex mutex;
+	static bool done = false;
+
+	std::lock_guard lock(mutex);
+	if (!done) {
+		OPENSSL_init_ssl(0, NULL);
+		SSL_load_error_strings();
+		ERR_load_crypto_strings();
+		done = true;
+	}
+}
+
+string error_string(unsigned long err) {
+	const size_t bufferSize = 256;
+	char buffer[bufferSize];
+	ERR_error_string_n(err, buffer, bufferSize);
+	return string(buffer);
+}
+
+bool check(int success, const string &message) {
+	if (success)
+		return true;
+
+	string str = error_string(ERR_get_error());
+	PLOG_ERROR << message << ": " << str;
+	throw std::runtime_error(message + ": " + str);
+}
+
+bool check(SSL *ssl, int ret, const string &message) {
+	if (ret == BIO_EOF)
+		return true;
+
+	unsigned long err = SSL_get_error(ssl, ret);
+	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
+		return true;
+	}
+	if (err == SSL_ERROR_ZERO_RETURN) {
+		PLOG_DEBUG << "DTLS connection cleanly closed";
+		return false;
+	}
+	string str = error_string(err);
+	PLOG_ERROR << str;
+	throw std::runtime_error(message + ": " + str);
+}
+
+} // namespace rtc::openssl
+
+#endif
+

+ 75 - 0
src/tls.hpp

@@ -0,0 +1,75 @@
+/**
+ * Copyright (c) 2019-2020 Paul-Louis Ageneau
+ *
+ * This library is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with this library; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#ifndef RTC_TLS_H
+#define RTC_TLS_H
+
+#include "include.hpp"
+
+#if USE_GNUTLS
+
+#include <gnutls/gnutls.h>
+
+#include <gnutls/crypto.h>
+#include <gnutls/dtls.h>
+#include <gnutls/x509.h>
+
+namespace rtc::gnutls {
+
+bool check(int ret, const string &message = "GnuTLS error");
+
+gnutls_certificate_credentials_t *new_credentials();
+void free_credentials(gnutls_certificate_credentials_t *creds);
+
+gnutls_x509_crt_t *new_crt();
+void free_crt(gnutls_x509_crt_t *crt);
+
+gnutls_x509_privkey_t *new_privkey();
+void free_privkey(gnutls_x509_privkey_t *privkey);
+
+gnutls_datum_t make_datum(char *data, size_t size);
+
+} // namespace rtc::gnutls
+
+#else // USE_GNUTLS==0
+
+#include <openssl/ssl.h>
+
+#include <openssl/bio.h>
+#include <openssl/ec.h>
+#include <openssl/err.h>
+#include <openssl/pem.h>
+#include <openssl/x509.h>
+
+#ifndef BIO_EOF
+#define BIO_EOF -1
+#endif
+
+namespace rtc::openssl {
+
+void init();
+string error_string(unsigned long err);
+
+bool check(int success, const string &message = "OpenSSL error");
+bool check(SSL *ssl, int ret, const string &message = "OpenSSL error");
+
+} // namespace rtc::openssl
+
+#endif
+
+#endif

+ 25 - 86
src/tlstransport.cpp

@@ -16,11 +16,11 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#if RTC_ENABLE_WEBSOCKET
-
 #include "tlstransport.hpp"
 #include "tcptransport.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+
 #include <chrono>
 #include <cstring>
 #include <exception>
@@ -33,26 +33,10 @@ using std::string;
 using std::unique_ptr;
 using std::weak_ptr;
 
-#if USE_GNUTLS
-
-namespace {
-
-static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
-	if (ret < 0) {
-		if (!gnutls_error_is_fatal(ret)) {
-			PLOG_INFO << gnutls_strerror(ret);
-			return false;
-		}
-		PLOG_ERROR << message << ": " << gnutls_strerror(ret);
-		throw std::runtime_error(message + ": " + gnutls_strerror(ret));
-	}
-	return true;
-}
-
-} // namespace
-
 namespace rtc {
 
+#if USE_GNUTLS
+
 void TlsTransport::Init() {
 	// Nothing to do
 }
@@ -66,18 +50,18 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 
 	PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
 
-	check_gnutls(gnutls_certificate_allocate_credentials(&mCreds));
-	check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
+	gnutls::check(gnutls_certificate_allocate_credentials(&mCreds));
+	gnutls::check(gnutls_init(&mSession, GNUTLS_CLIENT));
 
 	try {
-        check_gnutls(gnutls_certificate_set_x509_system_trust(mCreds));
-        check_gnutls(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
-        gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
+		gnutls::check(gnutls_certificate_set_x509_system_trust(mCreds));
+		gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
+		gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
 
 		const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
 		const char *err_pos = NULL;
-		check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
-		             "Failed to set TLS priorities");
+		gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
+		              "Failed to set TLS priorities");
 
        	PLOG_VERBOSE << "Server Name Indication: " << mHost;
 		gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size());
@@ -100,6 +84,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 
 TlsTransport::~TlsTransport() {
 	stop();
+
 	gnutls_deinit(mSession);
 	gnutls_certificate_free_credentials(mCreds);
 }
@@ -128,7 +113,7 @@ bool TlsTransport::send(message_ptr message) {
 		ret = gnutls_record_send(mSession, message->data(), message->size());
 	} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 
-	return check_gnutls(ret);
+	return gnutls::check(ret);
 }
 
 void TlsTransport::incoming(message_ptr message) {
@@ -150,7 +135,7 @@ void TlsTransport::runRecvLoop() {
 		do {
 			ret = gnutls_handshake(mSession);
 		} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
-		         !check_gnutls(ret, "TLS handshake failed"));
+		         !gnutls::check(ret, "TLS handshake failed"));
 
 	} catch (const std::exception &e) {
 		PLOG_ERROR << "TLS handshake: " << e.what();
@@ -175,7 +160,7 @@ void TlsTransport::runRecvLoop() {
 				break;
 			}
 
-			if (check_gnutls(ret)) {
+			if (gnutls::check(ret)) {
 				if (ret == 0) {
 					// Closed
 					PLOG_DEBUG << "TLS connection cleanly closed";
@@ -250,59 +235,13 @@ int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	return !t->mIncomingQueue.empty() ? 1 : 0;
 }
 
-} // namespace rtc
-
 #else // USE_GNUTLS==0
 
-#include <openssl/bio.h>
-#include <openssl/ec.h>
-#include <openssl/err.h>
-#include <openssl/ssl.h>
-
-namespace {
-
-const int BIO_EOF = -1;
-
-string openssl_error_string(unsigned long err) {
-	const size_t bufferSize = 256;
-	char buffer[bufferSize];
-	ERR_error_string_n(err, buffer, bufferSize);
-	return string(buffer);
-}
-
-bool check_openssl(int success, const string &message = "OpenSSL error") {
-	if (success)
-		return true;
-
-	string str = openssl_error_string(ERR_get_error());
-	PLOG_ERROR << message << ": " << str;
-	throw std::runtime_error(message + ": " + str);
-}
-
-bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
-	if (ret == BIO_EOF)
-		return true;
-
-	unsigned long err = SSL_get_error(ssl, ret);
-	if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
-		return true;
-	}
-	if (err == SSL_ERROR_ZERO_RETURN) {
-		PLOG_DEBUG << "TLS connection cleanly closed";
-		return false;
-	}
-	string str = openssl_error_string(err);
-	PLOG_ERROR << str;
-	throw std::runtime_error(message + ": " + str);
-}
-
-} // namespace
-
-namespace rtc {
-
 int TlsTransport::TransportExIndex = -1;
 
 void TlsTransport::Init() {
+	openssl::init();
+
 	if (TransportExIndex < 0) {
 		TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
 	}
@@ -321,8 +260,8 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
 			throw std::runtime_error("Failed to create SSL context");
 
-		check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
-		              "Failed to set SSL priorities");
+		openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
+		               "Failed to set SSL priorities");
 
 		SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
 		SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
@@ -340,7 +279,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
 		SSL_set_hostflags(mSsl, 0);
-		check_openssl(SSL_set1_host(mSsl, mHost.c_str()), "Failed to set SSL host");
+		openssl::check(SSL_set1_host(mSsl, mHost.c_str()), "Failed to set SSL host");
 
 		PLOG_VERBOSE << "Server Name Indication: " << mHost;
 		SSL_set_tlsext_host_name(mSsl, mHost.c_str());
@@ -399,7 +338,7 @@ bool TlsTransport::send(message_ptr message) {
 		return true;
 
 	int ret = SSL_write(mSsl, message->data(), message->size());
-	if (!check_openssl_ret(mSsl, ret))
+	if (!openssl::check(mSsl, ret))
 		return false;
 
 	const size_t bufferSize = 4096;
@@ -428,7 +367,7 @@ void TlsTransport::runRecvLoop() {
 			if (state() == State::Connecting) {
 				// Initiate or continue the handshake
 				int ret = SSL_do_handshake(mSsl);
-				if (!check_openssl_ret(mSsl, ret, "Handshake failed"))
+				if (!openssl::check(mSsl, ret, "Handshake failed"))
 					break;
 
 				// Output
@@ -441,7 +380,7 @@ void TlsTransport::runRecvLoop() {
 				}
 			} else {
 				int ret = SSL_read(mSsl, buffer, bufferSize);
-				if (!check_openssl_ret(mSsl, ret))
+				if (!openssl::check(mSsl, ret))
 					break;
 
 				if (ret > 0)
@@ -483,8 +422,8 @@ void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
 	}
 }
 
-} // namespace rtc
-
 #endif
 
+} // namespace rtc
+
 #endif

+ 3 - 9
src/tlstransport.hpp

@@ -19,22 +19,16 @@
 #ifndef RTC_TLS_TRANSPORT_H
 #define RTC_TLS_TRANSPORT_H
 
-#if RTC_ENABLE_WEBSOCKET
-
 #include "include.hpp"
 #include "queue.hpp"
+#include "tls.hpp"
 #include "transport.hpp"
 
-#include <memory>
+#if RTC_ENABLE_WEBSOCKET
+
 #include <mutex>
 #include <thread>
 
-#if USE_GNUTLS
-#include <gnutls/gnutls.h>
-#else
-#include <openssl/ssl.h>
-#endif
-
 namespace rtc {
 
 class TcpTransport;

+ 2 - 2
src/websocket.cpp

@@ -131,7 +131,7 @@ size_t WebSocket::availableAmount() const { return mRecvQueue.amount(); }
 
 bool WebSocket::changeState(State state) { return mState.exchange(state) != state; }
 
-bool WebSocket::outgoing(mutable_message_ptr message) {
+bool WebSocket::outgoing(message_ptr message) {
 	if (mState != State::Open || !mWsTransport)
 		throw std::runtime_error("WebSocket is not open");
 
@@ -302,7 +302,7 @@ void WebSocket::closeTransports() {
 	auto tls = std::atomic_exchange(&mTlsTransport, decltype(mTlsTransport)(nullptr));
 	auto tcp = std::atomic_exchange(&mTcpTransport, decltype(mTcpTransport)(nullptr));
 	if (ws || tls || tcp) {
-		std::thread t([ws, tls, tcp]() mutable {
+		std::thread t([ws, tls, tcp, token = mInitToken]() mutable {
 			if (ws)
 				ws->stop();
 			if (tls)

+ 2 - 11
src/wstransport.cpp

@@ -16,14 +16,13 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#if RTC_ENABLE_WEBSOCKET
-
 #include "wstransport.hpp"
 #include "tcptransport.hpp"
 #include "tlstransport.hpp"
-
 #include "base64.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+
 #include <chrono>
 #include <list>
 #include <map>
@@ -75,14 +74,6 @@ bool WsTransport::stop() {
 }
 
 bool WsTransport::send(message_ptr message) {
-	if (!message)
-		return false;
-
-	// Call the mutable message overload with a copy
-	return send(std::make_shared<Message>(*message));
-}
-
-bool WsTransport::send(mutable_message_ptr message) {
 	if (!message || state() != State::Connected)
 		return false;
 

+ 2 - 3
src/wstransport.hpp

@@ -19,11 +19,11 @@
 #ifndef RTC_WS_TRANSPORT_H
 #define RTC_WS_TRANSPORT_H
 
-#if RTC_ENABLE_WEBSOCKET
-
 #include "include.hpp"
 #include "transport.hpp"
 
+#if RTC_ENABLE_WEBSOCKET
+
 namespace rtc {
 
 class TcpTransport;
@@ -37,7 +37,6 @@ public:
 
 	bool stop() override;
 	bool send(message_ptr message) override;
-	bool send(mutable_message_ptr message);
 
 	void incoming(message_ptr message) override;