tlstransport.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This library is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * This library is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with this library; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #include "tlstransport.hpp"
  19. #include "tcptransport.hpp"
  20. #if RTC_ENABLE_WEBSOCKET
  21. #include <chrono>
  22. #include <cstring>
  23. #include <exception>
  24. #include <iostream>
  25. using namespace std::chrono;
  26. namespace rtc::impl {
  27. #if USE_GNUTLS
  28. namespace {
  29. gnutls_certificate_credentials_t default_certificate_credentials() {
  30. static std::mutex mutex;
  31. static shared_ptr<gnutls_certificate_credentials_t> creds;
  32. std::lock_guard lock(mutex);
  33. if (!creds) {
  34. creds = shared_ptr<gnutls_certificate_credentials_t>(gnutls::new_credentials(),
  35. gnutls::free_credentials);
  36. gnutls::check(gnutls_certificate_set_x509_system_trust(*creds));
  37. }
  38. return *creds;
  39. }
  40. } // namespace
  41. void TlsTransport::Init() {
  42. // Nothing to do
  43. }
  44. void TlsTransport::Cleanup() {
  45. // Nothing to do
  46. }
  47. TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
  48. certificate_ptr certificate, state_callback callback)
  49. : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
  50. PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
  51. gnutls::check(gnutls_init(&mSession, mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER));
  52. try {
  53. const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
  54. const char *err_pos = NULL;
  55. gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
  56. "Failed to set TLS priorities");
  57. gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE,
  58. certificate ? certificate->credentials()
  59. : default_certificate_credentials()));
  60. if (mIsClient && mHost) {
  61. PLOG_VERBOSE << "Server Name Indication: " << *mHost;
  62. gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size());
  63. }
  64. gnutls_session_set_ptr(mSession, this);
  65. gnutls_transport_set_ptr(mSession, this);
  66. gnutls_transport_set_push_function(mSession, WriteCallback);
  67. gnutls_transport_set_pull_function(mSession, ReadCallback);
  68. gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
  69. } catch (...) {
  70. gnutls_deinit(mSession);
  71. throw;
  72. }
  73. }
  74. TlsTransport::~TlsTransport() {
  75. stop();
  76. gnutls_deinit(mSession);
  77. }
  78. void TlsTransport::start() {
  79. Transport::start();
  80. registerIncoming();
  81. PLOG_DEBUG << "Starting TLS recv thread";
  82. mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
  83. }
  84. bool TlsTransport::stop() {
  85. if (!Transport::stop())
  86. return false;
  87. PLOG_DEBUG << "Stopping TLS recv thread";
  88. mIncomingQueue.stop();
  89. mRecvThread.join();
  90. return true;
  91. }
  92. bool TlsTransport::send(message_ptr message) {
  93. if (!message || state() != State::Connected)
  94. return false;
  95. PLOG_VERBOSE << "Send size=" << message->size();
  96. if (message->size() == 0)
  97. return true;
  98. ssize_t ret;
  99. do {
  100. ret = gnutls_record_send(mSession, message->data(), message->size());
  101. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
  102. return gnutls::check(ret);
  103. }
  104. void TlsTransport::incoming(message_ptr message) {
  105. if (!message) {
  106. mIncomingQueue.stop();
  107. return;
  108. }
  109. PLOG_VERBOSE << "Incoming size=" << message->size();
  110. mIncomingQueue.push(message);
  111. }
  112. void TlsTransport::postHandshake() {
  113. // Dummy
  114. }
  115. void TlsTransport::runRecvLoop() {
  116. const size_t bufferSize = 4096;
  117. char buffer[bufferSize];
  118. // Handshake loop
  119. try {
  120. changeState(State::Connecting);
  121. int ret;
  122. do {
  123. ret = gnutls_handshake(mSession);
  124. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
  125. !gnutls::check(ret, "TLS handshake failed"));
  126. } catch (const std::exception &e) {
  127. PLOG_ERROR << "TLS handshake: " << e.what();
  128. changeState(State::Failed);
  129. return;
  130. }
  131. // Receive loop
  132. try {
  133. PLOG_INFO << "TLS handshake finished";
  134. changeState(State::Connected);
  135. postHandshake();
  136. while (true) {
  137. ssize_t ret;
  138. do {
  139. ret = gnutls_record_recv(mSession, buffer, bufferSize);
  140. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
  141. // Consider premature termination as remote closing
  142. if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
  143. PLOG_DEBUG << "TLS connection terminated";
  144. break;
  145. }
  146. if (gnutls::check(ret)) {
  147. if (ret == 0) {
  148. // Closed
  149. PLOG_DEBUG << "TLS connection cleanly closed";
  150. break;
  151. }
  152. auto *b = reinterpret_cast<byte *>(buffer);
  153. recv(make_message(b, b + ret));
  154. }
  155. }
  156. } catch (const std::exception &e) {
  157. PLOG_ERROR << "TLS recv: " << e.what();
  158. }
  159. gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
  160. PLOG_INFO << "TLS closed";
  161. changeState(State::Disconnected);
  162. recv(nullptr);
  163. }
  164. ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
  165. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  166. if (len > 0) {
  167. auto b = reinterpret_cast<const byte *>(data);
  168. t->outgoing(make_message(b, b + len));
  169. }
  170. gnutls_transport_set_errno(t->mSession, 0);
  171. return ssize_t(len);
  172. }
  173. ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
  174. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  175. message_ptr &message = t->mIncomingMessage;
  176. size_t &position = t->mIncomingMessagePosition;
  177. if (message && position >= message->size())
  178. message.reset();
  179. if (!message) {
  180. position = 0;
  181. while (auto next = t->mIncomingQueue.pop()) {
  182. message = *next;
  183. if (message->size() > 0)
  184. break;
  185. else
  186. t->recv(message); // Pass zero-sized messages through
  187. }
  188. }
  189. if (message) {
  190. size_t available = message->size() - position;
  191. ssize_t len = std::min(maxlen, available);
  192. std::memcpy(data, message->data() + position, len);
  193. position += len;
  194. gnutls_transport_set_errno(t->mSession, 0);
  195. return len;
  196. } else {
  197. // Closed
  198. gnutls_transport_set_errno(t->mSession, 0);
  199. return 0;
  200. }
  201. }
  202. int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
  203. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  204. bool notEmpty = t->mIncomingQueue.wait(
  205. ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
  206. return notEmpty ? 1 : 0;
  207. }
  208. #else // USE_GNUTLS==0
  209. int TlsTransport::TransportExIndex = -1;
  210. void TlsTransport::Init() {
  211. openssl::init();
  212. if (TransportExIndex < 0) {
  213. TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
  214. }
  215. }
  216. void TlsTransport::Cleanup() {
  217. // Nothing to do
  218. }
  219. TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
  220. certificate_ptr certificate, state_callback callback)
  221. : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
  222. PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
  223. try {
  224. if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
  225. throw std::runtime_error("Failed to create SSL context");
  226. openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
  227. "Failed to set SSL priorities");
  228. if (certificate) {
  229. auto [x509, pkey] = certificate->credentials();
  230. SSL_CTX_use_certificate(mCtx, x509);
  231. SSL_CTX_use_PrivateKey(mCtx, pkey);
  232. } else {
  233. if (!SSL_CTX_set_default_verify_paths(mCtx)) {
  234. PLOG_WARNING << "SSL root CA certificates unavailable";
  235. }
  236. }
  237. SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
  238. SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
  239. SSL_CTX_set_read_ahead(mCtx, 1);
  240. SSL_CTX_set_quiet_shutdown(mCtx, 1);
  241. SSL_CTX_set_info_callback(mCtx, InfoCallback);
  242. SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL);
  243. if (!(mSsl = SSL_new(mCtx)))
  244. throw std::runtime_error("Failed to create SSL instance");
  245. SSL_set_ex_data(mSsl, TransportExIndex, this);
  246. if (mIsClient && mHost) {
  247. SSL_set_hostflags(mSsl, 0);
  248. openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host");
  249. PLOG_VERBOSE << "Server Name Indication: " << *mHost;
  250. SSL_set_tlsext_host_name(mSsl, mHost->c_str());
  251. }
  252. if (mIsClient)
  253. SSL_set_connect_state(mSsl);
  254. else
  255. SSL_set_accept_state(mSsl);
  256. if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
  257. throw std::runtime_error("Failed to create BIO");
  258. BIO_set_mem_eof_return(mInBio, BIO_EOF);
  259. BIO_set_mem_eof_return(mOutBio, BIO_EOF);
  260. SSL_set_bio(mSsl, mInBio, mOutBio);
  261. auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
  262. EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
  263. SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
  264. SSL_set_tmp_ecdh(mSsl, ecdh.get());
  265. } catch (...) {
  266. if (mSsl)
  267. SSL_free(mSsl);
  268. if (mCtx)
  269. SSL_CTX_free(mCtx);
  270. throw;
  271. }
  272. }
  273. TlsTransport::~TlsTransport() {
  274. stop();
  275. SSL_free(mSsl);
  276. SSL_CTX_free(mCtx);
  277. }
  278. void TlsTransport::start() {
  279. Transport::start();
  280. registerIncoming();
  281. PLOG_DEBUG << "Starting TLS recv thread";
  282. mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
  283. }
  284. bool TlsTransport::stop() {
  285. if (!Transport::stop())
  286. return false;
  287. PLOG_DEBUG << "Stopping TLS recv thread";
  288. mIncomingQueue.stop();
  289. mRecvThread.join();
  290. SSL_shutdown(mSsl);
  291. return true;
  292. }
  293. bool TlsTransport::send(message_ptr message) {
  294. if (!message || state() != State::Connected)
  295. return false;
  296. PLOG_VERBOSE << "Send size=" << message->size();
  297. if (message->size() == 0)
  298. return true;
  299. int ret = SSL_write(mSsl, message->data(), int(message->size()));
  300. if (!openssl::check(mSsl, ret))
  301. return false;
  302. const size_t bufferSize = 4096;
  303. byte buffer[bufferSize];
  304. while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
  305. outgoing(make_message(buffer, buffer + ret));
  306. return true;
  307. }
  308. void TlsTransport::incoming(message_ptr message) {
  309. if (!message) {
  310. mIncomingQueue.stop();
  311. return;
  312. }
  313. PLOG_VERBOSE << "Incoming size=" << message->size();
  314. mIncomingQueue.push(message);
  315. }
  316. void TlsTransport::postHandshake() {
  317. // Dummy
  318. }
  319. void TlsTransport::runRecvLoop() {
  320. const size_t bufferSize = 4096;
  321. byte buffer[bufferSize];
  322. try {
  323. changeState(State::Connecting);
  324. while (true) {
  325. if (state() == State::Connecting) {
  326. // Initiate or continue the handshake
  327. int ret = SSL_do_handshake(mSsl);
  328. if (!openssl::check(mSsl, ret, "Handshake failed"))
  329. break;
  330. // Output
  331. while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
  332. outgoing(make_message(buffer, buffer + ret));
  333. if (SSL_is_init_finished(mSsl)) {
  334. PLOG_INFO << "TLS handshake finished";
  335. changeState(State::Connected);
  336. postHandshake();
  337. }
  338. } else {
  339. int ret = SSL_read(mSsl, buffer, bufferSize);
  340. if (!openssl::check(mSsl, ret))
  341. break;
  342. if (ret > 0)
  343. recv(make_message(buffer, buffer + ret));
  344. }
  345. auto next = mIncomingQueue.pop();
  346. if (!next)
  347. break;
  348. message_ptr message = std::move(*next);
  349. if (message->size() > 0)
  350. BIO_write(mInBio, message->data(), int(message->size())); // Input
  351. else
  352. recv(message); // Pass zero-sized messages through
  353. }
  354. } catch (const std::exception &e) {
  355. PLOG_ERROR << "TLS recv: " << e.what();
  356. }
  357. if (state() == State::Connected) {
  358. PLOG_INFO << "TLS closed";
  359. recv(nullptr);
  360. } else {
  361. PLOG_ERROR << "TLS handshake failed";
  362. }
  363. }
  364. void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
  365. TlsTransport *t =
  366. static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
  367. if (where & SSL_CB_ALERT) {
  368. if (ret != 256) { // Close Notify
  369. PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
  370. }
  371. t->mIncomingQueue.stop(); // Close the connection
  372. }
  373. }
  374. #endif
  375. } // namespace rtc::impl
  376. #endif