Browse Source

Implemented DTLS-SRTP transport for GnuTLS (draft)

Paul-Louis Ageneau 5 years ago
parent
commit
a7e23dd210
4 changed files with 110 additions and 27 deletions
  1. 87 16
      src/dtlssrtptransport.cpp
  2. 4 2
      src/dtlssrtptransport.hpp
  3. 16 6
      src/dtlstransport.cpp
  4. 3 3
      src/dtlstransport.hpp

+ 87 - 16
src/dtlssrtptransport.cpp

@@ -19,8 +19,10 @@
 #include "dtlssrtptransport.hpp"
 #include "dtlssrtptransport.hpp"
 
 
 #include <exception>
 #include <exception>
+#include <srtp2/srtp.h>
 
 
 using std::shared_ptr;
 using std::shared_ptr;
+using std::to_string;
 
 
 namespace rtc {
 namespace rtc {
 
 
@@ -30,48 +32,117 @@ DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
                                      message_callback recvCallback,
                                      message_callback recvCallback,
                                      state_callback stateChangeCallback)
                                      state_callback stateChangeCallback)
     : DtlsTransport(lower, certificate, std::move(verifierCallback),
     : DtlsTransport(lower, certificate, std::move(verifierCallback),
-                    std::move(stateChangeCallback)) {
-	onRecv(recvCallback);
+                    std::move(stateChangeCallback)),
+      mRecvCallback(std::move(recvCallback)) {
 
 
 	// TODO: global init
 	// TODO: global init
 	srtp_init();
 	srtp_init();
 
 
 	PLOG_DEBUG << "Initializing SRTP transport";
 	PLOG_DEBUG << "Initializing SRTP transport";
 
 
-	mPolicy = {};
-	srtp_crypto_policy_set_rtp_default(&mPolicy.rtp);
-	srtp_crypto_policy_set_rtcp_default(&mPolicy.rtcp);
+#if USE_GNUTLS
+	// TODO: check_gnutls
+	gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80);
+#else
+	// TODO
+#endif
 }
 }
 
 
 DtlsSrtpTransport::~DtlsSrtpTransport() { stop(); }
 DtlsSrtpTransport::~DtlsSrtpTransport() { stop(); }
 
 
-void DtlsSrtpTransport::stop() {
-	Transport::stop();
-	onRecv(nullptr);
+bool DtlsSrtpTransport::stop() {
+	if (!Transport::stop())
+		return false;
 
 
 	// TODO: global cleanup
 	// TODO: global cleanup
 	srtp_shutdown();
 	srtp_shutdown();
+	return true;
 }
 }
 
 
 bool DtlsSrtpTransport::send(message_ptr message) {
 bool DtlsSrtpTransport::send(message_ptr message) {
 	if (!message)
 	if (!message)
 		return false;
 		return false;
 
 
-	PLOG_VERBOSE << "Send size=" << message->size();
-
-	// TODO
-	return false;
+	int size = message->size();
+	PLOG_VERBOSE << "Send size=" << size;
+
+	// srtp_protect() assumes that it can write SRTP_MAX_TRAILER_LEN (for the authentication tag)
+	// into the location in memory immediately following the RTP packet.
+	message->resize(size + SRTP_MAX_TRAILER_LEN);
+	if (srtp_err_status_t err = srtp_protect(mSrtp, message->data(), &size)) {
+		if (err == srtp_err_status_replay_fail)
+			throw std::runtime_error("SRTP packet is a replay");
+		else
+			throw std::runtime_error("SRTP protect error");
+	}
+	PLOG_VERBOSE << "Protected SRTP packet, size=" << size;
+	message->resize(size);
+	outgoing(message);
+	return true;
 }
 }
 
 
 void DtlsSrtpTransport::incoming(message_ptr message) {
 void DtlsSrtpTransport::incoming(message_ptr message) {
-	//
+	// TODO: demultiplexing
+	// detect dtls and pass to DtlsTransport::incoming
+
+	int size = message->size();
+	PLOG_VERBOSE << "Incoming SRTP packet, size=" << size;
+
+	if (srtp_err_status_t err = srtp_unprotect(mSrtp, message->data(), &size)) {
+		if (err == srtp_err_status_replay_fail)
+			PLOG_WARNING << "Incoming SRTP packet is a replay";
+		else
+			PLOG_WARNING << "SRTP unprotect error, status=" << err;
+		return;
+	}
+	PLOG_VERBOSE << "Unprotected SRTP packet, size=" << size;
+	message->resize(size);
+	mRecvCallback(message);
 }
 }
 
 
 void DtlsSrtpTransport::postHandshake() {
 void DtlsSrtpTransport::postHandshake() {
-	// TODO: derive keys
+	srtp_policy_t inbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
+	inbound.ssrc.type = ssrc_any_inbound;
+
+	srtp_policy_t outbound = {};
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
+	srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);
+	outbound.ssrc.type = ssrc_any_outbound;
+
+#if USE_GNUTLS
+	unsigned char material[SRTP_MAX_KEY_LEN * 2];
+	gnutls_datum_t clientKey, clientSalt, serverKey, serverSalt;
+	// TODO: check_gnutls
+	gnutls_srtp_get_keys(mSession, material, STRP_MAX_KEY_LEN * 2, &clientKey, &clientSalt,
+	                     &serverKey, &serverSalt);
+
+	unsigned char clientSessionKey[SRTP_MAX_KEY_LEN];
+	std::memcpy(clientSessionKey, clientKey.data, clientKey.size);
+	std::memcpy(clientSessionKey + clientKey.size, clientSalt.data, clientSalt.size);
+
+	unsigned char serverSessionKey[SRTP_MAX_KEY_LEN];
+	std::memcpy(serverSessionKey, serverKey.data, serverKey.size);
+	std::memcpy(serverSessionKey + serverKey.size, serverSalt.data, serverSalt.size);
+
+	if (mIsClient) {
+		inbound.key = serverSessionKey;
+		outbound.key = clientSessionKey;
+	} else {
+		inbound.key = clientSessionKey;
+		outbound.key = serverSessionKey;
+	}
+#else
+	// TODO
+#endif
+
+	srtp_policy_t *policies = &inbound;
+	inbound.next = &outbound;
+	outbound.next = nullptr;
 
 
-	mPolicy.ssrc = mSsrc;
-	mPolicy.key = key;
+	if (srtp_err_status_t err = srtp_create(&mSrtp, policies))
+		throw std::runtime_error("SRTP create failed, status=" + to_string(static_cast<int>(err)));
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 4 - 2
src/dtlssrtptransport.hpp

@@ -33,14 +33,16 @@ public:
 	                  state_callback stateChangeCallback);
 	                  state_callback stateChangeCallback);
 	~DtlsSrtpTransport();
 	~DtlsSrtpTransport();
 
 
-	void stop() override;
+	bool stop() override;
 	bool send(message_ptr message) override;
 	bool send(message_ptr message) override;
 
 
 private:
 private:
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
+	void postHandshake() override;
+
+	message_callback mRecvCallback;
 
 
 	srtp_t mSrtp;
 	srtp_t mSrtp;
-	srtp_policy_t mPolicy;
 };
 };
 
 
 } // namespace rtc
 } // namespace rtc

+ 16 - 6
src/dtlstransport.cpp

@@ -64,12 +64,13 @@ void DtlsTransport::Cleanup() {
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
     : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
-      mVerifierCallback(std::move(verifierCallback)) {
+      mVerifierCallback(std::move(verifierCallback)),
+      mIsClient(lower->role() == Description::Role::Active) {
 
 
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 	PLOG_DEBUG << "Initializing DTLS transport (GnuTLS)";
 
 
-	bool active = lower->role() == Description::Role::Active;
-	unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
+	unsigned int flags = GNUTLS_DATAGRAM | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER);
+
 	check_gnutls(gnutls_init(&mSession, flags));
 	check_gnutls(gnutls_init(&mSession, flags));
 
 
 	try {
 	try {
@@ -148,6 +149,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 	mIncomingQueue.push(message);
 }
 }
 
 
+void DtlsTransport::postHandshake() {
+	// Dummy
+}
+
 void DtlsTransport::runRecvLoop() {
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 	const size_t maxMtu = 4096;
 
 
@@ -180,6 +185,7 @@ void DtlsTransport::runRecvLoop() {
 	try {
 	try {
 		PLOG_INFO << "DTLS handshake finished";
 		PLOG_INFO << "DTLS handshake finished";
 		changeState(State::Connected);
 		changeState(State::Connected);
+		postHandshake();
 
 
 		const size_t bufferSize = maxMtu;
 		const size_t bufferSize = maxMtu;
 		char buffer[bufferSize];
 		char buffer[bufferSize];
@@ -354,8 +360,7 @@ void DtlsTransport::Cleanup() {
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
 DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
                              verifier_callback verifierCallback, state_callback stateChangeCallback)
     : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
     : Transport(lower, std::move(stateChangeCallback)), mCertificate(certificate),
-      mVerifierCallback(std::move(verifierCallback)) {
-
+      mVerifierCallback(std::move(verifierCallback), mIsClient(lower->role() == Description::Role::Active) {
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 	PLOG_DEBUG << "Initializing DTLS transport (OpenSSL)";
 
 
 	try {
 	try {
@@ -388,7 +393,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
 
 
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 		SSL_set_ex_data(mSsl, TransportExIndex, this);
 
 
-		if (lower->role() == Description::Role::Active)
+		if (mIsClient)
 			SSL_set_connect_state(mSsl);
 			SSL_set_connect_state(mSsl);
 		else
 		else
 			SSL_set_accept_state(mSsl);
 			SSL_set_accept_state(mSsl);
@@ -455,6 +460,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 	mIncomingQueue.push(message);
 }
 }
 
 
+void DtlsTransport::postHandshake() {
+	// Dummy
+}
+
 void DtlsTransport::runRecvLoop() {
 void DtlsTransport::runRecvLoop() {
 	const size_t maxMtu = 4096;
 	const size_t maxMtu = 4096;
 	try {
 	try {
@@ -486,6 +495,7 @@ void DtlsTransport::runRecvLoop() {
 
 
 						PLOG_INFO << "DTLS handshake finished";
 						PLOG_INFO << "DTLS handshake finished";
 						changeState(State::Connected);
 						changeState(State::Connected);
+						postHandshake();
 					}
 					}
 				} else {
 				} else {
 					int ret = SSL_read(mSsl, buffer, bufferSize);
 					int ret = SSL_read(mSsl, buffer, bufferSize);

+ 3 - 3
src/dtlstransport.hpp

@@ -57,16 +57,16 @@ public:
 
 
 protected:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void incoming(message_ptr message) override;
-
+	virtual void postHandshake();
 	void runRecvLoop();
 	void runRecvLoop();
 
 
 	const std::shared_ptr<Certificate> mCertificate;
 	const std::shared_ptr<Certificate> mCertificate;
+	const verifier_callback mVerifierCallback;
+	const bool mIsClient;
 
 
 	Queue<message_ptr> mIncomingQueue;
 	Queue<message_ptr> mIncomingQueue;
 	std::thread mRecvThread;
 	std::thread mRecvThread;
 
 
-	verifier_callback mVerifierCallback;
-
 #if USE_GNUTLS
 #if USE_GNUTLS
 	gnutls_session_t mSession;
 	gnutls_session_t mSession;