|
@@ -62,31 +62,38 @@ void TlsTransport::Cleanup() {
|
|
|
}
|
|
|
|
|
|
TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_callback callback)
|
|
|
- : Transport(lower, std::move(callback)) {
|
|
|
+ : Transport(lower, std::move(callback)), mHost(std::move(host)) {
|
|
|
|
|
|
PLOG_DEBUG << "Initializing TLS transport (GnuTLS)";
|
|
|
|
|
|
+ check_gnutls(gnutls_certificate_allocate_credentials(&mCreds));
|
|
|
check_gnutls(gnutls_init(&mSession, GNUTLS_CLIENT));
|
|
|
|
|
|
try {
|
|
|
+ check_gnutls(gnutls_certificate_set_x509_system_trust(mCreds));
|
|
|
+ check_gnutls(gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCreds));
|
|
|
+ gnutls_session_set_verify_cert(mSession, mHost.c_str(), 0);
|
|
|
+
|
|
|
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128";
|
|
|
const char *err_pos = NULL;
|
|
|
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
|
|
|
"Failed to set TLS priorities");
|
|
|
|
|
|
- gnutls_session_set_ptr(mSession, this);
|
|
|
+ PLOG_VERBOSE << "Server Name Indication: " << mHost;
|
|
|
+ gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, mHost.data(), mHost.size());
|
|
|
+
|
|
|
+ gnutls_session_set_ptr(mSession, this);
|
|
|
gnutls_transport_set_ptr(mSession, this);
|
|
|
gnutls_transport_set_push_function(mSession, WriteCallback);
|
|
|
gnutls_transport_set_pull_function(mSession, ReadCallback);
|
|
|
gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
|
|
|
|
|
|
- gnutls_server_name_set(mSession, GNUTLS_NAME_DNS, host.data(), host.size());
|
|
|
-
|
|
|
- mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
|
|
+ mRecvThread = std::thread(&TlsTransport::runRecvLoop, this);
|
|
|
registerIncoming();
|
|
|
|
|
|
} catch (...) {
|
|
|
gnutls_deinit(mSession);
|
|
|
+ gnutls_certificate_free_credentials(mCreds);
|
|
|
throw;
|
|
|
}
|
|
|
}
|
|
@@ -94,6 +101,7 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
|
|
|
TlsTransport::~TlsTransport() {
|
|
|
stop();
|
|
|
gnutls_deinit(mSession);
|
|
|
+ gnutls_certificate_free_credentials(mCreds);
|
|
|
}
|
|
|
|
|
|
bool TlsTransport::stop() {
|
|
@@ -111,6 +119,9 @@ bool TlsTransport::send(message_ptr message) {
|
|
|
return false;
|
|
|
|
|
|
PLOG_VERBOSE << "Send size=" << message->size();
|
|
|
+ if(message->size() == 0)
|
|
|
+ return true;
|
|
|
+
|
|
|
ssize_t ret;
|
|
|
do {
|
|
|
ret = gnutls_record_send(mSession, message->data(), message->size());
|
|
@@ -196,20 +207,37 @@ ssize_t TlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *data
|
|
|
|
|
|
ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
|
|
|
TlsTransport *t = static_cast<TlsTransport *>(ptr);
|
|
|
- while (auto next = t->mIncomingQueue.pop()) {
|
|
|
- auto message = *next;
|
|
|
- if (message->size() > 0) {
|
|
|
- ssize_t len = std::min(maxlen, message->size());
|
|
|
- std::memcpy(data, message->data(), len);
|
|
|
- gnutls_transport_set_errno(t->mSession, 0);
|
|
|
- return len;
|
|
|
+
|
|
|
+ message_ptr &message = t->mIncomingMessage;
|
|
|
+ size_t &position = t->mIncomingMessagePosition;
|
|
|
+
|
|
|
+ if(message && position >= message->size())
|
|
|
+ message.reset();
|
|
|
+
|
|
|
+ if(!message) {
|
|
|
+ position = 0;
|
|
|
+ while (auto next = t->mIncomingQueue.pop()) {
|
|
|
+ message = *next;
|
|
|
+ if (message->size() > 0)
|
|
|
+ break;
|
|
|
+
|
|
|
+ t->recv(message); // Pass zero-sized messages through
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- t->recv(message); // Pass zero-sized messages through
|
|
|
+ if(message) {
|
|
|
+ size_t available = message->size() - position;
|
|
|
+ ssize_t len = std::min(maxlen, available);
|
|
|
+ std::memcpy(data, message->data() + position, len);
|
|
|
+ position+= len;
|
|
|
+ gnutls_transport_set_errno(t->mSession, 0);
|
|
|
+ return len;
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ // Closed
|
|
|
+ gnutls_transport_set_errno(t->mSession, 0);
|
|
|
+ return 0;
|
|
|
}
|
|
|
- // Closed
|
|
|
- gnutls_transport_set_errno(t->mSession, 0);
|
|
|
- return 0;
|
|
|
}
|
|
|
|
|
|
int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
|
|
@@ -308,6 +336,8 @@ TlsTransport::TlsTransport(shared_ptr<TcpTransport> lower, string host, state_ca
|
|
|
throw std::runtime_error("Failed to create SSL instance");
|
|
|
|
|
|
SSL_set_ex_data(mSsl, TransportExIndex, this);
|
|
|
+
|
|
|
+ PLOG_VERBOSE << "Server Name Indication: " << host;
|
|
|
SSL_set_tlsext_host_name(mSsl, host.c_str());
|
|
|
|
|
|
SSL_set_connect_state(mSsl);
|