Prechádzať zdrojové kódy

Check connection fingerprint with mbedtls

Gasper Lah 2 rokov pred
rodič
commit
073ff5183a

+ 2 - 2
src/impl/certificate.cpp

@@ -148,7 +148,7 @@ string make_fingerprint(gnutls_x509_crt_t crt) {
 }
 
 #elif USE_MBEDTLS
-string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt) {
+string make_fingerprint(mbedtls_x509_crt* crt) {
 	const int size = 32;
 	uint8_t buffer[size];
 	std::stringstream fingerprint;
@@ -168,7 +168,7 @@ string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt) {
 }
 
 Certificate::Certificate(shared_ptr<mbedtls_x509_crt> crt, shared_ptr<mbedtls_pk_context> pk)
-    : mCrt(crt), mPk(pk), mFingerprint(make_fingerprint(crt)) {}
+    : mCrt(crt), mPk(pk), mFingerprint(make_fingerprint(crt.get())) {}
 
 Certificate Certificate::FromString(string crt_pem, string key_pem) {
 	PLOG_DEBUG << "Importing certificate from PEM string (MbedTLS)";

+ 1 - 1
src/impl/certificate.hpp

@@ -60,7 +60,7 @@ private:
 string make_fingerprint(gnutls_certificate_credentials_t credentials);
 string make_fingerprint(gnutls_x509_crt_t crt);
 #elif USE_MBEDTLS
-string make_fingerprint(shared_ptr<mbedtls_x509_crt> crt);
+string make_fingerprint(mbedtls_x509_crt* crt);
 #else
 string make_fingerprint(X509 *x509);
 #endif

+ 9 - 0
src/impl/dtlstransport.cpp

@@ -400,6 +400,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		               "Failed creating Mbed TLS Context");
 
 		mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL);
+		mbedtls_ssl_conf_verify(&mConf, DtlsTransport::CertificateCallback, this);
+
 		mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg);
 
 		auto [crt, pk] = mCertificate->credentials();
@@ -603,6 +605,13 @@ void DtlsTransport::doRecv() {
 	}
 }
 
+int DtlsTransport::CertificateCallback(void *ctx, mbedtls_x509_crt *crt, int /*depth*/, uint32_t */*flags*/) {
+	auto this_ = static_cast<DtlsTransport *>(ctx);
+	string fingerprint = make_fingerprint(crt);
+	std::transform(fingerprint.begin(), fingerprint.end(), fingerprint.begin(), [](char c) { return char(std::toupper(c)); });
+	return this_->mVerifierCallback(fingerprint) ? 0 : 1;
+}
+
 void DtlsTransport::ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type /*type*/,
                                        const unsigned char *secret, size_t secret_len,
                                        const unsigned char client_random[32],

+ 1 - 0
src/impl/dtlstransport.hpp

@@ -85,6 +85,7 @@ protected:
 	char mRandBytes[64];
 	mbedtls_tls_prf_types mTlsProfile = MBEDTLS_SSL_TLS_PRF_NONE;
 
+	static int CertificateCallback(void *ctx, mbedtls_x509_crt *crt, int depth, uint32_t *flags);
 	static int WriteCallback(void *ctx, const unsigned char *buf, size_t len);
 	static int ReadCallback(void *ctx, unsigned char *buf, size_t len);
 	static void ExportKeysCallback(void *ctx, mbedtls_ssl_key_export_type type,

+ 11 - 2
test/connectivity.cpp

@@ -21,7 +21,7 @@ using namespace std;
 
 template <class T> weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
 
-void test_connectivity() {
+void test_connectivity(bool signal_wrong_fingerprint) {
 	InitLogger(LogLevel::Debug);
 
 	Configuration config1;
@@ -47,8 +47,17 @@ void test_connectivity() {
 
 	PeerConnection pc2(config2);
 
-	pc1.onLocalDescription([&pc2](Description sdp) {
+	pc1.onLocalDescription([&pc2, signal_wrong_fingerprint](Description sdp) {
 		cout << "Description 1: " << sdp << endl;
+		if (signal_wrong_fingerprint) {
+			auto f = sdp.fingerprint();
+			if (f.has_value()) {
+				auto s = f.value();
+				auto& c = s[0];
+				if (c == 'F' || c == 'f') c = '0'; else c++;
+				sdp.setFingerprint(s);
+			}
+		}
 		pc2.setRemoteDescription(string(sdp));
 	});
 

+ 9 - 2
test/main.cpp

@@ -16,7 +16,7 @@ using namespace std;
 using namespace chrono_literals;
 
 void test_negotiated();
-void test_connectivity();
+void test_connectivity(bool signal_wrong_fingerprint);
 void test_turn_connectivity();
 void test_track();
 void test_capi_connectivity();
@@ -41,12 +41,19 @@ int main(int argc, char **argv) {
 	// C++ API tests
 	try {
 		cout << endl << "*** Running WebRTC connectivity test..." << endl;
-		test_connectivity();
+		test_connectivity(false);
 		cout << "*** Finished WebRTC connectivity test" << endl;
 	} catch (const exception &e) {
 		cerr << "WebRTC connectivity test failed: " << e.what() << endl;
 		return -1;
 	}
+	try {
+		cout << endl << "*** Running WebRTC broken fingerprint test..." << endl;
+		test_connectivity(true);
+		cerr << "WebRTC connectivity test failed to detect broken fingerprint" << endl;
+		return -1;
+	} catch (const exception &) {
+	}
 
 // TODO: Temporarily disabled as the Open Relay TURN server is unreliable
 /*