|
@@ -92,10 +92,6 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
|
|
|
gnutls_transport_set_pull_function(mSession, ReadCallback);
|
|
|
gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
|
|
|
|
|
|
- size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
|
|
|
- gnutls_dtls_set_mtu(mSession, static_cast<unsigned int>(mtu));
|
|
|
- PLOG_VERBOSE << "DTLS MTU set to " << mtu;
|
|
|
-
|
|
|
} catch (...) {
|
|
|
gnutls_deinit(mSession);
|
|
|
throw;
|
|
@@ -117,6 +113,11 @@ void DtlsTransport::start() {
|
|
|
PLOG_DEBUG << "Starting DTLS transport";
|
|
|
registerIncoming();
|
|
|
changeState(State::Connecting);
|
|
|
+
|
|
|
+ size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
|
|
|
+ gnutls_dtls_set_mtu(mSession, static_cast<unsigned int>(mtu));
|
|
|
+ PLOG_VERBOSE << "DTLS MTU set to " << mtu;
|
|
|
+
|
|
|
enqueueRecv(); // to initiate the handshake
|
|
|
}
|
|
|
|
|
@@ -181,10 +182,13 @@ void DtlsTransport::doRecv() {
|
|
|
std::lock_guard lock(mRecvMutex);
|
|
|
--mPendingRecvCount;
|
|
|
|
|
|
- const size_t bufferSize = 4096;
|
|
|
- char buffer[bufferSize];
|
|
|
+ if (state() != State::Connecting && state() != State::Connected)
|
|
|
+ return;
|
|
|
|
|
|
try {
|
|
|
+ const size_t bufferSize = 4096;
|
|
|
+ char buffer[bufferSize];
|
|
|
+
|
|
|
// Handle handshake if connecting
|
|
|
if (state() == State::Connecting) {
|
|
|
int ret;
|
|
@@ -193,7 +197,7 @@ void DtlsTransport::doRecv() {
|
|
|
|
|
|
if (ret == GNUTLS_E_AGAIN) {
|
|
|
// Schedule next call on timeout and return
|
|
|
- duration timeout = milliseconds(gnutls_dtls_get_timeout(mSession));
|
|
|
+ auto timeout = milliseconds(gnutls_dtls_get_timeout(mSession));
|
|
|
ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
|
|
|
if (auto locked = weak_this.lock())
|
|
|
locked->doRecv();
|
|
@@ -317,7 +321,13 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
|
|
|
ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
|
|
|
DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
|
|
|
try {
|
|
|
- while (auto next = t->mIncomingQueue.pop()) {
|
|
|
+ while (t->mIncomingQueue.running()) {
|
|
|
+ auto next = t->mIncomingQueue.pop();
|
|
|
+ if (!next) {
|
|
|
+ gnutls_transport_set_errno(t->mSession, EAGAIN);
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+
|
|
|
message_ptr message = std::move(*next);
|
|
|
if (t->demuxMessage(message))
|
|
|
continue;
|
|
@@ -328,14 +338,9 @@ ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size
|
|
|
return len;
|
|
|
}
|
|
|
|
|
|
- if (t->mIncomingQueue.running()) {
|
|
|
- gnutls_transport_set_errno(t->mSession, EAGAIN);
|
|
|
- return -1;
|
|
|
- } else {
|
|
|
- // Closed
|
|
|
- gnutls_transport_set_errno(t->mSession, 0);
|
|
|
- return 0;
|
|
|
- }
|
|
|
+ // Closed
|
|
|
+ gnutls_transport_set_errno(t->mSession, 0);
|
|
|
+ return 0;
|
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
PLOG_WARNING << e.what();
|
|
@@ -461,6 +466,7 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, certificate_ptr cer
|
|
|
throw std::runtime_error("Failed to set SRTP profile: " +
|
|
|
openssl::error_string(ERR_get_error()));
|
|
|
}
|
|
|
+
|
|
|
} catch (...) {
|
|
|
if (mSsl)
|
|
|
SSL_free(mSsl);
|
|
@@ -483,23 +489,26 @@ DtlsTransport::~DtlsTransport() {
|
|
|
}
|
|
|
|
|
|
void DtlsTransport::start() {
|
|
|
- if (mStarted.exchange(true))
|
|
|
- return;
|
|
|
-
|
|
|
- PLOG_DEBUG << "Starting DTLS recv thread";
|
|
|
+ PLOG_DEBUG << "Starting DTLS transport";
|
|
|
registerIncoming();
|
|
|
- mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
|
|
|
+ changeState(State::Connecting);
|
|
|
+
|
|
|
+ size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
|
|
|
+ SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
|
|
|
+ PLOG_VERBOSE << "DTLS MTU set to " << mtu;
|
|
|
+
|
|
|
+ // Initiate the handshake
|
|
|
+ int ret = SSL_do_handshake(mSsl);
|
|
|
+ openssl::check(mSsl, ret, "Handshake initiation failed");
|
|
|
+
|
|
|
+ handleTimeout();
|
|
|
}
|
|
|
|
|
|
void DtlsTransport::stop() {
|
|
|
- if (!mStarted.exchange(false))
|
|
|
- return;
|
|
|
-
|
|
|
- PLOG_DEBUG << "Stopping DTLS recv thread";
|
|
|
+ PLOG_DEBUG << "Stopping DTLS transport";
|
|
|
unregisterIncoming();
|
|
|
mIncomingQueue.stop();
|
|
|
- mRecvThread.join();
|
|
|
- SSL_shutdown(mSsl);
|
|
|
+ enqueueRecv();
|
|
|
}
|
|
|
|
|
|
bool DtlsTransport::send(message_ptr message) {
|
|
@@ -519,11 +528,13 @@ bool DtlsTransport::send(message_ptr message) {
|
|
|
void DtlsTransport::incoming(message_ptr message) {
|
|
|
if (!message) {
|
|
|
mIncomingQueue.stop();
|
|
|
+ enqueueRecv();
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
PLOG_VERBOSE << "Incoming size=" << message->size();
|
|
|
mIncomingQueue.push(message);
|
|
|
+ enqueueRecv();
|
|
|
}
|
|
|
|
|
|
bool DtlsTransport::outgoing(message_ptr message) {
|
|
@@ -543,86 +554,65 @@ void DtlsTransport::postHandshake() {
|
|
|
// Dummy
|
|
|
}
|
|
|
|
|
|
-void DtlsTransport::runRecvLoop() {
|
|
|
- const size_t bufferSize = 4096;
|
|
|
- try {
|
|
|
- changeState(State::Connecting);
|
|
|
-
|
|
|
- size_t mtu = mMtu.value_or(DEFAULT_MTU) - 8 - 40; // UDP/IPv6
|
|
|
- SSL_set_mtu(mSsl, static_cast<unsigned int>(mtu));
|
|
|
- PLOG_VERBOSE << "SSL MTU set to " << mtu;
|
|
|
+void DtlsTransport::doRecv() {
|
|
|
+ std::lock_guard lock(mRecvMutex);
|
|
|
+ --mPendingRecvCount;
|
|
|
|
|
|
- // Initiate the handshake
|
|
|
- int ret = SSL_do_handshake(mSsl);
|
|
|
- openssl::check(mSsl, ret, "Handshake failed");
|
|
|
+ if (state() != State::Connecting && state() != State::Connected)
|
|
|
+ return;
|
|
|
|
|
|
+ try {
|
|
|
+ const size_t bufferSize = 4096;
|
|
|
byte buffer[bufferSize];
|
|
|
+
|
|
|
+ // Process pending messages
|
|
|
while (mIncomingQueue.running()) {
|
|
|
- // Process pending messages
|
|
|
- while (auto next = mIncomingQueue.pop()) {
|
|
|
- message_ptr message = std::move(*next);
|
|
|
- if (demuxMessage(message))
|
|
|
- continue;
|
|
|
+ auto next = mIncomingQueue.pop();
|
|
|
+ if (!next) {
|
|
|
+ // No more messages pending, handle timeout if connecting
|
|
|
+ if (state() == State::Connecting)
|
|
|
+ handleTimeout();
|
|
|
|
|
|
- BIO_write(mInBio, message->data(), int(message->size()));
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
- if (state() == State::Connecting) {
|
|
|
- // Continue the handshake
|
|
|
- ret = SSL_do_handshake(mSsl);
|
|
|
- if (!openssl::check(mSsl, ret, "Handshake failed"))
|
|
|
- break;
|
|
|
+ message_ptr message = std::move(*next);
|
|
|
+ if (demuxMessage(message))
|
|
|
+ continue;
|
|
|
|
|
|
- if (SSL_is_init_finished(mSsl)) {
|
|
|
- // RFC 8261: DTLS MUST support sending messages larger than the current path
|
|
|
- // MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
|
|
|
- SSL_set_mtu(mSsl, bufferSize + 1);
|
|
|
+ BIO_write(mInBio, message->data(), int(message->size()));
|
|
|
|
|
|
- PLOG_INFO << "DTLS handshake finished";
|
|
|
- postHandshake();
|
|
|
- changeState(State::Connected);
|
|
|
- }
|
|
|
- } else {
|
|
|
- ret = SSL_read(mSsl, buffer, bufferSize);
|
|
|
- if (!openssl::check(mSsl, ret))
|
|
|
- break;
|
|
|
+ if (state() == State::Connecting) {
|
|
|
+ // Continue the handshake
|
|
|
+ int ret = SSL_do_handshake(mSsl);
|
|
|
+ if (!openssl::check(mSsl, ret, "Handshake failed"))
|
|
|
+ break;
|
|
|
|
|
|
- if (ret > 0)
|
|
|
- recv(make_message(buffer, buffer + ret));
|
|
|
- }
|
|
|
- }
|
|
|
+ if (SSL_is_init_finished(mSsl)) {
|
|
|
+ // RFC 8261: DTLS MUST support sending messages larger than the current path
|
|
|
+ // MTU See https://www.rfc-editor.org/rfc/rfc8261.html#section-5
|
|
|
+ SSL_set_mtu(mSsl, bufferSize + 1);
|
|
|
|
|
|
- // No more messages pending, retransmit and rearm timeout if connecting
|
|
|
- optional<milliseconds> duration;
|
|
|
- if (state() == State::Connecting) {
|
|
|
- // Warning: This function breaks the usual return value convention
|
|
|
- ret = DTLSv1_handle_timeout(mSsl);
|
|
|
- if (ret < 0) {
|
|
|
- throw std::runtime_error("Handshake timeout"); // write BIO can't fail
|
|
|
- } else if (ret > 0) {
|
|
|
- LOG_VERBOSE << "OpenSSL did DTLS retransmit";
|
|
|
+ PLOG_INFO << "DTLS handshake finished";
|
|
|
+ postHandshake();
|
|
|
+ changeState(State::Connected);
|
|
|
}
|
|
|
+ } else {
|
|
|
+ int ret = SSL_read(mSsl, buffer, bufferSize);
|
|
|
+ if (!openssl::check(mSsl, ret))
|
|
|
+ break;
|
|
|
|
|
|
- struct timeval timeout = {};
|
|
|
- if (state() == State::Connecting && DTLSv1_get_timeout(mSsl, &timeout)) {
|
|
|
- duration = milliseconds(timeout.tv_sec * 1000 + timeout.tv_usec / 1000);
|
|
|
- // Also handle handshake timeout manually because OpenSSL actually doesn't...
|
|
|
- // OpenSSL backs off exponentially in base 2 starting from the recommended 1s
|
|
|
- // so this allows for 5 retransmissions and fails after roughly 30s.
|
|
|
- if (duration > 30s) {
|
|
|
- throw std::runtime_error("Handshake timeout");
|
|
|
- } else {
|
|
|
- LOG_VERBOSE << "OpenSSL DTLS retransmit timeout is " << duration->count()
|
|
|
- << "ms";
|
|
|
- }
|
|
|
- }
|
|
|
+ if (ret > 0)
|
|
|
+ recv(make_message(buffer, buffer + ret));
|
|
|
}
|
|
|
-
|
|
|
- mIncomingQueue.wait(duration); // TODO
|
|
|
}
|
|
|
+
|
|
|
} catch (const std::exception &e) {
|
|
|
PLOG_ERROR << "DTLS recv: " << e.what();
|
|
|
}
|
|
|
|
|
|
+ SSL_shutdown(mSsl);
|
|
|
+
|
|
|
if (state() == State::Connected) {
|
|
|
PLOG_INFO << "DTLS closed";
|
|
|
changeState(State::Disconnected);
|
|
@@ -633,6 +623,33 @@ void DtlsTransport::runRecvLoop() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+void DtlsTransport::handleTimeout() {
|
|
|
+ // Warning: This function breaks the usual return value convention
|
|
|
+ int ret = DTLSv1_handle_timeout(mSsl);
|
|
|
+ if (ret < 0) {
|
|
|
+ throw std::runtime_error("Handshake timeout"); // write BIO can't fail
|
|
|
+ } else if (ret > 0) {
|
|
|
+ LOG_VERBOSE << "DTLS retransmit done";
|
|
|
+ }
|
|
|
+
|
|
|
+ struct timeval tv = {};
|
|
|
+ if (DTLSv1_get_timeout(mSsl, &tv)) {
|
|
|
+ auto timeout = milliseconds(tv.tv_sec * 1000 + tv.tv_usec / 1000);
|
|
|
+ // Also handle handshake timeout manually because OpenSSL actually
|
|
|
+ // doesn't... OpenSSL backs off exponentially in base 2 starting from the
|
|
|
+ // recommended 1s so this allows for 5 retransmissions and fails after
|
|
|
+ // roughly 30s.
|
|
|
+ if (timeout > 30s)
|
|
|
+ throw std::runtime_error("Handshake timeout");
|
|
|
+
|
|
|
+ LOG_VERBOSE << "DTLS retransmit timeout is " << timeout.count() << "ms";
|
|
|
+ ThreadPool::Instance().schedule(timeout, [weak_this = weak_from_this()]() {
|
|
|
+ if (auto locked = weak_this.lock())
|
|
|
+ locked->doRecv();
|
|
|
+ });
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
int DtlsTransport::CertificateCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx) {
|
|
|
SSL *ssl =
|
|
|
static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
|