Browse Source

Implement Mbed TLS Backend

Co-authored-by: tytan652 <[email protected]>
Co-authored-by: Paul-Louis Ageneau <[email protected]>
Sean DuBois 2 years ago
parent
commit
e6fbddeb9c

+ 37 - 0
.github/workflows/build-mbedtls.yml

@@ -0,0 +1,37 @@
+name: Build with Mbed TLS
+on:
+  push:
+    branches:
+    - master
+  pull_request:
+jobs:
+  build-linux:
+    runs-on: ubuntu-latest
+    steps:
+    - uses: actions/checkout@v2
+    - name: Set up Homebrew
+      uses: Homebrew/actions/setup-homebrew@master
+    - name: Install Mbed TLS
+      run: brew update && brew install mbedtls
+    - name: submodules
+      run: git submodule update --init --recursive --depth 1
+    - name: cmake
+      run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1  -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls)
+    - name: make
+      run: (cd build; make -j2)
+    - name: test
+      run: ./build/tests
+  build-macos:
+    runs-on: macos-latest
+    steps:
+    - uses: actions/checkout@v2
+    - name: Install Mbed TLS
+      run: brew update && brew install mbedtls
+    - name: submodules
+      run: git submodule update --init --recursive --depth 1
+    - name: cmake
+      run: cmake -B build -DUSE_MBEDTLS=1 -DWARNINGS_AS_ERRORS=1 -DENABLE_LOCAL_ADDRESS_TRANSLATION=1  -DCMAKE_PREFIX_PATH=$(brew --prefix mbedtls)
+    - name: make
+      run: (cd build; make -j2)
+    - name: test
+      run: ./build/tests

+ 19 - 1
CMakeLists.txt

@@ -5,6 +5,7 @@ project(libdatachannel
 set(PROJECT_DESCRIPTION "C/C++ WebRTC network library featuring Data Channels, Media Transport, and WebSockets")
 set(PROJECT_DESCRIPTION "C/C++ WebRTC network library featuring Data Channels, Media Transport, and WebSockets")
 
 
 # Options
 # Options
+option(USE_MBEDTLS "Use Mbed TLS instead of OpenSSL" OFF)
 option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
 option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
 option(USE_NICE "Use libnice instead of libjuice" OFF)
 option(USE_NICE "Use libnice instead of libjuice" OFF)
 option(PREFER_SYSTEM_LIB "Prefer system libraries over deps folder" OFF)
 option(PREFER_SYSTEM_LIB "Prefer system libraries over deps folder" OFF)
@@ -21,12 +22,22 @@ option(WARNINGS_AS_ERRORS "Treat warnings as errors" OFF)
 option(CAPI_STDCALL "Set calling convention of C API callbacks stdcall" OFF)
 option(CAPI_STDCALL "Set calling convention of C API callbacks stdcall" OFF)
 option(SCTP_DEBUG "Enable SCTP debugging output to verbose log" OFF)
 option(SCTP_DEBUG "Enable SCTP debugging output to verbose log" OFF)
 
 
+if (USE_MBEDTLS AND USE_GNUTLS)
+	message(FATAL_ERROR "Both USE_MBEDTLS and USE_GNUTLS can not be enabled at the same time")
+endif()
+
+
 if(USE_GNUTLS)
 if(USE_GNUTLS)
 	option(USE_NETTLE "Use Nettle in libjuice" ON)
 	option(USE_NETTLE "Use Nettle in libjuice" ON)
 else()
 else()
 	option(USE_NETTLE "Use Nettle in libjuice" OFF)
 	option(USE_NETTLE "Use Nettle in libjuice" OFF)
+
 	if(NOT USE_SYSTEM_SRTP)
 	if(NOT USE_SYSTEM_SRTP)
-		option(ENABLE_OPENSSL "Enable OpenSSL crypto engine for SRTP" ON)
+		if (USE_MBEDTLS)
+			option(ENABLE_MBEDTLS "Enable Mbed TLS crypto engine for SRTP" ON)
+		else()
+			option(ENABLE_OPENSSL "Enable OpenSSL crypto engine for SRTP" ON)
+		endif()
 	endif()
 	endif()
 endif()
 endif()
 
 
@@ -337,6 +348,13 @@ if (USE_GNUTLS)
 		target_link_libraries(datachannel PRIVATE Nettle::Nettle)
 		target_link_libraries(datachannel PRIVATE Nettle::Nettle)
 		target_link_libraries(datachannel-static PRIVATE Nettle::Nettle)
 		target_link_libraries(datachannel-static PRIVATE Nettle::Nettle)
 	endif()
 	endif()
+elseif(USE_MBEDTLS)
+	find_package(MbedTLS 3 REQUIRED)
+
+	target_compile_definitions(datachannel PRIVATE USE_MBEDTLS)
+	target_compile_definitions(datachannel-static PRIVATE USE_MBEDTLS)
+	target_link_libraries(datachannel PRIVATE MbedTLS::MbedTLS)
+	target_link_libraries(datachannel-static PRIVATE MbedTLS::MbedTLS)
 else()
 else()
 	if(APPLE)
 	if(APPLE)
 		# This is a bug in CMake that causes it to prefer the system version over
 		# This is a bug in CMake that causes it to prefer the system version over

+ 214 - 0
cmake/Modules/FindMbedTLS.cmake

@@ -0,0 +1,214 @@
+#[=======================================================================[.rst
+FindMbedTLS
+-----------
+
+FindModule for MbedTLS and associated libraries
+
+Components
+^^^^^^^^^^
+
+This module contains provides several components:
+
+``MbedCrypto``
+``MbedTLS``
+``MbedX509``
+
+Import targets exist for each component.
+
+Imported Targets
+^^^^^^^^^^^^^^^^
+
+This module defines the :prop_tgt:`IMPORTED` targets:
+
+``MbedTLS::MbedCrypto``
+  Crypto component
+
+``MbedTLS::MbedTLS``
+  TLS component
+
+``MbedTLS::MbedX509``
+  X509 component
+
+Result Variables
+^^^^^^^^^^^^^^^^
+
+This module sets the following variables:
+
+``MbedTLS_FOUND``
+  True, if all required components and the core library were found.
+``MbedTLS_VERSION``
+  Detected version of found MbedTLS libraries.
+
+``MbedTLS_<COMPONENT>_VERSION``
+  Detected version of found MbedTLS component library.
+
+Cache variables
+^^^^^^^^^^^^^^^
+
+The following cache variables may also be set:
+
+``MbedTLS_<COMPONENT>_LIBRARY``
+  Path to the library component of MbedTLS.
+``MbedTLS_<COMPONENT>_INCLUDE_DIR``
+  Directory containing ``<COMPONENT>.h``.
+
+Distributed under the MIT License, see accompanying LICENSE file or
+https://github.com/PatTheMav/cmake-finders/blob/master/LICENSE for details.
+(c) 2023 Patrick Heyer
+
+#]=======================================================================]
+
+# cmake-format: off
+# cmake-lint: disable=C0103
+# cmake-lint: disable=C0301
+# cmake-lint: disable=C0307
+# cmake-format: on
+
+include(FindPackageHandleStandardArgs)
+
+find_package(PkgConfig QUIET)
+if(PKG_CONFIG_FOUND)
+  pkg_check_modules(PC_MbedTLS QUIET mbedtls mbedcrypto mbedx509)
+endif()
+
+# MbedTLS_set_soname: Set SONAME on imported library targets
+macro(MbedTLS_set_soname component)
+  if(CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin")
+    execute_process(
+      COMMAND sh -c "otool -D '${Mbed${component}_LIBRARY}' | grep -v '${Mbed${component}_LIBRARY}'"
+      OUTPUT_VARIABLE _output
+      RESULT_VARIABLE _result)
+
+    if(_result EQUAL 0 AND _output MATCHES "^@rpath/")
+      set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_SONAME "${_output}")
+    endif()
+  elseif(CMAKE_HOST_SYSTEM_NAME MATCHES "Linux|FreeBSD")
+    execute_process(
+      COMMAND sh -c "objdump -p '${Mbed${component}_LIBRARY}' | grep SONAME"
+      OUTPUT_VARIABLE _output
+      RESULT_VARIABLE _result)
+
+    if(_result EQUAL 0)
+      string(REGEX REPLACE "[ \t]+SONAME[ \t]+([^ \t]+)" "\\1" _soname "${_output}")
+      set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_SONAME "${_soname}")
+      unset(_soname)
+    endif()
+  endif()
+  unset(_output)
+  unset(_result)
+endmacro()
+
+find_path(
+  MbedTLS_INCLUDE_DIR
+  NAMES mbedtls/ssl.h
+  HINTS "${PC_MbedTLS_INCLUDE_DIRS}"
+  PATHS /usr/include /usr/local/include
+  DOC "MbedTLS include directory")
+
+if(PC_MbedTLS_VERSION VERSION_GREATER 0)
+  set(MbedTLS_VERSION ${PC_MbedTLS_VERSION})
+elseif(EXISTS "${MbedTLS_INCLUDE_DIR}/mbedtls/build_info.h")
+  file(STRINGS "${MbedTLS_INCLUDE_DIR}/mbedtls/build_info.h" _VERSION_STRING
+       REGEX "#define[ \t]+MBEDTLS_VERSION_STRING[ \t]+.+")
+  string(REGEX REPLACE ".*#define[ \t]+MBEDTLS_VERSION_STRING[ \t]+\"(.+)\".*" "\\1" MbedTLS_VERSION
+                       "${_VERSION_STRING}")
+else()
+  if(NOT MbedTLS_FIND_QUIETLY)
+    message(AUTHOR_WARNING "Failed to find MbedTLS version.")
+  endif()
+  set(MbedTLS_VERSION 0.0.0)
+endif()
+
+find_library(
+  MbedTLS_LIBRARY
+  NAMES libmbedtls mbedtls
+  HINTS "${PC_MbedTLS_LIBRARY_DIRS}"
+  PATHS /usr/lib /usr/local/lib
+  DOC "MbedTLS location")
+
+find_library(
+  MbedCrypto_LIBRARY
+  NAMES libmbedcrypto mbedcrypto
+  HINTS "${PC_MbedTLS_LIBRARY_DIRS}"
+  PATHS /usr/lib /usr/local/lib
+  DOC "MbedCrypto location")
+
+find_library(
+  MbedX509_LIBRARY
+  NAMES libmbedx509 mbedx509
+  HINTS "${PC_MbedTLS_LIBRARY_DIRS}"
+  PATHS /usr/lib /usr/local/lib
+  DOC "MbedX509 location")
+
+if(MbedTLS_LIBRARY
+   AND NOT MbedCrypto_LIBRARY
+   AND NOT MbedX509_LIBRARY)
+  set(CMAKE_REQUIRED_LIBRARIES "${MbedTLS_LIBRARY}")
+  set(CMAKE_REQUIRED_INCLUDES "${MbedTLS_INCLUDE_DIR}")
+
+  check_symbol_exists(mbedtls_x509_crt_init "mbedtls/x590_crt.h" MbedTLS_INCLUDES_X509)
+  check_symbol_exists(mbedtls_sha256_init "mbedtls/sha256.h" MbedTLS_INCLUDES_CRYPTO)
+  unset(CMAKE_REQUIRED_LIBRARIES)
+  unset(CMAKE_REQUIRED_INCLUDES)
+endif()
+
+if(CMAKE_HOST_SYSTEM_NAME MATCHES "Darwin|Windows")
+  set(MbedTLS_ERROR_REASON "Ensure that an MbedTLS distribution is provided as part of CMAKE_PREFIX_PATH.")
+elseif(CMAKE_HOST_SYSTEM_NAME MATCHES "Linux|FreeBSD")
+  set(MbedTLS_ERROR_REASON "Ensure that MbedTLS is installed on the system.")
+endif()
+
+if(MbedTLS_INCLUDES_X509 AND MbedTLS_INCLUDES_CRYPTO)
+  find_package_handle_standard_args(
+    MbedTLS
+    REQUIRED_VARS MbedTLS_LIBRARY MbedTLS_INCLUDE_DIR
+    VERSION_VAR MbedTLS_VERSION REASON_FAILURE_MESSAGE "${MbedTLS_ERROR_REASON}")
+  mark_as_advanced(MbedTLS_LIBRARY MbedTLS_INCLUDE_DIR)
+  list(APPEND _COMPONENTS TLS)
+else()
+  find_package_handle_standard_args(
+    MbedTLS
+    REQUIRED_VARS MbedTLS_LIBRARY MbedCrypto_LIBRARY MbedX509_LIBRARY MbedTLS_INCLUDE_DIR
+    VERSION_VAR MbedTLS_VERSION REASON_FAILURE_MESSAGE "${MbedTLS_ERROR_REASON}")
+  mark_as_advanced(MbedTLS_LIBRARY MbedCrypto_LIBRARY MbedX509_LIBRARY MbedTLS_INCLUDE_DIR)
+  list(APPEND _COMPONENTS TLS Crypto X509)
+endif()
+unset(MbedTLS_ERROR_REASON)
+
+if(MbedTLS_FOUND)
+  foreach(component IN LISTS _COMPONENTS)
+    if(NOT TARGET MbedTLS::Mbed${component})
+      if(IS_ABSOLUTE "${Mbed${component}_LIBRARY}")
+        add_library(MbedTLS::Mbed${component} UNKNOWN IMPORTED)
+        set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_LOCATION "${Mbed${component}_LIBRARY}")
+      else()
+        add_library(MbedTLS::Mbed${component} INTERFACE IMPORTED)
+        set_property(TARGET MbedTLS::Mbed${component} PROPERTY IMPORTED_LIBNAME "${Mbed${component}_LIBRARY}")
+      endif()
+
+      mbedtls_set_soname(${component})
+      set_target_properties(
+        MbedTLS::MbedTLS
+        PROPERTIES INTERFACE_COMPILE_OPTIONS "${PC_MbedTLS_CFLAGS_OTHER}"
+                   INTERFACE_INCLUDE_DIRECTORIES "${MbedTLS_INCLUDE_DIR}"
+                   VERSION ${MbedTLS_VERSION})
+    endif()
+  endforeach()
+
+  if(MbedTLS_INCLUDES_X509 AND MbedTLS_INCLUDES_CRYPTO)
+    set(MbedTLS_LIBRARIES ${MbedTLS_LIBRARY})
+    set(MBEDTLS_INCLUDE_DIRS ${MbedTLS_INCLUDE_DIR})
+  else()
+    set(MbedTLS_LIBRARIES ${MbedTLS_LIBRARY} ${MbedCrypto_LIBRARY} ${MbedX509_LIBRARY})
+    set_property(TARGET MbedTLS::MbedTLS PROPERTY INTERFACE_LINK_LIBRARIES MbedTLS::MbedCrypto MbedTLS::MbedX509)
+    set(MBEDTLS_INCLUDE_DIRS ${MbedTLS_INCLUDE_DIR})
+  endif()
+endif()
+
+include(FeatureSummary)
+set_package_properties(
+  MbedTLS PROPERTIES
+  URL "https://www.trustedfirmware.org/projects/mbed-tls"
+  DESCRIPTION
+    "A C library implementing cryptographic primitives, X.509 certificate manipulation, and the SSL/TLS and DTLS protocols."
+)

+ 169 - 6
src/impl/certificate.cpp

@@ -111,8 +111,6 @@ Certificate::Certificate(shared_ptr<gnutls_certificate_credentials_t> creds)
 
 
 gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
 gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
 
 
-string Certificate::fingerprint() const { return mFingerprint; }
-
 string make_fingerprint(gnutls_certificate_credentials_t credentials) {
 string make_fingerprint(gnutls_certificate_credentials_t credentials) {
 	auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * {
 	auto new_crt_list = [credentials]() -> gnutls_x509_crt_t * {
 		gnutls_x509_crt_t *crt_list = nullptr;
 		gnutls_x509_crt_t *crt_list = nullptr;
@@ -149,7 +147,172 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
 	return oss.str();
 	return oss.str();
 }
 }
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt) {
+	const int size = 32;
+	uint8_t buffer[size];
+	std::stringstream fingerprint;
+
+	mbedtls::check(
+	    mbedtls_sha256(crt->raw.p, crt->raw.len, reinterpret_cast<unsigned char *>(buffer), 0),
+	    "Failed to generate certificate fingerprint");
+
+	for (auto i = 0; i < size; i++) {
+		fingerprint << std::setfill('0') << std::setw(2) << std::hex << static_cast<int>(buffer[i]);
+		if (i != (size - 1)) {
+			fingerprint << ":";
+		}
+	}
+
+	return fingerprint.str();
+}
+
+Certificate::Certificate(shared_ptr<mbedtls_x509_crt> crt, shared_ptr<mbedtls_pk_context> pk)
+    : mCrt(crt), mPk(pk), mFingerprint(make_fingerprint(crt)) {}
+
+Certificate Certificate::FromString(string crt_pem, string key_pem) {
+	PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)";
+
+	auto crt = mbedtls::new_x509_crt();
+	auto pk = mbedtls::new_pk_context();
+
+	mbedtls::check(mbedtls_x509_crt_parse(crt.get(),
+	                                      reinterpret_cast<const unsigned char *>(crt_pem.c_str()),
+	                                      crt_pem.length()),
+	               "Failed to parse certificate");
+	mbedtls::check(mbedtls_pk_parse_key(pk.get(),
+	                                    reinterpret_cast<const unsigned char *>(key_pem.c_str()),
+	                                    key_pem.size(), NULL, 0, NULL, 0),
+	               "Failed to parse key");
+
+	return Certificate(std::move(crt), std::move(pk));
+}
+
+Certificate Certificate::FromFile(const string &crt_pem_file, const string &key_pem_file,
+                                  const string &pass) {
+	PLOG_DEBUG << "Importing certificate from PEM file (MbedTLS): " << crt_pem_file;
+
+	auto crt = mbedtls::new_x509_crt();
+	auto pk = mbedtls::new_pk_context();
+
+	mbedtls::check(mbedtls_x509_crt_parse_file(crt.get(), crt_pem_file.c_str()),
+	               "Failed to parse certificate");
+	mbedtls::check(mbedtls_pk_parse_keyfile(pk.get(), key_pem_file.c_str(), pass.c_str(), 0, NULL),
+	               "Failed to parse key");
+
+	return Certificate(std::move(crt), std::move(pk));
+}
+
+Certificate Certificate::Generate(CertificateType type, const string &commonName) {
+	PLOG_DEBUG << "Generating certificate (MbedTLS)";
+
+	mbedtls_entropy_context entropy;
+	mbedtls_ctr_drbg_context drbg;
+	mbedtls_x509write_cert wcrt;
+	mbedtls_mpi serial;
+	auto crt = mbedtls::new_x509_crt();
+	auto pk = mbedtls::new_pk_context();
+
+	mbedtls_entropy_init(&entropy);
+	mbedtls_ctr_drbg_init(&drbg);
+	mbedtls_ctr_drbg_set_prediction_resistance(&drbg, MBEDTLS_CTR_DRBG_PR_ON);
+	mbedtls_x509write_crt_init(&wcrt);
+	mbedtls_mpi_init(&serial);
+
+	try {
+		mbedtls::check(mbedtls_ctr_drbg_seed(
+		    &drbg, mbedtls_entropy_func, &entropy,
+		    reinterpret_cast<const unsigned char *>(commonName.data()), commonName.size()));
+
+		switch (type) {
+		// RFC 8827 WebRTC Security Architecture 6.5. Communications Security
+		// All implementations MUST support DTLS 1.2 with the
+		// TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 cipher suite and the P-256 curve
+		// See https://www.rfc-editor.org/rfc/rfc8827.html#section-6.5
+		case CertificateType::Default:
+		case CertificateType::Ecdsa: {
+			mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_ECKEY)));
+			mbedtls::check(mbedtls_ecp_gen_key(MBEDTLS_ECP_DP_SECP256R1, mbedtls_pk_ec(*pk.get()),
+			                                   mbedtls_ctr_drbg_random, &drbg),
+			               "Unable to generate ECDSA P-256 key pair");
+			break;
+		}
+		case CertificateType::Rsa: {
+			const unsigned int nbits = 2048;
+			const int exponent = 65537;
+
+			mbedtls::check(mbedtls_pk_setup(pk.get(), mbedtls_pk_info_from_type(MBEDTLS_PK_RSA)));
+			mbedtls::check(mbedtls_rsa_gen_key(mbedtls_pk_rsa(*pk.get()), mbedtls_ctr_drbg_random,
+			                                   &drbg, nbits, exponent),
+			               "Unable to generate RSA key pair");
+			break;
+		}
+		default:
+			throw std::invalid_argument("Unknown certificate type");
+		}
+
+		auto now = std::chrono::system_clock::now();
+		string notBefore = mbedtls::format_time(now - std::chrono::hours(1));
+		string notAfter = mbedtls::format_time(now + std::chrono::hours(24 * 365));
+
+		const size_t serialBufferSize = 16;
+		unsigned char serialBuffer[serialBufferSize];
+		mbedtls::check(mbedtls_ctr_drbg_random(&drbg, serialBuffer, serialBufferSize),
+		               "Failed to generate certificate");
+		mbedtls::check(mbedtls_mpi_read_binary(&serial, serialBuffer, serialBufferSize),
+		               "Failed to generate certificate");
+
+		std::string name = std::string("O=" + commonName + ",CN=" + commonName);
+		mbedtls::check(mbedtls_x509write_crt_set_serial(&wcrt, &serial),
+		               "Failed to generate certificate");
+		mbedtls::check(mbedtls_x509write_crt_set_subject_name(&wcrt, name.c_str()),
+		               "Failed to generate certificate");
+		mbedtls::check(mbedtls_x509write_crt_set_issuer_name(&wcrt, name.c_str()),
+		               "Failed to generate certificate");
+		mbedtls::check(
+		    mbedtls_x509write_crt_set_validity(&wcrt, notBefore.c_str(), notAfter.c_str()),
+		    "Failed to generate certificate");
+
+		mbedtls_x509write_crt_set_version(&wcrt, MBEDTLS_X509_CRT_VERSION_3);
+		mbedtls_x509write_crt_set_subject_key(&wcrt, pk.get());
+		mbedtls_x509write_crt_set_issuer_key(&wcrt, pk.get());
+		mbedtls_x509write_crt_set_md_alg(&wcrt, MBEDTLS_MD_SHA256);
+
+		const size_t certificateBufferSize = 4096;
+		unsigned char certificateBuffer[certificateBufferSize];
+		std::memset(certificateBuffer, 0, certificateBufferSize);
+
+		auto certificateLen = mbedtls_x509write_crt_der(
+		    &wcrt, certificateBuffer, certificateBufferSize, mbedtls_ctr_drbg_random, &drbg);
+		if (certificateLen <= 0) {
+			throw std::runtime_error("Certificate generation failed");
+		}
+
+		mbedtls::check(mbedtls_x509_crt_parse_der(
+		                   crt.get(), (certificateBuffer + certificateBufferSize - certificateLen),
+		                   certificateLen),
+		               "Failed to generate certificate");
+	} catch (...) {
+		mbedtls_entropy_free(&entropy);
+		mbedtls_ctr_drbg_free(&drbg);
+		mbedtls_x509write_crt_free(&wcrt);
+		mbedtls_mpi_free(&serial);
+		throw;
+	}
+
+	mbedtls_entropy_free(&entropy);
+	mbedtls_ctr_drbg_free(&drbg);
+	mbedtls_x509write_crt_free(&wcrt);
+	mbedtls_mpi_free(&serial);
+	return Certificate(std::move(crt), std::move(pk));
+}
+
+std::tuple<shared_ptr<mbedtls_x509_crt>, shared_ptr<mbedtls_pk_context>>
+Certificate::credentials() const {
+	return {mCrt, mPk};
+}
+
+#else // OPENSSL
 
 
 namespace {
 namespace {
 
 
@@ -291,8 +454,6 @@ Certificate Certificate::Generate(CertificateType type, const string &commonName
 Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
 Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey)
     : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {}
     : mX509(std::move(x509)), mPKey(std::move(pkey)), mFingerprint(make_fingerprint(mX509.get())) {}
 
 
-string Certificate::fingerprint() const { return mFingerprint; }
-
 std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
 std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const {
 	return {mX509.get(), mPKey.get()};
 	return {mX509.get(), mPKey.get()};
 }
 }
@@ -316,7 +477,7 @@ string make_fingerprint(X509 *x509) {
 
 
 #endif
 #endif
 
 
-// Common for GnuTLS and OpenSSL
+// Common for GnuTLS, Mbed TLS, and OpenSSL
 
 
 future_certificate_ptr make_certificate(CertificateType type) {
 future_certificate_ptr make_certificate(CertificateType type) {
 	return ThreadPool::Instance().enqueue([type, token = Init::Instance().token()]() {
 	return ThreadPool::Instance().enqueue([type, token = Init::Instance().token()]() {
@@ -324,4 +485,6 @@ future_certificate_ptr make_certificate(CertificateType type) {
 	});
 	});
 }
 }
 
 
+string Certificate::fingerprint() const { return mFingerprint; }
+
 } // namespace rtc::impl
 } // namespace rtc::impl

+ 9 - 1
src/impl/certificate.hpp

@@ -29,7 +29,10 @@ public:
 #if USE_GNUTLS
 #if USE_GNUTLS
 	Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
 	Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
 	gnutls_certificate_credentials_t credentials() const;
 	gnutls_certificate_credentials_t credentials() const;
-#else
+#elif USE_MBEDTLS
+	Certificate(shared_ptr<mbedtls_x509_crt> crt, shared_ptr<mbedtls_pk_context> pk);
+	std::tuple<shared_ptr<mbedtls_x509_crt>, shared_ptr<mbedtls_pk_context>> credentials() const;
+#else // OPENSSL
 	Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey);
 	Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey);
 	std::tuple<X509 *, EVP_PKEY *> credentials() const;
 	std::tuple<X509 *, EVP_PKEY *> credentials() const;
 #endif
 #endif
@@ -42,6 +45,9 @@ private:
 #if USE_GNUTLS
 #if USE_GNUTLS
 	Certificate(shared_ptr<gnutls_certificate_credentials_t> creds);
 	Certificate(shared_ptr<gnutls_certificate_credentials_t> creds);
 	const shared_ptr<gnutls_certificate_credentials_t> mCredentials;
 	const shared_ptr<gnutls_certificate_credentials_t> mCredentials;
+#elif USE_MBEDTLS
+	const shared_ptr<mbedtls_x509_crt> mCrt;
+	const shared_ptr<mbedtls_pk_context> mPk;
 #else
 #else
 	const shared_ptr<X509> mX509;
 	const shared_ptr<X509> mX509;
 	const shared_ptr<EVP_PKEY> mPKey;
 	const shared_ptr<EVP_PKEY> mPKey;
@@ -53,6 +59,8 @@ private:
 #if USE_GNUTLS
 #if USE_GNUTLS
 string make_fingerprint(gnutls_certificate_credentials_t credentials);
 string make_fingerprint(gnutls_certificate_credentials_t credentials);
 string make_fingerprint(gnutls_x509_crt_t crt);
 string make_fingerprint(gnutls_x509_crt_t crt);
+#elif USE_MBEDTLS
+string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt);
 #else
 #else
 string make_fingerprint(X509 *x509);
 string make_fingerprint(X509 *x509);
 #endif
 #endif

+ 36 - 2
src/impl/dtlssrtptransport.cpp

@@ -213,7 +213,7 @@ bool DtlsSrtpTransport::demuxMessage(message_ptr message) {
 	} else {
 	} else {
 		COUNTER_UNKNOWN_PACKET_TYPE++;
 		COUNTER_UNKNOWN_PACKET_TYPE++;
 		PLOG_DEBUG << "Unknown packet type, value=" << unsigned(value1)
 		PLOG_DEBUG << "Unknown packet type, value=" << unsigned(value1)
-		             << ", size=" << message->size();
+		           << ", size=" << message->size();
 		return true;
 		return true;
 	}
 	}
 }
 }
@@ -263,12 +263,46 @@ void DtlsSrtpTransport::postHandshake() {
 
 
 	serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
 	serverKey = reinterpret_cast<const unsigned char *>(serverKeyDatum.data);
 	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
 	serverSalt = reinterpret_cast<const unsigned char *>(serverSaltDatum.data);
+#elif USE_MBEDTLS
+	PLOG_INFO << "Deriving SRTP keying material (Mbed TLS)";
+	unsigned int keySize = SRTP_AES_128_KEY_LEN;
+	unsigned int saltSize = SRTP_SALT_LEN;
+	auto srtpProfile = srtp_profile_aes128_cm_sha1_80;
+	auto keySizeWithSalt = SRTP_AES_ICM_128_KEY_LEN_WSALT;
+	mbedtls_dtls_srtp_info srtpInfo;
+
+	mbedtls_ssl_get_dtls_srtp_negotiation_result(&mSsl, &srtpInfo);
+	if (srtpInfo.private_chosen_dtls_srtp_profile != MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80) {
+		throw std::runtime_error("Failed to get SRTP profile");
+	}
+
+	const size_t materialLen = keySizeWithSalt * 2;
+	std::vector<unsigned char> material(materialLen);
+	// The extractor provides the client write master key, the server write master key, the client
+	// write master salt and the server write master salt in that order.
+	const string label = "EXTRACTOR-dtls_srtp";
+
+	if (mTlsProfile == MBEDTLS_SSL_TLS_PRF_NONE) {
+		throw std::logic_error("Failed to get SRTP profile");
+	}
+
+	if (mbedtls_ssl_tls_prf(mTlsProfile, reinterpret_cast<const unsigned char *>(mMasterSecret), 32,
+	                        label.c_str(), reinterpret_cast<const unsigned char *>(mRandBytes), 32,
+	                        material.data(), materialLen) != 0) {
+		throw std::runtime_error("Failed to derive SRTP keys");
+	}
+
+	// Order is client key, server key, client salt, and server salt
+	clientKey = material.data();
+	serverKey = clientKey + keySize;
+	clientSalt = serverKey + keySize;
+	serverSalt = clientSalt + saltSize;
 #else
 #else
 	PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
 	PLOG_INFO << "Deriving SRTP keying material (OpenSSL)";
 	auto profile = SSL_get_selected_srtp_profile(mSsl);
 	auto profile = SSL_get_selected_srtp_profile(mSsl);
 	if (!profile)
 	if (!profile)
 		throw std::runtime_error("Failed to get SRTP profile: " +
 		throw std::runtime_error("Failed to get SRTP profile: " +
-					openssl::error_string(ERR_get_error()));
+		                         openssl::error_string(ERR_get_error()));
 	PLOG_DEBUG << "srtp profile used is: " << profile->name;
 	PLOG_DEBUG << "srtp profile used is: " << profile->name;
 	auto [keySize, saltSize, srtpProfile] = getEncryptionParams(profile->name);
 	auto [keySize, saltSize, srtpProfile] = getEncryptionParams(profile->name);
 	auto keySizeWithSalt = keySize + saltSize;
 	auto keySizeWithSalt = keySize + saltSize;

+ 298 - 3
src/impl/dtlstransport.cpp

@@ -360,7 +360,302 @@ int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* m
 	}
 	}
 }
 }
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+
+mbedtls_ssl_srtp_profile srtpSupportedProtectionProfiles[] = {
+    MBEDTLS_TLS_SRTP_AES128_CM_HMAC_SHA1_80,
+    MBEDTLS_TLS_SRTP_UNSET,
+};
+
+DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr certificate,
+                             optional<size_t> mtu, verifier_callback verifierCallback,
+                             state_callback stateChangeCallback)
+    : Transport(lower, std::move(stateChangeCallback)), mMtu(mtu), mCertificate(certificate),
+      mVerifierCallback(std::move(verifierCallback)),
+      mIsClient(lower->role() == Description::Role::Active) {
+
+	PLOG_DEBUG << "Initializing DTLS transport (MbedTLS)";
+
+	if (!mCertificate)
+		throw std::invalid_argument("DTLS certificate is null");
+
+	mbedtls_entropy_init(&mEntropy);
+	mbedtls_ctr_drbg_init(&mDrbg);
+	mbedtls_ssl_init(&mSsl);
+	mbedtls_ssl_config_init(&mConf);
+	mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON);
+
+	try {
+		mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0),
+		               "Failed creating Mbed TLS Context");
+
+		mbedtls::check(mbedtls_ssl_config_defaults(
+		                   &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
+		                   MBEDTLS_SSL_TRANSPORT_DATAGRAM, MBEDTLS_SSL_PRESET_DEFAULT),
+		               "Failed creating Mbed TLS Context");
+
+		mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+		mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg);
+
+		auto [crt, pk] = mCertificate->credentials();
+		mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get()),
+		               "Failed creating Mbed TLS Context");
+
+		mbedtls_ssl_conf_dtls_cookies(&mConf, NULL, NULL, NULL);
+		mbedtls_ssl_conf_dtls_srtp_protection_profiles(&mConf, srtpSupportedProtectionProfiles);
+
+		mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf), "Failed creating Mbed TLS Context");
+
+		size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
+		mbedtls_ssl_set_mtu(&mSsl, static_cast<unsigned int>(mtu));
+		PLOG_VERBOSE << "DTLS MTU set to " << mtu;
+
+		mbedtls_ssl_set_export_keys_cb(&mSsl, DtlsTransport::ExportKeysCallback, this);
+		mbedtls_ssl_set_bio(&mSsl, this, WriteCallback, ReadCallback, NULL);
+		mbedtls_ssl_set_timer_cb(&mSsl, this, SetTimerCallback, GetTimerCallback);
+	} catch (...) {
+		mbedtls_entropy_free(&mEntropy);
+		mbedtls_ctr_drbg_free(&mDrbg);
+		mbedtls_ssl_free(&mSsl);
+		mbedtls_ssl_config_free(&mConf);
+		throw;
+	}
+
+	// Set recommended medium-priority DSCP value for handshake
+	// See https://www.rfc-editor.org/rfc/rfc8837.html#section-5
+	mCurrentDscp = 10; // AF11: Assured Forwarding class 1, low drop probability
+}
+
+DtlsTransport::~DtlsTransport() {
+	stop();
+
+	PLOG_DEBUG << "Destroying DTLS transport";
+	mbedtls_entropy_free(&mEntropy);
+	mbedtls_ctr_drbg_free(&mDrbg);
+	mbedtls_ssl_free(&mSsl);
+	mbedtls_ssl_config_free(&mConf);
+}
+
+void DtlsTransport::Init() {
+	// Nothing to do
+}
+
+void DtlsTransport::Cleanup() {
+	// Nothing to do
+}
+
+void DtlsTransport::start() {
+	PLOG_DEBUG << "Starting DTLS transport";
+	registerIncoming();
+	changeState(State::Connecting);
+
+	enqueueRecv(); // to initiate the handshake
+}
+
+void DtlsTransport::stop() {
+	PLOG_DEBUG << "Stopping DTLS transport";
+	unregisterIncoming();
+	mIncomingQueue.stop();
+	enqueueRecv();
+}
+
+bool DtlsTransport::send(message_ptr message) {
+	if (!message || state() != State::Connected)
+		return false;
+
+	PLOG_VERBOSE << "Send size=" << message->size();
+
+	int ret;
+	do {
+		std::lock_guard lock(mMutex);
+		mCurrentDscp = message->dscp;
+
+		if (message->size() > size_t(mbedtls_ssl_get_max_out_record_payload(&mSsl)))
+			return false;
+
+		ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
+		                        message->size());
+	} while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
+	mbedtls::check(ret);
+
+	return mOutgoingResult;
+}
+
+void DtlsTransport::incoming(message_ptr message) {
+	if (!message) {
+		mIncomingQueue.stop();
+		return;
+	}
+
+	PLOG_VERBOSE << "Incoming size=" << message->size();
+	mIncomingQueue.push(message);
+	enqueueRecv();
+}
+
+bool DtlsTransport::outgoing(message_ptr message) {
+	message->dscp = mCurrentDscp;
+
+	bool result = Transport::outgoing(std::move(message));
+	mOutgoingResult = result;
+	return result;
+}
+
+bool DtlsTransport::demuxMessage(message_ptr) {
+	// Dummy
+	return false;
+}
+
+void DtlsTransport::postHandshake() {
+	// Dummy
+}
+
+void DtlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
+
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
+
+	try {
+		const size_t bufferSize = 4096;
+		char buffer[bufferSize];
+
+		// Handle handshake if connecting
+		if (state() == State::Connecting) {
+			while (true) {
+				auto ret = mbedtls_ssl_handshake(&mSsl);
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+				ThreadPool::Instance().schedule(mTimerSetAt + milliseconds(mFinMs), [weak_this = weak_from_this()]() {
+					if (auto locked = weak_this.lock())
+						locked->doRecv();
+					});
+				return;
+				} else if ( ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+					continue;
+				}
+
+				mbedtls::check(ret);
+				PLOG_INFO << "DTLS handshake finished";
+				changeState(State::Connected);
+				postHandshake();
+				break;
+			}
+		}
+
+		if (state() == State::Connected) {
+			while (true) {
+				mMutex.lock();
+				auto ret =
+				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
+				mMutex.unlock();
+
+				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
+					// Closed
+					PLOG_DEBUG << "DTLS connection cleanly closed";
+					break;
+				}
+
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+				    ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS ||
+				    ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+					return;
+				}
+				mbedtls::check(ret);
+
+				auto *b = reinterpret_cast<byte *>(buffer);
+				recv(make_message(b, b + ret));
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "DTLS recv: " << e.what();
+	}
+
+	PLOG_INFO << "DTLS closed";
+	changeState(State::Disconnected);
+	recv(nullptr);
+}
+
+void DtlsTransport::ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type /*type*/,
+                                       const unsigned char *secret, size_t secret_len,
+                                       const unsigned char client_random[32],
+                                       const unsigned char server_random[32],
+                                       mbedtls_tls_prf_types tls_prf_type) {
+	auto dtlsTransport = static_cast<DtlsTransport *>(ctx);
+	std::memcpy(dtlsTransport->mMasterSecret, secret, secret_len);
+	std::memcpy(dtlsTransport->mRandBytes, client_random, 32);
+	std::memcpy(dtlsTransport->mRandBytes + 32, server_random, 32);
+	dtlsTransport->mTlsProfile = tls_prf_type;
+}
+
+int DtlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) {
+	auto *t = static_cast<DtlsTransport *>(ctx);
+	try {
+		if (len > 0) {
+			auto b = reinterpret_cast<const byte *>(buf);
+			t->outgoing(make_message(b, b + len));
+		}
+		return int(len);
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+	}
+}
+
+int DtlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) {
+	auto *t = static_cast<DtlsTransport *>(ctx);
+	try {
+		while (t->mIncomingQueue.running()) {
+			auto next = t->mIncomingQueue.pop();
+			if (!next) {
+				return MBEDTLS_ERR_SSL_WANT_READ;
+			}
+
+			message_ptr message = std::move(*next);
+			if (t->demuxMessage(message))
+				continue;
+
+			auto bufMin = std::min(len, size_t(message->size()));
+			std::memcpy(buf, message->data(), bufMin);
+			return int(len);
+		}
+
+		// Closed
+		return 0;
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+		;
+	}
+}
+
+void DtlsTransport::SetTimerCallback(void *ctx, uint32_t int_ms, uint32_t fin_ms) {
+	auto dtlsTransport = static_cast<DtlsTransport *>(ctx);
+	dtlsTransport->mIntMs = int_ms;
+	dtlsTransport->mFinMs = fin_ms;
+
+	if (fin_ms != 0) {
+		dtlsTransport->mTimerSetAt = std::chrono::steady_clock::now();
+	}
+}
+
+int DtlsTransport::GetTimerCallback(void *ctx) {
+	auto dtlsTransport = static_cast<DtlsTransport *>(ctx);
+	auto now = std::chrono::steady_clock::now();
+
+	if (dtlsTransport->mFinMs == 0) {
+		return -1;
+	} else if (now >= dtlsTransport->mTimerSetAt + milliseconds(dtlsTransport->mFinMs)) {
+		return 2;
+	} else if (now >= dtlsTransport->mTimerSetAt + milliseconds(dtlsTransport->mIntMs)) {
+		return 1;
+	} else {
+		return 0;
+	}
+}
+
+#else // OPENSSL
 
 
 BIO_METHOD *DtlsTransport::BioMethods = NULL;
 BIO_METHOD *DtlsTransport::BioMethods = NULL;
 int DtlsTransport::TransportExIndex = -1;
 int DtlsTransport::TransportExIndex = -1;
@@ -415,8 +710,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 
 
 		SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION);
 		SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION);
 		SSL_CTX_set_read_ahead(mCtx, 1);
 		SSL_CTX_set_read_ahead(mCtx, 1);
-		//sent the dtls close_notify alert
-		//SSL_CTX_set_quiet_shutdown(mCtx, 1);
+		// sent the dtls close_notify alert
+		// SSL_CTX_set_quiet_shutdown(mCtx, 1);
 		SSL_CTX_set_info_callback(mCtx, InfoCallback);
 		SSL_CTX_set_info_callback(mCtx, InfoCallback);
 
 
 		SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
 		SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,

+ 27 - 1
src/impl/dtlstransport.hpp

@@ -70,7 +70,33 @@ protected:
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
-#else
+
+#elif USE_MBEDTLS
+	std::mutex mMutex;
+
+	mbedtls_entropy_context mEntropy;
+	mbedtls_ctr_drbg_context mDrbg;
+	mbedtls_ssl_config mConf;
+	mbedtls_ssl_context mSsl;
+
+	uint32_t mFinMs = 0, mIntMs = 0;
+	std::chrono::time_point<std::chrono::steady_clock> mTimerSetAt;
+
+	char mMasterSecret[48];
+	char mRandBytes[64];
+	mbedtls_tls_prf_types mTlsProfile = MBEDTLS_SSL_TLS_PRF_NONE;
+
+	static int WriteCallback(void *ctx, const unsigned char *buf, size_t len);
+	static int ReadCallback(void *ctx, unsigned char *buf, size_t len);
+	static void ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type type,
+	                               const unsigned char *secret, size_t secret_len,
+	                               const unsigned char client_random[32],
+	                               const unsigned char server_random[32],
+	                               mbedtls_tls_prf_types tls_prf_type);
+	static void SetTimerCallback(void *ctx, uint32_t int_ms, uint32_t fin_ms);
+	static int GetTimerCallback(void *ctx);
+
+#else // OPENSSL
 	SSL_CTX *mCtx = NULL;
 	SSL_CTX *mCtx = NULL;
 	SSL *mSsl = NULL;
 	SSL *mSsl = NULL;
 	BIO *mInBio, *mOutBio;
 	BIO *mInBio, *mOutBio;

+ 2 - 0
src/impl/init.cpp

@@ -128,6 +128,8 @@ void Init::doInit() {
 
 
 #if USE_GNUTLS
 #if USE_GNUTLS
 	// Nothing to do
 	// Nothing to do
+#elif USE_MBEDTLS
+	// Nothing to do
 #else
 #else
 	openssl::init();
 	openssl::init();
 #endif
 #endif

+ 13 - 2
src/impl/sha.cpp

@@ -14,7 +14,11 @@
 
 
 #include <nettle/sha1.h>
 #include <nettle/sha1.h>
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+
+#include <mbedtls/sha1.h>
+
+#else
 
 
 #ifndef OPENSSL_API_COMPAT
 #ifndef OPENSSL_API_COMPAT
 #define OPENSSL_API_COMPAT 0x10100000L
 #define OPENSSL_API_COMPAT 0x10100000L
@@ -38,7 +42,14 @@ binary Sha1(const byte *data, size_t size) {
 	sha1_digest(&ctx, SHA1_DIGEST_SIZE, reinterpret_cast<uint8_t *>(output.data()));
 	sha1_digest(&ctx, SHA1_DIGEST_SIZE, reinterpret_cast<uint8_t *>(output.data()));
 	return output;
 	return output;
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+
+	binary output(20);
+	mbedtls_sha1(reinterpret_cast<const unsigned char *>(data), size,
+	             reinterpret_cast<unsigned char *>(output.data()));
+	return output;
+
+#else
 
 
 	binary output(SHA_DIGEST_LENGTH);
 	binary output(SHA_DIGEST_LENGTH);
 	SHA_CTX ctx;
 	SHA_CTX ctx;

+ 75 - 1
src/impl/tls.cpp

@@ -70,7 +70,81 @@ gnutls_datum_t make_datum(char *data, size_t size) {
 
 
 } // namespace rtc::gnutls
 } // namespace rtc::gnutls
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+
+#include <time.h>
+
+namespace {
+
+// Safe gmtime
+int my_gmtime(const time_t *t, struct tm *buf) {
+#ifdef _WIN32
+	return ::gmtime_s(buf, t) == 0 ? 0 : -1;
+#else // POSIX
+	return ::gmtime_r(t, buf) != NULL ? 0 : -1;
+#endif
+}
+
+// Format time_t as UTC
+size_t my_strftme(char *buf, size_t size, const char *format, const time_t *t) {
+	struct tm g;
+	if (my_gmtime(t, &g) != 0)
+		return 0;
+
+	return ::strftime(buf, size, format, &g);
+}
+
+} // namespace
+
+namespace rtc::mbedtls {
+
+void check(int ret, const string &message) {
+	if (ret < 0) {
+		const size_t bufferSize = 1024;
+		char buffer[bufferSize];
+		mbedtls_strerror(ret, reinterpret_cast<char *>(buffer), bufferSize);
+		PLOG_ERROR << message << ": " << buffer;
+		throw std::runtime_error(message + ": " + std::string(buffer));
+	}
+}
+
+string format_time(const std::chrono::system_clock::time_point &tp) {
+	time_t t = std::chrono::system_clock::to_time_t(tp);
+	const size_t bufferSize = 256;
+	char buffer[bufferSize];
+	if (my_strftme(buffer, bufferSize, "%Y%m%d%H%M%S", &t) == 0)
+		throw std::runtime_error("Time conversion failed");
+
+	return string(buffer);
+};
+
+std::shared_ptr<mbedtls_pk_context> new_pk_context() {
+	return std::shared_ptr<mbedtls_pk_context>{[]() {
+		                                           auto p = new mbedtls_pk_context;
+		                                           mbedtls_pk_init(p);
+		                                           return p;
+	                                           }(),
+	                                           [](mbedtls_pk_context *p) {
+		                                           mbedtls_pk_free(p);
+		                                           delete p;
+	                                           }};
+}
+
+std::shared_ptr<mbedtls_x509_crt> new_x509_crt() {
+	return std::shared_ptr<mbedtls_x509_crt>{[]() {
+		                                         auto p = new mbedtls_x509_crt;
+		                                         mbedtls_x509_crt_init(p);
+		                                         return p;
+	                                         }(),
+	                                         [](mbedtls_x509_crt *crt) {
+		                                         mbedtls_x509_crt_free(crt);
+		                                         delete crt;
+	                                         }};
+}
+
+} // namespace rtc::mbedtls
+
+#else // OPENSSL
 
 
 namespace rtc::openssl {
 namespace rtc::openssl {
 
 

+ 26 - 1
src/impl/tls.hpp

@@ -11,6 +11,8 @@
 
 
 #include "common.hpp"
 #include "common.hpp"
 
 
+#include <chrono>
+
 #if USE_GNUTLS
 #if USE_GNUTLS
 
 
 #include <gnutls/gnutls.h>
 #include <gnutls/gnutls.h>
@@ -36,7 +38,30 @@ gnutls_datum_t make_datum(char *data, size_t size);
 
 
 } // namespace rtc::gnutls
 } // namespace rtc::gnutls
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+
+#include "mbedtls/ctr_drbg.h"
+#include "mbedtls/ecdsa.h"
+#include "mbedtls/entropy.h"
+#include "mbedtls/error.h"
+#include "mbedtls/pk.h"
+#include "mbedtls/rsa.h"
+#include "mbedtls/sha256.h"
+#include "mbedtls/ssl.h"
+#include "mbedtls/x509_crt.h"
+
+namespace rtc::mbedtls {
+
+void check(int ret, const string &message = "MbedTLS error");
+
+string format_time(const std::chrono::system_clock::time_point &tp);
+
+std::shared_ptr<mbedtls_pk_context> new_pk_context();
+std::shared_ptr<mbedtls_x509_crt> new_x509_crt();
+
+} // namespace rtc::mbedtls
+
+#else // OPENSSL
 
 
 #ifdef _WIN32
 #ifdef _WIN32
 // Include winsock2.h header first since OpenSSL may include winsock.h
 // Include winsock2.h header first since OpenSSL may include winsock.h

+ 212 - 1
src/impl/tlstransport.cpp

@@ -296,7 +296,218 @@ int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms
 	}
 	}
 }
 }
 
 
-#else // USE_GNUTLS==0
+#elif USE_MBEDTLS
+
+void TlsTransport::Init() {
+	// Nothing to do
+}
+
+void TlsTransport::Cleanup() {
+	// Nothing to do
+}
+
+TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower,
+                           optional<string> host, certificate_ptr certificate,
+                           state_callback callback)
+    : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
+                std::move(callback)),
+      mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)),
+      mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
+
+	PLOG_DEBUG << "Initializing TLS transport (MbedTLS)";
+
+	mbedtls_entropy_init(&mEntropy);
+	mbedtls_ctr_drbg_init(&mDrbg);
+	mbedtls_ssl_init(&mSsl);
+	mbedtls_ssl_config_init(&mConf);
+	mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON);
+
+	try {
+		mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0));
+
+		mbedtls::check(mbedtls_ssl_config_defaults(
+		    &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
+		    MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT));
+
+		mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+		mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg);
+
+		if (certificate) {
+			auto [crt, pk] = certificate->credentials();
+			mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get()));
+		}
+
+		mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf));
+		mbedtls_ssl_set_bio(&mSsl, static_cast<void *>(this), WriteCallback, ReadCallback, NULL);
+	} catch (...) {
+		mbedtls_entropy_free(&mEntropy);
+		mbedtls_ctr_drbg_free(&mDrbg);
+		mbedtls_ssl_free(&mSsl);
+		mbedtls_ssl_config_free(&mConf);
+		throw;
+	}
+}
+
+TlsTransport::~TlsTransport() {}
+
+void TlsTransport::start() {
+	PLOG_DEBUG << "Starting TLS transport";
+	registerIncoming();
+	changeState(State::Connecting);
+	enqueueRecv(); // to initiate the handshake
+}
+
+void TlsTransport::stop() {
+	PLOG_DEBUG << "Stopping TLS transport";
+	unregisterIncoming();
+	mIncomingQueue.stop();
+	enqueueRecv();
+}
+
+bool TlsTransport::send(message_ptr message) {
+	if (state() != State::Connected)
+		throw std::runtime_error("TLS is not open");
+
+	if (!message || message->size() == 0)
+		return outgoing(message); // pass through
+
+	PLOG_VERBOSE << "Send size=" << message->size();
+
+	mbedtls::check(mbedtls_ssl_write(
+	    &mSsl, reinterpret_cast<const unsigned char *>(message->data()), int(message->size())));
+
+	return mOutgoingResult;
+}
+
+void TlsTransport::incoming(message_ptr message) {
+	if (!message) {
+		mIncomingQueue.stop();
+		enqueueRecv();
+		return;
+	}
+
+	PLOG_VERBOSE << "Incoming size=" << message->size();
+	mIncomingQueue.push(message);
+	enqueueRecv();
+}
+
+bool TlsTransport::outgoing(message_ptr message) {
+	bool result = Transport::outgoing(std::move(message));
+	mOutgoingResult = result;
+	return result;
+}
+
+void TlsTransport::postHandshake() {
+	// Dummy
+}
+
+void TlsTransport::doRecv() {
+	std::lock_guard lock(mRecvMutex);
+	--mPendingRecvCount;
+
+	if (state() != State::Connecting && state() != State::Connected)
+		return;
+
+	try {
+		const size_t bufferSize = 4096;
+		char buffer[bufferSize];
+
+		// Handle handshake if connecting
+		if (state() == State::Connecting) {
+			while (true) {
+				auto ret = mbedtls_ssl_handshake(&mSsl);
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+					return;
+				} else if ( ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+					continue;
+				}
+
+				mbedtls::check(ret);
+				PLOG_INFO << "TLS handshake finished";
+				changeState(State::Connected);
+				postHandshake();
+				break;
+			}
+		}
+
+		if (state() == State::Connected) {
+			while (true) {
+				auto ret =
+				    mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer), bufferSize);
+
+				if (ret == 0 || ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
+					// Closed
+					PLOG_DEBUG << "TLS connection cleanly closed";
+					break;
+				}
+
+				if (ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
+					return;
+				} else if ( ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS || ret == MBEDTLS_ERR_SSL_CRYPTO_IN_PROGRESS) {
+					continue;
+				}
+				mbedtls::check(ret);
+
+				auto *b = reinterpret_cast<byte *>(buffer);
+				recv(make_message(b, b + ret));
+			}
+		}
+	} catch (const std::exception &e) {
+		PLOG_ERROR << "TLS recv: " << e.what();
+	}
+
+	PLOG_INFO << "TLS closed";
+	changeState(State::Disconnected);
+	recv(nullptr);
+}
+
+int TlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) {
+	auto *t = static_cast<TlsTransport *>(ctx);
+	auto *b = reinterpret_cast<const byte *>(buf);
+	t->outgoing(make_message(b, b + len));
+
+	return int(len);
+}
+
+int TlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) {
+	TlsTransport *t = static_cast<TlsTransport *>(ctx);
+	try {
+		message_ptr &message = t->mIncomingMessage;
+		size_t &position = t->mIncomingMessagePosition;
+
+		if (message && position >= message->size())
+			message.reset();
+
+		if (!message) {
+			position = 0;
+			while (auto next = t->mIncomingQueue.pop()) {
+				message = *next;
+				if (message->size() > 0)
+					break;
+				else
+					t->recv(message); // Pass zero-sized messages through
+			}
+		}
+
+		if (message) {
+			size_t available = message->size() - position;
+			size_t writeLen = std::min(len, available);
+			std::memcpy(buf, message->data() + position, writeLen);
+			position += writeLen;
+			return int(writeLen);
+		} else if (t->mIncomingQueue.running()) {
+			return MBEDTLS_ERR_SSL_WANT_READ;
+		} else {
+			return MBEDTLS_ERR_SSL_CONN_EOF;
+		}
+
+	} catch (const std::exception &e) {
+		PLOG_WARNING << e.what();
+		return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+	}
+}
+
+#else
 
 
 int TlsTransport::TransportExIndex = -1;
 int TlsTransport::TransportExIndex = -1;
 
 

+ 15 - 0
src/impl/tlstransport.hpp

@@ -65,6 +65,21 @@ protected:
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
 	static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
+#elif USE_MBEDTLS
+	std::mutex mSendMutex;
+	std::atomic<bool> mOutgoingResult = true;
+
+	mbedtls_entropy_context mEntropy;
+	mbedtls_ctr_drbg_context mDrbg;
+	mbedtls_ssl_config mConf;
+	mbedtls_ssl_context mSsl;
+
+	message_ptr mIncomingMessage;
+	size_t mIncomingMessagePosition = 0;
+
+	static int WriteCallback(void *ctx, const unsigned char *buf, size_t len);
+	static int ReadCallback(void *ctx, unsigned char *buf, size_t len);
+
 #else
 #else
 	SSL_CTX *mCtx;
 	SSL_CTX *mCtx;
 	SSL *mSsl;
 	SSL *mSsl;

+ 4 - 2
src/impl/verifiedtlstransport.cpp

@@ -18,11 +18,13 @@ VerifiedTlsTransport::VerifiedTlsTransport(
     certificate_ptr certificate, state_callback callback)
     certificate_ptr certificate, state_callback callback)
     : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
     : TlsTransport(std::move(lower), std::move(host), std::move(certificate), std::move(callback)) {
 
 
-#if USE_GNUTLS
 	PLOG_DEBUG << "Setting up TLS certificate verification";
 	PLOG_DEBUG << "Setting up TLS certificate verification";
+
+#if USE_GNUTLS
 	gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0);
 	gnutls_session_set_verify_cert(mSession, mHost->c_str(), 0);
+#elif USE_MBEDTLS
+	mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_REQUIRED);
 #else
 #else
-	PLOG_DEBUG << "Setting up TLS certificate verification";
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
 	SSL_set_verify(mSsl, SSL_VERIFY_PEER, NULL);
 	SSL_set_verify_depth(mSsl, 4);
 	SSL_set_verify_depth(mSsl, 4);
 #endif
 #endif