|
@@ -16,6 +16,8 @@
|
|
|
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
|
|
*/
|
|
|
|
|
|
+#if ENABLE_WEBSOCKET
|
|
|
+
|
|
|
#include "tlstransport.hpp"
|
|
|
#include "tcptransport.hpp"
|
|
|
|
|
@@ -51,8 +53,15 @@ static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
|
|
|
|
|
|
namespace rtc {
|
|
|
|
|
|
-TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
|
|
|
- : Transport(lower), mHost(host) {
|
|
|
+void TlsTransport::Init() {
|
|
|
+ // Nothing to do
|
|
|
+}
|
|
|
+
|
|
|
+void TlsTransport::Cleanup() {
|
|
|
+ // Nothing to do
|
|
|
+}
|
|
|
+
|
|
|
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
|
|
|
|
|
|
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
|
|
|
|
|
@@ -62,7 +71,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
|
|
|
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
|
|
|
const char *err_pos = NULL;
|
|
|
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
|
|
|
- "Unable to set TLS priorities");
|
|
|
+ "Failed to set TLS priorities");
|
|
|
|
|
|
gnutls_session_set_ptr(mSession, this);
|
|
|
gnutls_transport_set_ptr(mSession, this);
|
|
@@ -72,7 +81,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
|
|
|
|
|
|
gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
|
|
|
|
|
|
- mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
|
|
|
+ mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
|
|
|
|
|
} catch (...) {
|
|
|
|
|
@@ -81,7 +90,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, const string &host)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-TlsTransport::~DtlsTransport() {
|
|
|
+TlsTransport::~TlsTransport() {
|
|
|
stop();
|
|
|
gnutls_deinit(mSession);
|
|
|
}
|
|
@@ -96,7 +105,7 @@ bool DtlsTransport::stop() {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
-bool DtlsTransport::send(message_ptr message) {
|
|
|
+bool TlsTransport::send(message_ptr message) {
|
|
|
if (!message)
|
|
|
return false;
|
|
|
|
|
@@ -108,7 +117,7 @@ bool DtlsTransport::send(message_ptr message) {
|
|
|
return check_gnutls(ret);
|
|
|
}
|
|
|
|
|
|
-void DtlsTransport::incoming(message_ptr message) {
|
|
|
+void TlsTransport::incoming(message_ptr message) {
|
|
|
if (message)
|
|
|
mIncomingQueue.push(message);
|
|
|
else
|
|
@@ -128,7 +137,6 @@ void TlsTransport::runRecvLoop() {
|
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
PLOG_ERROR << "TLS handshake: " << e.what();
|
|
|
- changeState(State::Failed);
|
|
|
return;
|
|
|
}
|
|
|
|
|
@@ -169,7 +177,7 @@ void TlsTransport::runRecvLoop() {
|
|
|
}
|
|
|
|
|
|
ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
|
|
|
- DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
|
|
|
+ TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
|
|
if (len > 0) {
|
|
|
auto b = reinterpret_cast<const byte *>(data);
|
|
|
t->outgoing(make_message(b, b + len));
|
|
@@ -179,7 +187,7 @@ ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data
|
|
|
}
|
|
|
|
|
|
ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
|
|
|
- TlsTransport *t = static_cast<DtlsTransport *>(ptr);
|
|
|
+ TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
|
|
if (auto next = t->mIncomingQueue.pop()) {
|
|
|
auto message = *next;
|
|
|
ssize_t len = std::min(maxlen, message->size());
|
|
@@ -193,7 +201,7 @@ ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_
|
|
|
}
|
|
|
|
|
|
int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
|
|
|
- TlsTransport *t = static_cast<DtlsTransport *>(ptr);
|
|
|
+ TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
|
|
if (ms != GNUTLS_INDEFINITE_TIMEOUT)
|
|
|
t->mIncomingQueue.wait(milliseconds(ms));
|
|
|
else
|
|
@@ -204,6 +212,202 @@ int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
|
|
|
} // namespace rtc
|
|
|
|
|
|
#else // USE_GNUTLS==0
|
|
|
-// TODO
|
|
|
+
|
|
|
+#include <openssl/bio.h>
|
|
|
+#include <openssl/ec.h>
|
|
|
+#include <openssl/err.h>
|
|
|
+#include <openssl/ssl.h>
|
|
|
+
|
|
|
+namespace {
|
|
|
+
|
|
|
+const int BIO_EOF = -1;
|
|
|
+
|
|
|
+string openssl_error_string(unsigned long err) {
|
|
|
+ const size_t bufferSize = 256;
|
|
|
+ char buffer[bufferSize];
|
|
|
+ ERR_error_string_n(err, buffer, bufferSize);
|
|
|
+ return string(buffer);
|
|
|
+}
|
|
|
+
|
|
|
+bool check_openssl(int success, const string &message = "OpenSSL error") {
|
|
|
+ if (success)
|
|
|
+ return true;
|
|
|
+
|
|
|
+ string str = openssl_error_string(ERR_get_error());
|
|
|
+ PLOG_ERROR << message << ": " << str;
|
|
|
+ throw std::runtime_error(message + ": " + str);
|
|
|
+}
|
|
|
+
|
|
|
+bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
|
|
|
+ if (ret == BIO_EOF)
|
|
|
+ return true;
|
|
|
+
|
|
|
+ unsigned long err = SSL_get_error(ssl, ret);
|
|
|
+ if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (err == SSL_ERROR_ZERO_RETURN) {
|
|
|
+ PLOG_DEBUG << "TLS connection cleanly closed";
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ string str = openssl_error_string(err);
|
|
|
+ PLOG_ERROR << str;
|
|
|
+ throw std::runtime_error(message + ": " + str);
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace
|
|
|
+
|
|
|
+namespace rtc {
|
|
|
+
|
|
|
+int TlsTransport::TransportExIndex = -1;
|
|
|
+
|
|
|
+void TlsTransport::Init() {
|
|
|
+ if (TransportExIndex < 0) {
|
|
|
+ TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void TlsTransport::Cleanup() {
|
|
|
+ // Nothing to do
|
|
|
+}
|
|
|
+
|
|
|
+TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host) : Transport(lower) {
|
|
|
+
|
|
|
+ PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
|
|
|
+ GlobalInit();
|
|
|
+
|
|
|
+ if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
|
|
|
+ throw std::runtime_error("Failed to create SSL context");
|
|
|
+
|
|
|
+ check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
|
|
|
+ "Failed to set SSL priorities");
|
|
|
+
|
|
|
+ SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
|
|
|
+ SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
|
|
|
+ SSL_CTX_set_read_ahead(mCtx, 1);
|
|
|
+ SSL_CTX_set_quiet_shutdown(mCtx, 1);
|
|
|
+ SSL_CTX_set_info_callback(mCtx, InfoCallback);
|
|
|
+
|
|
|
+ SSL_CTX_set_default_verify_paths(mCtx);
|
|
|
+ SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER, NULL);
|
|
|
+ SSL_CTX_set_verify_depth(mCtx, 4);
|
|
|
+
|
|
|
+ if (!(mSsl = SSL_new(mCtx)))
|
|
|
+ throw std::runtime_error("Failed to create SSL instance");
|
|
|
+
|
|
|
+ SSL_set_ex_data(mSsl, TransportExIndex, this);
|
|
|
+ SSL_set_tlsext_host_name(mSsl, host.c_str());
|
|
|
+
|
|
|
+ SSL_set_connect_state(mSsl);
|
|
|
+
|
|
|
+ if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(Bio_s_mem())))
|
|
|
+ throw std::runtime_error("Failed to create BIO");
|
|
|
+
|
|
|
+ BIO_set_mem_eof_return(mInBio, BIO_EOF);
|
|
|
+ BIO_set_mem_eof_return(mOutBio, BIO_EOF);
|
|
|
+ SSL_set_bio(mSsl, mInBio, mOutBio);
|
|
|
+
|
|
|
+ auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
|
|
|
+ EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
|
|
|
+ SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
|
|
|
+ SSL_set_tmp_ecdh(mSsl, ecdh.get());
|
|
|
+
|
|
|
+ mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
|
|
+}
|
|
|
+
|
|
|
+TlsTransport::~TlsTransport() {
|
|
|
+ stop();
|
|
|
+
|
|
|
+ SSL_free(mSsl);
|
|
|
+ SSL_CTX_free(mCtx);
|
|
|
+}
|
|
|
+
|
|
|
+bool TlsTransport::stop() {
|
|
|
+ if (!Transport::stop())
|
|
|
+ return false;
|
|
|
+
|
|
|
+ PLOG_DEBUG << "Stopping TLS recv thread";
|
|
|
+ mIncomingQueue.stop();
|
|
|
+ mRecvThread.join();
|
|
|
+ SSL_shutdown(mSsl);
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+bool TlsTransport::send(message_ptr message) {
|
|
|
+ if (!message)
|
|
|
+ return false;
|
|
|
+
|
|
|
+ int ret = SSL_write(mSsl, message->data(), message->size());
|
|
|
+ if(!check_openssl_ret(mSsl, ret)
|
|
|
+ return false;
|
|
|
+
|
|
|
+ while (int len = BIO_read(mOutBio, buffer, bufferSize); len > 0)
|
|
|
+ outgoing(make_message(buffer, buffer + len));
|
|
|
+
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+void TlsTransport::incoming(message_ptr message) {
|
|
|
+ if (message)
|
|
|
+ mIncomingQueue.push(message);
|
|
|
+ else
|
|
|
+ mIncomingQueue.stop();
|
|
|
+}
|
|
|
+
|
|
|
+void TlsTransport::runRecvLoop() {
|
|
|
+ const size_t bufferSize = 4096;
|
|
|
+
|
|
|
+ byte buffer[bufferSize];
|
|
|
+ bool initFinished = false;
|
|
|
+ try {
|
|
|
+ SSL_do_handshake(mSsl);
|
|
|
+ while (int len = BIO_read(mOutBio, buffer, bufferSize); len > 0)
|
|
|
+ outgoing(make_message(buffer, buffer + len));
|
|
|
+
|
|
|
+ while (auto next = mIncomingQueue.pop()) {
|
|
|
+ auto message = *next;
|
|
|
+ BIO_write(mInBio, message->data(), message->size());
|
|
|
+ int ret = SSL_read(mSsl, buffer, bufferSize);
|
|
|
+ if (!check_openssl_ret(mSsl, ret))
|
|
|
+ break;
|
|
|
+
|
|
|
+ auto received = ret > 0 ? make_message(buffer, buffer + ret) : nullptr;
|
|
|
+
|
|
|
+ while (int len = BIO_read(mOutBio, buffer, bufferSize); len > 0)
|
|
|
+ outgoing(make_message(buffer, buffer + len));
|
|
|
+
|
|
|
+ if (!initFinished && SSL_is_init_finished(mSsl))
|
|
|
+ initFinished = true;
|
|
|
+
|
|
|
+ if (received)
|
|
|
+ recv(received);
|
|
|
+ }
|
|
|
+ } catch (const std::exception &e) {
|
|
|
+ PLOG_ERROR << "TLS recv: " << e.what();
|
|
|
+ }
|
|
|
+
|
|
|
+ if (initFinished) {
|
|
|
+ PLOG_INFO << "TLS disconnected";
|
|
|
+ recv(nullptr);
|
|
|
+ } else {
|
|
|
+ PLOG_ERROR << "TLS handshake failed";
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
|
|
|
+ TlsTransport *t =
|
|
|
+ static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
|
|
|
+
|
|
|
+ if (where & SSL_CB_ALERT) {
|
|
|
+ if (ret != 256) // Close Notify
|
|
|
+ PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
|
|
|
+
|
|
|
+ t->mIncomingQueue.stop(); // Close the connection
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+} // namespace rtc
|
|
|
+
|
|
|
#endif
|
|
|
|
|
|
+#endif
|