tlstransport.cpp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This Source Code Form is subject to the terms of the Mozilla Public
  5. * License, v. 2.0. If a copy of the MPL was not distributed with this
  6. * file, You can obtain one at https://mozilla.org/MPL/2.0/.
  7. */
  8. #include "tlstransport.hpp"
  9. #include "httpproxytransport.hpp"
  10. #include "tcptransport.hpp"
  11. #include "threadpool.hpp"
  12. #if RTC_ENABLE_WEBSOCKET
  13. #include <algorithm>
  14. #include <chrono>
  15. #include <cstring>
  16. #include <exception>
  17. using namespace std::chrono;
  18. namespace rtc::impl {
  19. void TlsTransport::enqueueRecv() {
  20. if (mPendingRecvCount > 0)
  21. return;
  22. if (auto shared_this = weak_from_this().lock()) {
  23. ++mPendingRecvCount;
  24. ThreadPool::Instance().enqueue(&TlsTransport::doRecv, std::move(shared_this));
  25. }
  26. }
  27. #if USE_GNUTLS
  28. namespace {
  29. gnutls_certificate_credentials_t default_certificate_credentials() {
  30. static std::mutex mutex;
  31. static shared_ptr<gnutls_certificate_credentials_t> creds;
  32. std::lock_guard lock(mutex);
  33. if (!creds) {
  34. creds = shared_ptr<gnutls_certificate_credentials_t>(gnutls::new_credentials(),
  35. gnutls::free_credentials);
  36. gnutls::check(gnutls_certificate_set_x509_system_trust(*creds));
  37. }
  38. return *creds;
  39. }
  40. } // namespace
  41. void TlsTransport::Init() {
  42. // Nothing to do
  43. }
  44. void TlsTransport::Cleanup() {
  45. // Nothing to do
  46. }
  47. TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower,
  48. optional<string> host, certificate_ptr certificate,
  49. state_callback callback)
  50. : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
  51. std::move(callback)),
  52. mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)),
  53. mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
  54. PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
  55. unsigned int flags = GNUTLS_NONBLOCK | (mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER);
  56. gnutls::check(gnutls_init(&mSession, flags));
  57. try {
  58. const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
  59. const char *err_pos = NULL;
  60. gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
  61. "Failed to set TLS priorities");
  62. gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE,
  63. certificate ? certificate->credentials()
  64. : default_certificate_credentials()));
  65. if (mIsClient && mHost) {
  66. PLOG_VERBOSE << "Server Name Indication: " << *mHost;
  67. gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size());
  68. }
  69. gnutls_session_set_ptr(mSession, this);
  70. gnutls_transport_set_ptr(mSession, this);
  71. gnutls_transport_set_push_function(mSession, WriteCallback);
  72. gnutls_transport_set_pull_function(mSession, ReadCallback);
  73. gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
  74. } catch (...) {
  75. gnutls_deinit(mSession);
  76. throw;
  77. }
  78. }
  79. TlsTransport::~TlsTransport() {
  80. stop();
  81. gnutls_deinit(mSession);
  82. }
  83. void TlsTransport::start() {
  84. PLOG_DEBUG << "Starting TLS transport";
  85. registerIncoming();
  86. changeState(State::Connecting);
  87. enqueueRecv(); // to initiate the handshake
  88. }
  89. void TlsTransport::stop() {
  90. PLOG_DEBUG << "Stopping TLS transport";
  91. unregisterIncoming();
  92. mIncomingQueue.stop();
  93. enqueueRecv();
  94. }
  95. bool TlsTransport::send(message_ptr message) {
  96. if (state() != State::Connected)
  97. throw std::runtime_error("TLS is not open");
  98. if (!message || message->size() == 0)
  99. return outgoing(message); // pass through
  100. PLOG_VERBOSE << "Send size=" << message->size();
  101. ssize_t ret;
  102. do {
  103. ret = gnutls_record_send(mSession, message->data(), message->size());
  104. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
  105. if (!gnutls::check(ret))
  106. throw std::runtime_error("TLS send failed");
  107. return mOutgoingResult;
  108. }
  109. void TlsTransport::incoming(message_ptr message) {
  110. if (!message) {
  111. mIncomingQueue.stop();
  112. enqueueRecv();
  113. return;
  114. }
  115. PLOG_VERBOSE << "Incoming size=" << message->size();
  116. mIncomingQueue.push(message);
  117. enqueueRecv();
  118. }
  119. bool TlsTransport::outgoing(message_ptr message) {
  120. bool result = Transport::outgoing(std::move(message));
  121. mOutgoingResult = result;
  122. return result;
  123. }
  124. void TlsTransport::postHandshake() {
  125. // Dummy
  126. }
  127. void TlsTransport::doRecv() {
  128. std::lock_guard lock(mRecvMutex);
  129. --mPendingRecvCount;
  130. const size_t bufferSize = 4096;
  131. char buffer[bufferSize];
  132. try {
  133. // Handle handshake if connecting
  134. if (state() == State::Connecting) {
  135. int ret;
  136. do {
  137. ret = gnutls_handshake(mSession);
  138. if (ret == GNUTLS_E_AGAIN)
  139. return;
  140. } while (!gnutls::check(ret, "Handshake failed")); // Re-call on non-fatal error
  141. PLOG_INFO << "TLS handshake finished";
  142. changeState(State::Connected);
  143. postHandshake();
  144. }
  145. if (state() == State::Connected) {
  146. while (true) {
  147. ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize);
  148. if (ret == GNUTLS_E_AGAIN)
  149. return;
  150. // Consider premature termination as remote closing
  151. if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
  152. PLOG_DEBUG << "TLS connection terminated";
  153. break;
  154. }
  155. if (gnutls::check(ret)) {
  156. if (ret == 0) {
  157. // Closed
  158. PLOG_DEBUG << "TLS connection cleanly closed";
  159. break;
  160. }
  161. auto *b = reinterpret_cast<byte *>(buffer);
  162. recv(make_message(b, b + ret));
  163. }
  164. }
  165. }
  166. } catch (const std::exception &e) {
  167. PLOG_ERROR << "TLS recv: " << e.what();
  168. }
  169. gnutls_bye(mSession, GNUTLS_SHUT_WR);
  170. if (state() == State::Connected) {
  171. PLOG_INFO << "TLS closed";
  172. changeState(State::Disconnected);
  173. recv(nullptr);
  174. } else {
  175. PLOG_ERROR << "TLS handshake failed";
  176. changeState(State::Failed);
  177. }
  178. }
  179. ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
  180. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  181. try {
  182. if (len > 0) {
  183. auto b = reinterpret_cast<const byte *>(data);
  184. t->outgoing(make_message(b, b + len));
  185. }
  186. gnutls_transport_set_errno(t->mSession, 0);
  187. return ssize_t(len);
  188. } catch (const std::exception &e) {
  189. PLOG_WARNING << e.what();
  190. gnutls_transport_set_errno(t->mSession, ECONNRESET);
  191. return -1;
  192. }
  193. }
  194. ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
  195. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  196. try {
  197. message_ptr &message = t->mIncomingMessage;
  198. size_t &position = t->mIncomingMessagePosition;
  199. if (message && position >= message->size())
  200. message.reset();
  201. if (!message) {
  202. position = 0;
  203. while (auto next = t->mIncomingQueue.pop()) {
  204. message = *next;
  205. if (message->size() > 0)
  206. break;
  207. else
  208. t->recv(message); // Pass zero-sized messages through
  209. }
  210. }
  211. if (message) {
  212. size_t available = message->size() - position;
  213. ssize_t len = std::min(maxlen, available);
  214. std::memcpy(data, message->data() + position, len);
  215. position += len;
  216. gnutls_transport_set_errno(t->mSession, 0);
  217. return len;
  218. } else if (t->mIncomingQueue.running()) {
  219. gnutls_transport_set_errno(t->mSession, EAGAIN);
  220. return -1;
  221. } else {
  222. // Closed
  223. gnutls_transport_set_errno(t->mSession, 0);
  224. return 0;
  225. }
  226. } catch (const std::exception &e) {
  227. PLOG_WARNING << e.what();
  228. gnutls_transport_set_errno(t->mSession, ECONNRESET);
  229. return -1;
  230. }
  231. }
  232. int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int /* ms */) {
  233. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  234. try {
  235. message_ptr &message = t->mIncomingMessage;
  236. size_t &position = t->mIncomingMessagePosition;
  237. if (message && position < message->size())
  238. return 1;
  239. return !t->mIncomingQueue.empty() ? 1 : 0;
  240. } catch (const std::exception &e) {
  241. PLOG_WARNING << e.what();
  242. return 1;
  243. }
  244. }
  245. #elif USE_MBEDTLS
  246. void TlsTransport::Init() {
  247. // Nothing to do
  248. }
  249. void TlsTransport::Cleanup() {
  250. // Nothing to do
  251. }
  252. TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower,
  253. optional<string> host, certificate_ptr certificate,
  254. state_callback callback)
  255. : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
  256. std::move(callback)),
  257. mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)),
  258. mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
  259. PLOG_DEBUG << "Initializing TLS transport (MbedTLS)";
  260. mbedtls_entropy_init(&mEntropy);
  261. mbedtls_ctr_drbg_init(&mDrbg);
  262. mbedtls_ssl_init(&mSsl);
  263. mbedtls_ssl_config_init(&mConf);
  264. mbedtls_ctr_drbg_set_prediction_resistance(&mDrbg, MBEDTLS_CTR_DRBG_PR_ON);
  265. try {
  266. mbedtls::check(mbedtls_ctr_drbg_seed(&mDrbg, mbedtls_entropy_func, &mEntropy, NULL, 0));
  267. mbedtls::check(mbedtls_ssl_config_defaults(
  268. &mConf, mIsClient ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
  269. MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT));
  270. mbedtls_ssl_conf_authmode(&mConf, MBEDTLS_SSL_VERIFY_OPTIONAL);
  271. mbedtls_ssl_conf_rng(&mConf, mbedtls_ctr_drbg_random, &mDrbg);
  272. if (certificate) {
  273. auto [crt, pk] = certificate->credentials();
  274. mbedtls::check(mbedtls_ssl_conf_own_cert(&mConf, crt.get(), pk.get()));
  275. }
  276. mbedtls::check(mbedtls_ssl_setup(&mSsl, &mConf));
  277. mbedtls_ssl_set_bio(&mSsl, static_cast<void *>(this), WriteCallback, ReadCallback, NULL);
  278. } catch (...) {
  279. mbedtls_entropy_free(&mEntropy);
  280. mbedtls_ctr_drbg_free(&mDrbg);
  281. mbedtls_ssl_free(&mSsl);
  282. mbedtls_ssl_config_free(&mConf);
  283. throw;
  284. }
  285. }
  286. TlsTransport::~TlsTransport() {}
  287. void TlsTransport::start() {
  288. PLOG_DEBUG << "Starting TLS transport";
  289. registerIncoming();
  290. changeState(State::Connecting);
  291. enqueueRecv(); // to initiate the handshake
  292. }
  293. void TlsTransport::stop() {
  294. PLOG_DEBUG << "Stopping TLS transport";
  295. unregisterIncoming();
  296. mIncomingQueue.stop();
  297. enqueueRecv();
  298. }
  299. bool TlsTransport::send(message_ptr message) {
  300. if (state() != State::Connected)
  301. throw std::runtime_error("TLS is not open");
  302. if (!message || message->size() == 0)
  303. return outgoing(message); // pass through
  304. PLOG_VERBOSE << "Send size=" << message->size();
  305. int ret;
  306. do {
  307. std::lock_guard lock(mSslMutex);
  308. ret = mbedtls_ssl_write(&mSsl, reinterpret_cast<const unsigned char *>(message->data()),
  309. int(message->size()));
  310. } while (ret == MBEDTLS_ERR_SSL_WANT_WRITE);
  311. mbedtls::check(ret);
  312. return mOutgoingResult;
  313. }
  314. void TlsTransport::incoming(message_ptr message) {
  315. if (!message) {
  316. mIncomingQueue.stop();
  317. enqueueRecv();
  318. return;
  319. }
  320. PLOG_VERBOSE << "Incoming size=" << message->size();
  321. mIncomingQueue.push(message);
  322. enqueueRecv();
  323. }
  324. bool TlsTransport::outgoing(message_ptr message) {
  325. bool result = Transport::outgoing(std::move(message));
  326. mOutgoingResult = result;
  327. return result;
  328. }
  329. void TlsTransport::postHandshake() {
  330. // Dummy
  331. }
  332. void TlsTransport::doRecv() {
  333. std::lock_guard lock(mRecvMutex);
  334. --mPendingRecvCount;
  335. if (state() != State::Connecting && state() != State::Connected)
  336. return;
  337. try {
  338. const size_t bufferSize = 4096;
  339. char buffer[bufferSize];
  340. // Handle handshake if connecting
  341. if (state() == State::Connecting) {
  342. while (true) {
  343. int ret;
  344. {
  345. std::lock_guard lock(mSslMutex);
  346. ret = mbedtls_ssl_handshake(&mSsl);
  347. }
  348. if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
  349. return;
  350. }
  351. if (mbedtls::check(ret, "Handshake failed")) {
  352. PLOG_INFO << "TLS handshake finished";
  353. changeState(State::Connected);
  354. postHandshake();
  355. break;
  356. }
  357. }
  358. }
  359. if (state() == State::Connected) {
  360. while (true) {
  361. int ret;
  362. {
  363. std::lock_guard lock(mSslMutex);
  364. ret = mbedtls_ssl_read(&mSsl, reinterpret_cast<unsigned char *>(buffer),
  365. bufferSize);
  366. }
  367. if (ret == MBEDTLS_ERR_SSL_WANT_READ) {
  368. return;
  369. }
  370. if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
  371. PLOG_DEBUG << "TLS connection cleanly closed";
  372. break;
  373. }
  374. if (mbedtls::check(ret)) {
  375. if (ret == 0) {
  376. PLOG_DEBUG << "TLS connection terminated";
  377. break;
  378. }
  379. auto *b = reinterpret_cast<byte *>(buffer);
  380. recv(make_message(b, b + ret));
  381. }
  382. }
  383. }
  384. } catch (const std::exception &e) {
  385. PLOG_ERROR << "TLS recv: " << e.what();
  386. }
  387. if (state() == State::Connected) {
  388. PLOG_INFO << "TLS closed";
  389. changeState(State::Disconnected);
  390. recv(nullptr);
  391. } else {
  392. PLOG_ERROR << "TLS handshake failed";
  393. changeState(State::Failed);
  394. }
  395. }
  396. int TlsTransport::WriteCallback(void *ctx, const unsigned char *buf, size_t len) {
  397. auto *t = static_cast<TlsTransport *>(ctx);
  398. auto *b = reinterpret_cast<const byte *>(buf);
  399. t->outgoing(make_message(b, b + len));
  400. return int(len);
  401. }
  402. int TlsTransport::ReadCallback(void *ctx, unsigned char *buf, size_t len) {
  403. TlsTransport *t = static_cast<TlsTransport *>(ctx);
  404. try {
  405. message_ptr &message = t->mIncomingMessage;
  406. size_t &position = t->mIncomingMessagePosition;
  407. if (message && position >= message->size())
  408. message.reset();
  409. if (!message) {
  410. position = 0;
  411. while (auto next = t->mIncomingQueue.pop()) {
  412. message = *next;
  413. if (message->size() > 0)
  414. break;
  415. else
  416. t->recv(message); // Pass zero-sized messages through
  417. }
  418. }
  419. if (message) {
  420. size_t available = message->size() - position;
  421. size_t writeLen = std::min(len, available);
  422. std::memcpy(buf, message->data() + position, writeLen);
  423. position += writeLen;
  424. return int(writeLen);
  425. } else if (t->mIncomingQueue.running()) {
  426. return MBEDTLS_ERR_SSL_WANT_READ;
  427. } else {
  428. return MBEDTLS_ERR_SSL_CONN_EOF;
  429. }
  430. } catch (const std::exception &e) {
  431. PLOG_WARNING << e.what();
  432. return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
  433. }
  434. }
  435. #else
  436. int TlsTransport::TransportExIndex = -1;
  437. void TlsTransport::Init() {
  438. openssl::init();
  439. if (TransportExIndex < 0) {
  440. TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
  441. }
  442. }
  443. void TlsTransport::Cleanup() {
  444. // Nothing to do
  445. }
  446. TlsTransport::TlsTransport(variant<shared_ptr<TcpTransport>, shared_ptr<HttpProxyTransport>> lower,
  447. optional<string> host, certificate_ptr certificate,
  448. state_callback callback)
  449. : Transport(std::visit([](auto l) { return std::static_pointer_cast<Transport>(l); }, lower),
  450. std::move(callback)),
  451. mHost(std::move(host)), mIsClient(std::visit([](auto l) { return l->isActive(); }, lower)),
  452. mIncomingQueue(RECV_QUEUE_LIMIT, message_size_func) {
  453. PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
  454. try {
  455. if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
  456. throw std::runtime_error("Failed to create SSL context");
  457. openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
  458. "Failed to set SSL priorities");
  459. #if OPENSSL_VERSION_NUMBER >= 0x30000000
  460. openssl::check(SSL_CTX_set1_groups_list(mCtx, "P-256"), "Failed to set SSL groups");
  461. #else
  462. auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
  463. EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
  464. SSL_CTX_set_tmp_ecdh(mCtx, ecdh.get());
  465. SSL_CTX_set_options(mCtx, SSL_OP_SINGLE_ECDH_USE);
  466. #endif
  467. if (certificate) {
  468. auto [x509, pkey] = certificate->credentials();
  469. SSL_CTX_use_certificate(mCtx, x509);
  470. SSL_CTX_use_PrivateKey(mCtx, pkey);
  471. } else {
  472. if (!SSL_CTX_set_default_verify_paths(mCtx)) {
  473. PLOG_WARNING << "SSL root CA certificates unavailable";
  474. }
  475. }
  476. SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
  477. SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
  478. SSL_CTX_set_read_ahead(mCtx, 1);
  479. SSL_CTX_set_quiet_shutdown(mCtx, 1);
  480. SSL_CTX_set_info_callback(mCtx, InfoCallback);
  481. SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL);
  482. if (!(mSsl = SSL_new(mCtx)))
  483. throw std::runtime_error("Failed to create SSL instance");
  484. SSL_set_ex_data(mSsl, TransportExIndex, this);
  485. if (mIsClient && mHost) {
  486. SSL_set_hostflags(mSsl, 0);
  487. openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host");
  488. PLOG_VERBOSE << "Server Name Indication: " << *mHost;
  489. SSL_set_tlsext_host_name(mSsl, mHost->c_str());
  490. }
  491. if (mIsClient)
  492. SSL_set_connect_state(mSsl);
  493. else
  494. SSL_set_accept_state(mSsl);
  495. if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
  496. throw std::runtime_error("Failed to create BIO");
  497. BIO_set_mem_eof_return(mInBio, BIO_EOF);
  498. BIO_set_mem_eof_return(mOutBio, BIO_EOF);
  499. SSL_set_bio(mSsl, mInBio, mOutBio);
  500. } catch (...) {
  501. if (mSsl)
  502. SSL_free(mSsl);
  503. if (mCtx)
  504. SSL_CTX_free(mCtx);
  505. throw;
  506. }
  507. }
  508. TlsTransport::~TlsTransport() {
  509. stop();
  510. SSL_free(mSsl);
  511. SSL_CTX_free(mCtx);
  512. }
  513. void TlsTransport::start() {
  514. PLOG_DEBUG << "Starting TLS transport";
  515. registerIncoming();
  516. changeState(State::Connecting);
  517. // Initiate the handshake
  518. std::lock_guard lock(mSslMutex);
  519. int ret = SSL_do_handshake(mSsl);
  520. openssl::check(mSsl, ret, "Handshake initiation failed");
  521. flushOutput();
  522. }
  523. void TlsTransport::stop() {
  524. PLOG_DEBUG << "Stopping TLS transport";
  525. unregisterIncoming();
  526. mIncomingQueue.stop();
  527. enqueueRecv();
  528. }
  529. bool TlsTransport::send(message_ptr message) {
  530. if (state() != State::Connected)
  531. throw std::runtime_error("TLS is not open");
  532. if (!message || message->size() == 0)
  533. return outgoing(message); // pass through
  534. PLOG_VERBOSE << "Send size=" << message->size();
  535. std::lock_guard lock(mSslMutex);
  536. int ret = SSL_write(mSsl, message->data(), int(message->size()));
  537. if (!openssl::check(mSsl, ret))
  538. throw std::runtime_error("TLS send failed");
  539. return flushOutput();
  540. }
  541. void TlsTransport::incoming(message_ptr message) {
  542. if (!message) {
  543. mIncomingQueue.stop();
  544. enqueueRecv();
  545. return;
  546. }
  547. PLOG_VERBOSE << "Incoming size=" << message->size();
  548. mIncomingQueue.push(message);
  549. enqueueRecv();
  550. }
  551. bool TlsTransport::outgoing(message_ptr message) { return Transport::outgoing(std::move(message)); }
  552. void TlsTransport::postHandshake() {
  553. // Dummy
  554. }
  555. void TlsTransport::doRecv() {
  556. std::lock_guard lock(mRecvMutex);
  557. --mPendingRecvCount;
  558. if (state() != State::Connecting && state() != State::Connected)
  559. return;
  560. try {
  561. const size_t bufferSize = 4096;
  562. byte buffer[bufferSize];
  563. // Process incoming messages
  564. while (mIncomingQueue.running()) {
  565. auto next = mIncomingQueue.pop();
  566. if (!next)
  567. return;
  568. message_ptr message = std::move(*next);
  569. if (message->size() > 0)
  570. BIO_write(mInBio, message->data(), int(message->size())); // Input
  571. else
  572. recv(message); // Pass zero-sized messages through
  573. if (state() == State::Connecting) {
  574. // Continue the handshake
  575. bool finished;
  576. {
  577. std::lock_guard lock(mSslMutex);
  578. int ret = SSL_do_handshake(mSsl);
  579. if (!openssl::check(mSsl, ret, "Handshake failed"))
  580. break;
  581. flushOutput();
  582. finished = (SSL_is_init_finished(mSsl) != 0);
  583. }
  584. if (finished) {
  585. PLOG_INFO << "TLS handshake finished";
  586. changeState(State::Connected);
  587. postHandshake();
  588. }
  589. }
  590. if (state() == State::Connected) {
  591. int ret;
  592. while (true) {
  593. {
  594. std::lock_guard lock(mSslMutex);
  595. ret = SSL_read(mSsl, buffer, bufferSize);
  596. }
  597. if (ret > 0)
  598. recv(make_message(buffer, buffer + ret));
  599. else
  600. break;
  601. }
  602. {
  603. std::lock_guard lock(mSslMutex);
  604. if (!openssl::check(mSsl, ret))
  605. break;
  606. flushOutput(); // SSL_read() can also cause write operations
  607. }
  608. }
  609. }
  610. } catch (const std::exception &e) {
  611. PLOG_ERROR << "TLS recv: " << e.what();
  612. }
  613. if (state() == State::Connected) {
  614. PLOG_INFO << "TLS closed";
  615. changeState(State::Disconnected);
  616. recv(nullptr);
  617. } else {
  618. PLOG_ERROR << "TLS handshake failed";
  619. changeState(State::Failed);
  620. }
  621. {
  622. std::lock_guard lock(mSslMutex);
  623. SSL_shutdown(mSsl);
  624. }
  625. }
  626. bool TlsTransport::flushOutput() {
  627. // Requires mSslMutex to be locked
  628. bool result = true;
  629. const size_t bufferSize = 4096;
  630. byte buffer[bufferSize];
  631. int len;
  632. while ((len = BIO_read(mOutBio, buffer, bufferSize)) > 0)
  633. result = outgoing(make_message(buffer, buffer + len));
  634. return result;
  635. }
  636. void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
  637. TlsTransport *t =
  638. static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
  639. if (where & SSL_CB_ALERT) {
  640. if (ret != 256) { // Close Notify
  641. PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
  642. }
  643. t->mIncomingQueue.stop(); // Close the connection
  644. }
  645. }
  646. #endif
  647. } // namespace rtc::impl
  648. #endif