Browse Source

Implemented HTTP error response on WebSocket handshake failure

Paul-Louis Ageneau 4 years ago
parent
commit
54e52d25bf
4 changed files with 110 additions and 35 deletions
  1. 58 13
      src/impl/wshandshake.cpp
  2. 16 0
      src/impl/wshandshake.hpp
  3. 35 22
      src/impl/wstransport.cpp
  4. 1 0
      src/impl/wstransport.hpp

+ 58 - 13
src/impl/wshandshake.cpp

@@ -101,7 +101,7 @@ string WsHandshake::generateHttpRequest() {
 	             "Host: " +
 	             mHost +
 	             "\r\n"
-	             "Connection: Upgrade\r\n"
+	             "Connection: upgrade\r\n"
 	             "Upgrade: websocket\r\n"
 	             "Sec-WebSocket-Version: 13\r\n"
 	             "Sec-WebSocket-Key: " +
@@ -118,7 +118,8 @@ string WsHandshake::generateHttpRequest() {
 string WsHandshake::generateHttpResponse() {
 	std::unique_lock lock(mMutex);
 	const string out = "HTTP/1.1 101 Switching Protocols\r\n"
-	                   "Connection: Upgrade\r\n"
+	                   "Server: libdatachannel\r\n"
+	                   "Connection: upgrade\r\n"
 	                   "Upgrade: websocket\r\n"
 	                   "Sec-WebSocket-Accept: " +
 	                   computeAcceptKey(mKey) + "\r\n\r\n";
@@ -126,6 +127,43 @@ string WsHandshake::generateHttpResponse() {
 	return out;
 }
 
+namespace {
+
+string GetHttpErrorName(int responseCode) {
+	switch(responseCode) {
+	case 400:
+		return "Bad Request";
+	case 404:
+		return "Not Found";
+	case 405:
+		return "Method Not Allowed";
+	case 426:
+		return "Upgrade Required";
+	case 500:
+		return "Internal Server Error";
+	default:
+		return "Error";
+	}
+}
+
+}
+
+string WsHandshake::generateHttpError(int responseCode) {
+	std::unique_lock lock(mMutex);
+
+	const string error = to_string(responseCode) + " " + GetHttpErrorName(responseCode);
+
+	const string out = "HTTP/1.1 " + error + "\r\n"
+	                   "Server: libdatachannel\r\n"
+	                   "Connection: upgrade\r\n"
+	                   "Upgrade: websocket\r\n"
+	                   "Content-Type: text/plain\r\n"
+	                   "Content-Length: " + to_string(error.size()) + "\r\n"
+	                   "Access-Control-Allow-Origin: *\r\n\r\n" + error;
+
+	return out;
+}
+
 size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
 	std::unique_lock lock(mMutex);
 	std::list<string> lines;
@@ -134,7 +172,7 @@ size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
 		return 0;
 
 	if (lines.empty())
-		throw std::runtime_error("Invalid HTTP request for WebSocket");
+		throw RequestError("Invalid HTTP request for WebSocket", 400);
 
 	std::istringstream requestLine(std::move(lines.front()));
 	lines.pop_front();
@@ -143,7 +181,7 @@ size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
 	requestLine >> method >> path >> protocol;
 	PLOG_DEBUG << "WebSocket request method \"" << method << "\" for path: " << path;
 	if (method != "GET")
-		throw std::runtime_error("Unexpected request method \"" + method + "\" for WebSocket");
+		throw RequestError("Invalid request method \"" + method + "\" for WebSocket", 405);
 
 	mPath = std::move(path);
 
@@ -151,23 +189,23 @@ size_t WsHandshake::parseHttpRequest(const byte *buffer, size_t size) {
 
 	auto h = headers.find("host");
 	if (h == headers.end())
-		throw std::runtime_error("WebSocket host header missing in request");
+		throw RequestError("WebSocket host header missing in request", 400);
 
 	mHost = std::move(h->second);
 
 	h = headers.find("upgrade");
 	if (h == headers.end())
-		throw std::runtime_error("WebSocket update header missing in request");
+		throw RequestError("WebSocket upgrade header missing in request", 426);
 
 	string upgrade;
 	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
 	               [](char c) { return std::tolower(c); });
 	if (upgrade != "websocket")
-		throw std::runtime_error("WebSocket update header mismatching: " + h->second);
+		throw RequestError("WebSocket upgrade header mismatching: " + h->second, 426);
 
 	h = headers.find("sec-websocket-key");
 	if (h == headers.end())
-		throw std::runtime_error("WebSocket key header missing in request");
+		throw RequestError("WebSocket key header missing in request", 400);
 
 	mKey = std::move(h->second);
 
@@ -186,7 +224,7 @@ size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
 		return 0;
 
 	if (lines.empty())
-		throw std::runtime_error("Invalid HTTP response for WebSocket");
+		throw Error("Invalid HTTP response for WebSocket");
 
 	std::istringstream status(std::move(lines.front()));
 	lines.pop_front();
@@ -202,20 +240,20 @@ size_t WsHandshake::parseHttpResponse(const byte *buffer, size_t size) {
 
 	auto h = headers.find("upgrade");
 	if (h == headers.end())
-		throw std::runtime_error("WebSocket update header missing");
+		throw Error("WebSocket update header missing");
 
 	string upgrade;
 	std::transform(h->second.begin(), h->second.end(), std::back_inserter(upgrade),
 	               [](char c) { return std::tolower(c); });
 	if (upgrade != "websocket")
-		throw std::runtime_error("WebSocket update header mismatching: " + h->second);
+		throw Error("WebSocket update header mismatching: " + h->second);
 
 	h = headers.find("sec-websocket-accept");
 	if (h == headers.end())
-		throw std::runtime_error("WebSocket accept header missing");
+		throw Error("WebSocket accept header missing");
 
 	if (h->second != computeAcceptKey(mKey))
-		throw std::runtime_error("WebSocket accept header is invalid");
+		throw Error("WebSocket accept header is invalid");
 
 	return length;
 }
@@ -272,6 +310,13 @@ std::multimap<string, string> WsHandshake::parseHttpHeaders(const std::list<stri
 	return headers;
 }
 
+WsHandshake::Error::Error(const string &w) : std::runtime_error(w) {}
+
+WsHandshake::RequestError::RequestError(const string &w, int responseCode)
+    : Error(w), mResponseCode(responseCode) {}
+
+int WsHandshake::RequestError::RequestError::responseCode() const { return mResponseCode; }
+
 } // namespace rtc::impl
 
 #endif

+ 16 - 0
src/impl/wshandshake.hpp

@@ -39,6 +39,22 @@ public:
 
 	string generateHttpRequest();
 	string generateHttpResponse();
+	string generateHttpError(int responseCode = 400);
+
+	class Error : public std::runtime_error {
+	public:
+		explicit Error(const string &w);
+	};
+
+	class RequestError : public Error {
+	public:
+		explicit RequestError(const string &w, int responseCode = 400);
+		int responseCode() const;
+
+	private:
+		const int mResponseCode;
+	};
+
 	size_t parseHttpRequest(const byte *buffer, size_t size);
 	size_t parseHttpResponse(const byte *buffer, size_t size);
 

+ 35 - 22
src/impl/wstransport.cpp

@@ -105,19 +105,9 @@ void WsTransport::incoming(message_ptr message) {
 	if (message) {
 		PLOG_VERBOSE << "Incoming size=" << message->size();
 
-		if (message->size() == 0) {
-			if (state() == State::Connected) {
-				// TCP is idle, send a ping
-				PLOG_DEBUG << "WebSocket sending ping";
-				uint32_t dummy = 0;
-				sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
-			}
-			return;
-		}
-
-		mBuffer.insert(mBuffer.end(), message->begin(), message->end());
-
 		try {
+			mBuffer.insert(mBuffer.end(), message->begin(), message->end());
+
 			if (state() == State::Connecting) {
 				if (mIsClient) {
 					if (size_t len =
@@ -137,15 +127,35 @@ void WsTransport::incoming(message_ptr message) {
 			}
 
 			if (state() == State::Connected) {
-				Frame frame;
-				while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
-					recvFrame(frame);
-					mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+				if (message->size() == 0) {
+					// TCP is idle, send a ping
+					PLOG_DEBUG << "WebSocket sending ping";
+					uint32_t dummy = 0;
+					sendFrame({PING, reinterpret_cast<byte *>(&dummy), 4, true, mIsClient});
+
+				} else {
+					Frame frame;
+					while (size_t len = readFrame(mBuffer.data(), mBuffer.size(), frame)) {
+						recvFrame(frame);
+						mBuffer.erase(mBuffer.begin(), mBuffer.begin() + len);
+					}
 				}
 			}
 
 			return;
 
+		} catch (const WsHandshake::RequestError &e) {
+			PLOG_WARNING << e.what();
+			try {
+				sendHttpError(e.responseCode());
+
+			} catch (const std::exception &e) {
+				PLOG_WARNING << e.what();
+			}
+
+		} catch (const WsHandshake::Error &e) {
+			PLOG_WARNING << e.what();
+
 		} catch (const std::exception &e) {
 			PLOG_ERROR << e.what();
 		}
@@ -174,8 +184,7 @@ bool WsTransport::sendHttpRequest() {
 
 	const string request = mHandshake->generateHttpRequest();
 	auto data = reinterpret_cast<const byte *>(request.data());
-	auto size = request.size();
-	return outgoing(make_message(data, data + size));
+	return outgoing(make_message(data, data + request.size()));
 }
 
 bool WsTransport::sendHttpResponse() {
@@ -183,11 +192,15 @@ bool WsTransport::sendHttpResponse() {
 
 	const string response = mHandshake->generateHttpResponse();
 	auto data = reinterpret_cast<const byte *>(response.data());
-	auto size = response.size();
-	bool ret = outgoing(make_message(data, data + size));
+	return outgoing(make_message(data, data + response.size()));
+}
 
-	changeState(State::Connected);
-	return ret;
+bool WsTransport::sendHttpError(int code) {
+	PLOG_WARNING << "Sending WebSocket HTTP error response " << code;
+
+	const string response = mHandshake->generateHttpError(code);
+	auto data = reinterpret_cast<const byte *>(response.data());
+	return outgoing(make_message(data, data + response.size()));
 }
 
 // RFC6455 5.2. Base Framing Protocol

+ 1 - 0
src/impl/wstransport.hpp

@@ -64,6 +64,7 @@ private:
 	};
 
 	bool sendHttpRequest();
+	bool sendHttpError(int code);
 	bool sendHttpResponse();
 
 	size_t readFrame(byte *buffer, size_t size, Frame &frame);