Browse Source

Merge branch 'api_updates_fix_openssl' of https://github.com/paullouisageneau/libdatachannel into paullouisageneau-api_updates_fix_openssl

Staz M 4 years ago
parent
commit
e966710988
3 changed files with 12 additions and 10 deletions
  1. 5 5
      src/dtlssrtptransport.cpp
  2. 6 4
      src/dtlssrtptransport.hpp
  3. 1 1
      src/dtlstransport.cpp

+ 5 - 5
src/dtlssrtptransport.cpp

@@ -90,9 +90,6 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 	if (size < 8)
 		throw std::runtime_error("RTP/RTCP packet too short");
 
-//    return outgoing(message);
-//    return DtlsTransport::send(message);
-
 	// srtp_protect() and srtp_protect_rtcp() assume that they can write SRTP_MAX_TRAILER_LEN (for
 	// the authentication tag) into the location in memory immediately following the RTP packet.
 	message->resize(size + SRTP_MAX_TRAILER_LEN);
@@ -145,7 +142,6 @@ bool DtlsSrtpTransport::sendMedia(message_ptr message) {
 
 	message->resize(size);
 	return outgoing(message);
-//	return DtlsTransport::send(message);
 }
 
 void DtlsSrtpTransport::incoming(message_ptr message) {
@@ -247,6 +243,8 @@ void DtlsSrtpTransport::postHandshake() {
 	if (mInitDone)
 		return;
 
+	static_assert(SRTP_AES_ICM_128_KEY_LEN_WSALT == SRTP_AES_128_KEY_LEN + SRTP_SALT_LEN);
+
 	const size_t materialLen = SRTP_AES_ICM_128_KEY_LEN_WSALT * 2;
 	unsigned char material[materialLen];
 	const unsigned char *clientKey, *clientSalt, *serverKey, *serverSalt;
@@ -319,6 +317,9 @@ void DtlsSrtpTransport::postHandshake() {
 }
 
 void DtlsSrtpTransport::addSSRC(uint32_t ssrc) {
+	if (!mInitDone)
+		throw std::logic_error("Attempted to add SSRC before SRTP keying material is derived");
+
     srtp_policy_t inbound = {};
     srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtp);
     srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&inbound.rtcp);
@@ -331,7 +332,6 @@ void DtlsSrtpTransport::addSSRC(uint32_t ssrc) {
         throw std::runtime_error("SRTP add inbound stream failed, status=" +
                                  to_string(static_cast<int>(err)));
 
-
     srtp_policy_t outbound = {};
     srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtp);
     srtp_crypto_policy_set_aes_cm_128_hmac_sha1_80(&outbound.rtcp);

+ 6 - 4
src/dtlssrtptransport.hpp

@@ -26,6 +26,8 @@
 
 #include <srtp2/srtp.h>
 
+#include <atomic>
+
 namespace rtc {
 
 class DtlsSrtpTransport final : public DtlsTransport {
@@ -39,7 +41,7 @@ public:
 	~DtlsSrtpTransport();
 
 	bool sendMedia(message_ptr message);
-    void addSSRC(uint32_t ssrc);
+	void addSSRC(uint32_t ssrc);
 
 private:
 	void incoming(message_ptr message) override;
@@ -48,10 +50,10 @@ private:
 	message_callback mSrtpRecvCallback;
 
 	srtp_t mSrtpIn, mSrtpOut;
-	bool mInitDone = false;
 
-    unsigned char mClientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
-    unsigned char mServerSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+	std::atomic<bool> mInitDone = false;
+	unsigned char mClientSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
+	unsigned char mServerSessionKey[SRTP_AES_ICM_128_KEY_LEN_WSALT];
 };
 
 } // namespace rtc

+ 1 - 1
src/dtlstransport.cpp

@@ -453,8 +453,8 @@ void DtlsTransport::runRecvLoop() {
 						SSL_set_mtu(mSsl, maxMtu + 1);
 
 						PLOG_INFO << "DTLS handshake finished";
-						changeState(State::Connected);
 						postHandshake();
+						changeState(State::Connected);
 					}
 				} else {
 					ret = SSL_read(mSsl, buffer, bufferSize);