tlstransport.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. /**
  2. * Copyright (c) 2020 Paul-Louis Ageneau
  3. *
  4. * This library is free software; you can redistribute it and/or
  5. * modify it under the terms of the GNU Lesser General Public
  6. * License as published by the Free Software Foundation; either
  7. * version 2.1 of the License, or (at your option) any later version.
  8. *
  9. * This library is distributed in the hope that it will be useful,
  10. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  12. * Lesser General Public License for more details.
  13. *
  14. * You should have received a copy of the GNU Lesser General Public
  15. * License along with this library; if not, write to the Free Software
  16. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  17. */
  18. #include "tlstransport.hpp"
  19. #include "tcptransport.hpp"
  20. #if RTC_ENABLE_WEBSOCKET
  21. #include <algorithm>
  22. #include <chrono>
  23. #include <cstring>
  24. #include <exception>
  25. using namespace std::chrono;
  26. namespace rtc::impl {
  27. #if USE_GNUTLS
  28. namespace {
  29. gnutls_certificate_credentials_t default_certificate_credentials() {
  30. static std::mutex mutex;
  31. static shared_ptr<gnutls_certificate_credentials_t> creds;
  32. std::lock_guard lock(mutex);
  33. if (!creds) {
  34. creds = shared_ptr<gnutls_certificate_credentials_t>(gnutls::new_credentials(),
  35. gnutls::free_credentials);
  36. gnutls::check(gnutls_certificate_set_x509_system_trust(*creds));
  37. }
  38. return *creds;
  39. }
  40. } // namespace
  41. void TlsTransport::Init() {
  42. // Nothing to do
  43. }
  44. void TlsTransport::Cleanup() {
  45. // Nothing to do
  46. }
  47. TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
  48. certificate_ptr certificate, state_callback callback)
  49. : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
  50. PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
  51. gnutls::check(gnutls_init(&mSession, mIsClient ? GNUTLS_CLIENT : GNUTLS_SERVER));
  52. try {
  53. const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
  54. const char *err_pos = NULL;
  55. gnutls::check(gnutls_priority_set_direct(mSession, priorities, &err_pos),
  56. "Failed to set TLS priorities");
  57. gnutls::check(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE,
  58. certificate ? certificate->credentials()
  59. : default_certificate_credentials()));
  60. if (mIsClient && mHost) {
  61. PLOG_VERBOSE << "Server Name Indication: " << *mHost;
  62. gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost->data(), mHost->size());
  63. }
  64. gnutls_session_set_ptr(mSession, this);
  65. gnutls_transport_set_ptr(mSession, this);
  66. gnutls_transport_set_push_function(mSession, WriteCallback);
  67. gnutls_transport_set_pull_function(mSession, ReadCallback);
  68. gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
  69. } catch (...) {
  70. gnutls_deinit(mSession);
  71. throw;
  72. }
  73. }
  74. TlsTransport::~TlsTransport() {
  75. stop();
  76. gnutls_deinit(mSession);
  77. }
  78. void TlsTransport::start() {
  79. Transport::start();
  80. registerIncoming();
  81. PLOG_DEBUG << "Starting TLS recv thread";
  82. mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
  83. }
  84. bool TlsTransport::stop() {
  85. if (!Transport::stop())
  86. return false;
  87. PLOG_DEBUG << "Stopping TLS recv thread";
  88. mIncomingQueue.stop();
  89. mRecvThread.join();
  90. return true;
  91. }
  92. bool TlsTransport::send(message_ptr message) {
  93. if (!message || state() != State::Connected)
  94. return false;
  95. PLOG_VERBOSE << "Send size=" << message->size();
  96. if (message->size() == 0)
  97. return true;
  98. ssize_t ret;
  99. do {
  100. ret = gnutls_record_send(mSession, message->data(), message->size());
  101. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
  102. return gnutls::check(ret);
  103. }
  104. void TlsTransport::incoming(message_ptr message) {
  105. if (!message) {
  106. mIncomingQueue.stop();
  107. return;
  108. }
  109. PLOG_VERBOSE << "Incoming size=" << message->size();
  110. mIncomingQueue.push(message);
  111. }
  112. void TlsTransport::postHandshake() {
  113. // Dummy
  114. }
  115. void TlsTransport::runRecvLoop() {
  116. const size_t bufferSize = 4096;
  117. char buffer[bufferSize];
  118. // Handshake loop
  119. try {
  120. changeState(State::Connecting);
  121. int ret;
  122. do {
  123. ret = gnutls_handshake(mSession);
  124. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
  125. !gnutls::check(ret, "TLS handshake failed"));
  126. } catch (const std::exception &e) {
  127. PLOG_ERROR << "TLS handshake: " << e.what();
  128. changeState(State::Failed);
  129. return;
  130. }
  131. // Receive loop
  132. try {
  133. PLOG_INFO << "TLS handshake finished";
  134. changeState(State::Connected);
  135. postHandshake();
  136. while (true) {
  137. ssize_t ret;
  138. do {
  139. ret = gnutls_record_recv(mSession, buffer, bufferSize);
  140. } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
  141. // Consider premature termination as remote closing
  142. if (ret == GNUTLS_E_PREMATURE_TERMINATION) {
  143. PLOG_DEBUG << "TLS connection terminated";
  144. break;
  145. }
  146. if (gnutls::check(ret)) {
  147. if (ret == 0) {
  148. // Closed
  149. PLOG_DEBUG << "TLS connection cleanly closed";
  150. break;
  151. }
  152. auto *b = reinterpret_cast<byte *>(buffer);
  153. recv(make_message(b, b + ret));
  154. }
  155. }
  156. } catch (const std::exception &e) {
  157. PLOG_ERROR << "TLS recv: " << e.what();
  158. }
  159. gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
  160. PLOG_INFO << "TLS closed";
  161. changeState(State::Disconnected);
  162. recv(nullptr);
  163. }
  164. ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len) {
  165. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  166. try {
  167. if (len > 0) {
  168. auto b = reinterpret_cast<const byte *>(data);
  169. t->outgoing(make_message(b, b + len));
  170. }
  171. gnutls_transport_set_errno(t->mSession, 0);
  172. return ssize_t(len);
  173. } catch (const std::exception &e) {
  174. PLOG_WARNING << e.what();
  175. gnutls_transport_set_errno(t->mSession, ECONNRESET);
  176. return -1;
  177. }
  178. }
  179. ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
  180. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  181. try {
  182. message_ptr &message = t->mIncomingMessage;
  183. size_t &position = t->mIncomingMessagePosition;
  184. if (message && position >= message->size())
  185. message.reset();
  186. if (!message) {
  187. position = 0;
  188. while (auto next = t->mIncomingQueue.pop()) {
  189. message = *next;
  190. if (message->size() > 0)
  191. break;
  192. else
  193. t->recv(message); // Pass zero-sized messages through
  194. }
  195. }
  196. if (message) {
  197. size_t available = message->size() - position;
  198. ssize_t len = std::min(maxlen, available);
  199. std::memcpy(data, message->data() + position, len);
  200. position += len;
  201. gnutls_transport_set_errno(t->mSession, 0);
  202. return len;
  203. } else {
  204. // Closed
  205. gnutls_transport_set_errno(t->mSession, 0);
  206. return 0;
  207. }
  208. } catch (const std::exception &e) {
  209. PLOG_WARNING << e.what();
  210. gnutls_transport_set_errno(t->mSession, ECONNRESET);
  211. return -1;
  212. }
  213. }
  214. int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
  215. TlsTransport *t = static_cast<TlsTransport *>(ptr);
  216. try {
  217. bool isReadable = t->mIncomingQueue.wait(
  218. ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
  219. return isReadable ? 1 : 0;
  220. } catch (const std::exception &e) {
  221. PLOG_WARNING << e.what();
  222. return 1;
  223. }
  224. }
  225. #else // USE_GNUTLS==0
  226. int TlsTransport::TransportExIndex = -1;
  227. void TlsTransport::Init() {
  228. openssl::init();
  229. if (TransportExIndex < 0) {
  230. TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
  231. }
  232. }
  233. void TlsTransport::Cleanup() {
  234. // Nothing to do
  235. }
  236. TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, optional<string> host,
  237. certificate_ptr certificate, state_callback callback)
  238. : Transport(lower, std::move(callback)), mHost(std::move(host)), mIsClient(lower->isActive()) {
  239. PLOG_DEBUG << "Initializing TLS transport (OpenSSL)";
  240. try {
  241. if (!(mCtx = SSL_CTX_new(SSLv23_method()))) // version-flexible
  242. throw std::runtime_error("Failed to create SSL context");
  243. openssl::check(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
  244. "Failed to set SSL priorities");
  245. if (certificate) {
  246. auto [x509, pkey] = certificate->credentials();
  247. SSL_CTX_use_certificate(mCtx, x509);
  248. SSL_CTX_use_PrivateKey(mCtx, pkey);
  249. } else {
  250. if (!SSL_CTX_set_default_verify_paths(mCtx)) {
  251. PLOG_WARNING << "SSL root CA certificates unavailable";
  252. }
  253. }
  254. SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3);
  255. SSL_CTX_set_min_proto_version(mCtx, TLS1_VERSION);
  256. SSL_CTX_set_read_ahead(mCtx, 1);
  257. SSL_CTX_set_quiet_shutdown(mCtx, 1);
  258. SSL_CTX_set_info_callback(mCtx, InfoCallback);
  259. SSL_CTX_set_verify(mCtx, SSL_VERIFY_NONE, NULL);
  260. if (!(mSsl = SSL_new(mCtx)))
  261. throw std::runtime_error("Failed to create SSL instance");
  262. SSL_set_ex_data(mSsl, TransportExIndex, this);
  263. if (mIsClient && mHost) {
  264. SSL_set_hostflags(mSsl, 0);
  265. openssl::check(SSL_set1_host(mSsl, mHost->c_str()), "Failed to set SSL host");
  266. PLOG_VERBOSE << "Server Name Indication: " << *mHost;
  267. SSL_set_tlsext_host_name(mSsl, mHost->c_str());
  268. }
  269. if (mIsClient)
  270. SSL_set_connect_state(mSsl);
  271. else
  272. SSL_set_accept_state(mSsl);
  273. if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
  274. throw std::runtime_error("Failed to create BIO");
  275. BIO_set_mem_eof_return(mInBio, BIO_EOF);
  276. BIO_set_mem_eof_return(mOutBio, BIO_EOF);
  277. SSL_set_bio(mSsl, mInBio, mOutBio);
  278. auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
  279. EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
  280. SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
  281. SSL_set_tmp_ecdh(mSsl, ecdh.get());
  282. } catch (...) {
  283. if (mSsl)
  284. SSL_free(mSsl);
  285. if (mCtx)
  286. SSL_CTX_free(mCtx);
  287. throw;
  288. }
  289. }
  290. TlsTransport::~TlsTransport() {
  291. stop();
  292. SSL_free(mSsl);
  293. SSL_CTX_free(mCtx);
  294. }
  295. void TlsTransport::start() {
  296. Transport::start();
  297. registerIncoming();
  298. PLOG_DEBUG << "Starting TLS recv thread";
  299. mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
  300. }
  301. bool TlsTransport::stop() {
  302. if (!Transport::stop())
  303. return false;
  304. PLOG_DEBUG << "Stopping TLS recv thread";
  305. mIncomingQueue.stop();
  306. mRecvThread.join();
  307. SSL_shutdown(mSsl);
  308. return true;
  309. }
  310. bool TlsTransport::send(message_ptr message) {
  311. if (!message || state() != State::Connected)
  312. return false;
  313. PLOG_VERBOSE << "Send size=" << message->size();
  314. if (message->size() == 0)
  315. return true;
  316. int ret = SSL_write(mSsl, message->data(), int(message->size()));
  317. if (!openssl::check(mSsl, ret))
  318. return false;
  319. const size_t bufferSize = 4096;
  320. byte buffer[bufferSize];
  321. while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
  322. outgoing(make_message(buffer, buffer + ret));
  323. return true;
  324. }
  325. void TlsTransport::incoming(message_ptr message) {
  326. if (!message) {
  327. mIncomingQueue.stop();
  328. return;
  329. }
  330. PLOG_VERBOSE << "Incoming size=" << message->size();
  331. mIncomingQueue.push(message);
  332. }
  333. void TlsTransport::postHandshake() {
  334. // Dummy
  335. }
  336. void TlsTransport::runRecvLoop() {
  337. const size_t bufferSize = 4096;
  338. byte buffer[bufferSize];
  339. try {
  340. changeState(State::Connecting);
  341. int ret;
  342. while (true) {
  343. if (state() == State::Connecting) {
  344. // Initiate or continue the handshake
  345. ret = SSL_do_handshake(mSsl);
  346. if (!openssl::check(mSsl, ret, "Handshake failed"))
  347. break;
  348. // Output
  349. while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
  350. outgoing(make_message(buffer, buffer + ret));
  351. if (SSL_is_init_finished(mSsl)) {
  352. PLOG_INFO << "TLS handshake finished";
  353. changeState(State::Connected);
  354. postHandshake();
  355. }
  356. }
  357. if (state() == State::Connected) {
  358. // Input
  359. while ((ret = SSL_read(mSsl, buffer, bufferSize)) > 0)
  360. recv(make_message(buffer, buffer + ret));
  361. if (!openssl::check(mSsl, ret))
  362. break;
  363. }
  364. auto next = mIncomingQueue.pop();
  365. if (!next)
  366. break;
  367. message_ptr message = std::move(*next);
  368. if (message->size() > 0)
  369. BIO_write(mInBio, message->data(), int(message->size())); // Input
  370. else
  371. recv(message); // Pass zero-sized messages through
  372. }
  373. } catch (const std::exception &e) {
  374. PLOG_ERROR << "TLS recv: " << e.what();
  375. }
  376. if (state() == State::Connected) {
  377. PLOG_INFO << "TLS closed";
  378. recv(nullptr);
  379. } else {
  380. PLOG_ERROR << "TLS handshake failed";
  381. }
  382. }
  383. void TlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
  384. TlsTransport *t =
  385. static_cast<TlsTransport *>(SSL_get_ex_data(ssl, TlsTransport::TransportExIndex));
  386. if (where & SSL_CB_ALERT) {
  387. if (ret != 256) { // Close Notify
  388. PLOG_ERROR << "TLS alert: " << SSL_alert_desc_string_long(ret);
  389. }
  390. t->mIncomingQueue.stop(); // Close the connection
  391. }
  392. }
  393. #endif
  394. } // namespace rtc::impl
  395. #endif