Browse Source

Fixed TLS send return value to reflect TCP buffering

Paul-Louis Ageneau 3 years ago
parent
commit
d03f617c9e
3 changed files with 32 additions and 15 deletions
  1. 1 1
      src/impl/tcptransport.cpp
  2. 28 14
      src/impl/tlstransport.cpp
  3. 3 0
      src/impl/tlstransport.hpp

+ 1 - 1
src/impl/tcptransport.cpp

@@ -95,7 +95,7 @@ bool TcpTransport::send(message_ptr message) {
 	if (state() != State::Connected)
 	if (state() != State::Connected)
 		throw std::runtime_error("Connection is not open");
 		throw std::runtime_error("Connection is not open");
 
 
-	if (!message)
+	if (!message || message->size() == 0)
 		return trySendQueue();
 		return trySendQueue();
 
 
 	PLOG_VERBOSE << "Send size=" << message->size();
 	PLOG_VERBOSE << "Send size=" << message->size();

+ 28 - 14
src/impl/tlstransport.cpp

@@ -118,20 +118,23 @@ bool TlsTransport::stop() {
 }
 }
 
 
 bool TlsTransport::send(message_ptr message) {
 bool TlsTransport::send(message_ptr message) {
-	if (!message || state() != State::Connected)
-		return false;
+	if (state() != State::Connected)
+		throw std::runtime_error("TLS is not open");
 
 
-	PLOG_VERBOSE << "Send size=" << message->size();
+	if (!message || message->size() == 0)
+		return outgoing(message); // pass through
 
 
-	if (message->size() == 0)
-		return true;
+	PLOG_VERBOSE << "Send size=" << message->size();
 
 
 	ssize_t ret;
 	ssize_t ret;
 	do {
 	do {
 		ret = gnutls_record_send(mSession, message->data(), message->size());
 		ret = gnutls_record_send(mSession, message->data(), message->size());
 	} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 	} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
 
 
-	return gnutls::check(ret);
+	if (!gnutls::check(ret))
+		throw std::runtime_error("TLS send failed");
+
+	return mOutgoingResult;
 }
 }
 
 
 void TlsTransport::incoming(message_ptr message) {
 void TlsTransport::incoming(message_ptr message) {
@@ -144,6 +147,12 @@ void TlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 	mIncomingQueue.push(message);
 }
 }
 
 
+bool TlsTransport::outgoing(message_ptr message) {
+	bool result = Transport::outgoing(std::move(message));
+	mOutgoingResult = result;
+	return result;
+}
+
 void TlsTransport::postHandshake() {
 void TlsTransport::postHandshake() {
 	// Dummy
 	// Dummy
 }
 }
@@ -390,24 +399,25 @@ bool TlsTransport::stop() {
 }
 }
 
 
 bool TlsTransport::send(message_ptr message) {
 bool TlsTransport::send(message_ptr message) {
-	if (!message || state() != State::Connected)
-		return false;
+	if (state() != State::Connected)
+		throw std::runtime_error("TLS is not open");
 
 
-	PLOG_VERBOSE << "Send size=" << message->size();
+	if (!message || message->size() == 0)
+		return outgoing(message); // pass through
 
 
-	if (message->size() == 0)
-		return true;
+	PLOG_VERBOSE << "Send size=" << message->size();
 
 
 	int ret = SSL_write(mSsl, message->data(), int(message->size()));
 	int ret = SSL_write(mSsl, message->data(), int(message->size()));
 	if (!openssl::check(mSsl, ret))
 	if (!openssl::check(mSsl, ret))
-		return false;
+		throw std::runtime_error("TLS send failed");
 
 
 	const size_t bufferSize = 4096;
 	const size_t bufferSize = 4096;
 	byte buffer[bufferSize];
 	byte buffer[bufferSize];
+	bool result = true;
 	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
 	while ((ret = BIO_read(mOutBio, buffer, bufferSize)) > 0)
-		outgoing(make_message(buffer, buffer + ret));
+		result = outgoing(make_message(buffer, buffer + ret));
 
 
-	return true;
+	return result;
 }
 }
 
 
 void TlsTransport::incoming(message_ptr message) {
 void TlsTransport::incoming(message_ptr message) {
@@ -420,6 +430,10 @@ void TlsTransport::incoming(message_ptr message) {
 	mIncomingQueue.push(message);
 	mIncomingQueue.push(message);
 }
 }
 
 
+bool TlsTransport::outgoing(message_ptr message) {
+	return Transport::outgoing(std::move(message));
+}
+
 void TlsTransport::postHandshake() {
 void TlsTransport::postHandshake() {
 	// Dummy
 	// Dummy
 }
 }

+ 3 - 0
src/impl/tlstransport.hpp

@@ -27,6 +27,7 @@
 
 
 #if RTC_ENABLE_WEBSOCKET
 #if RTC_ENABLE_WEBSOCKET
 
 
+#include <atomic>
 #include <thread>
 #include <thread>
 
 
 namespace rtc::impl {
 namespace rtc::impl {
@@ -50,6 +51,7 @@ public:
 
 
 protected:
 protected:
 	virtual void incoming(message_ptr message) override;
 	virtual void incoming(message_ptr message) override;
+	virtual bool outgoing(message_ptr message) override;
 	virtual void postHandshake();
 	virtual void postHandshake();
 	void runRecvLoop();
 	void runRecvLoop();
 
 
@@ -64,6 +66,7 @@ protected:
 
 
 	message_ptr mIncomingMessage;
 	message_ptr mIncomingMessage;
 	size_t mIncomingMessagePosition = 0;
 	size_t mIncomingMessagePosition = 0;
+	std::atomic<bool> mOutgoingResult = true;
 
 
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
 	static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);