Browse Source

Introduced postCreation method to DTLS-SRTP

Paul-Louis Ageneau 5 years ago
parent
commit
103935bdd5
4 changed files with 23 additions and 13 deletions
  1. 15 13
      src/dtlssrtptransport.cpp
  2. 1 0
      src/dtlssrtptransport.hpp
  3. 6 0
      src/dtlstransport.cpp
  4. 1 0
      src/dtlstransport.hpp

+ 15 - 13
src/dtlssrtptransport.cpp

@@ -44,16 +44,6 @@ DtlsSrtpTransport::DtlsSrtpTransport(std::shared_ptr<IceTransport> lower,
       mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback
       mSrtpRecvCallback(std::move(srtpRecvCallback)) { // distinct from Transport recv callback
 
 
 	PLOG_DEBUG << "Initializing SRTP transport";
 	PLOG_DEBUG << "Initializing SRTP transport";
-
-#if USE_GNUTLS
-	PLOG_DEBUG << "Initializing DTLS-SRTP transport (GnuTLS)";
-	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
-	              "Failed to set SRTP profile");
-#else
-	PLOG_DEBUG << "Initializing DTLS-SRTP transport (OpenSSL)";
-	openssl::check(SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"),
-	               "Failed to set SRTP profile");
-#endif
 }
 }
 
 
 DtlsSrtpTransport::~DtlsSrtpTransport() {
 DtlsSrtpTransport::~DtlsSrtpTransport() {
@@ -67,14 +57,14 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	if (!message)
 	if (!message)
 		return false;
 		return false;
 
 
-	int size = message->size();
-	PLOG_VERBOSE << "Send size=" << size;
-
 	if (!mCreated) {
 	if (!mCreated) {
 		PLOG_WARNING << "SRTP media sent before keys are derived";
 		PLOG_WARNING << "SRTP media sent before keys are derived";
 		return false;
 		return false;
 	}
 	}
 
 
+	int size = message->size();
+	PLOG_VERBOSE << "Send size=" << size;
+
 	// srtp_protect() assumes that it can write SRTP_MAX_TRAILER_LEN (for the authentication tag)
 	// srtp_protect() assumes that it can write SRTP_MAX_TRAILER_LEN (for the authentication tag)
 	// into the location in memory immediately following the RTP packet.
 	// into the location in memory immediately following the RTP packet.
 	message->resize(size + SRTP_MAX_TRAILER_LEN);
 	message->resize(size + SRTP_MAX_TRAILER_LEN);
@@ -130,6 +120,18 @@ void DtlsSrtpTransport::incoming(message_ptr message) {
 	}
 	}
 }
 }
 
 
+void DtlsSrtpTransport::postCreation() {
+#if USE_GNUTLS
+	PLOG_DEBUG << "Initializing DTLS-SRTP transport (GnuTLS)";
+	gnutls::check(gnutls_srtp_set_profile(mSession, GNUTLS_SRTP_AES128_CM_HMAC_SHA1_80),
+	              "Failed to set SRTP profile");
+#else
+	PLOG_DEBUG << "Initializing DTLS-SRTP transport (OpenSSL)";
+	openssl::check(SSL_set_tlsext_use_srtp(mSsl, "SRTP_AES128_CM_SHA1_80"),
+	               "Failed to set SRTP profile");
+#endif
+}
+
 void DtlsSrtpTransport::postHandshake() {
 void DtlsSrtpTransport::postHandshake() {
 	if (mCreated)
 	if (mCreated)
 		return;
 		return;

+ 1 - 0
src/dtlssrtptransport.hpp

@@ -42,6 +42,7 @@ public:
 
 
 private:
 private:
 	void incoming(message_ptr message) override;
 	void incoming(message_ptr message) override;
+	void postCreation() override;
 	void postHandshake() override;
 	void postHandshake() override;
 
 
 	message_callback mSrtpRecvCallback;
 	message_callback mSrtpRecvCallback;

+ 6 - 0
src/dtlstransport.cpp

@@ -85,6 +85,8 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_function(mSession, ReadCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 		gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
 
 
+		postCreation();
+
 		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 		mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
 		registerIncoming();
 		registerIncoming();
 
 
@@ -137,6 +139,10 @@ void DtlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 	mIncomingQueue.push(message);
 }
 }
 
 
+void DtlsTransport::postCreation() {
+	// Dummy
+}
+
 void DtlsTransport::postHandshake() {
 void DtlsTransport::postHandshake() {
 	// Dummy
 	// Dummy
 }
 }

+ 1 - 0
src/dtlstransport.hpp

@@ -52,6 +52,7 @@ public:
 
 
 protected:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void incoming(message_ptr message) override;
+	virtual void postCreation();
 	virtual void postHandshake();
 	virtual void postHandshake();
 	void runRecvLoop();
 	void runRecvLoop();