Browse Source

Merge pull request #482 from paullouisageneau/use-poll

 Replace select() by poll() for WebSockets
Paul-Louis Ageneau 3 years ago
parent
commit
739bdf48b5

+ 2 - 2
CMakeLists.txt

@@ -119,7 +119,7 @@ set(LIBDATACHANNEL_IMPL_SOURCES
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.cpp
-	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/selectinterrupter.cpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollinterrupter.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.cpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.cpp
@@ -149,7 +149,7 @@ set(LIBDATACHANNEL_IMPL_HEADERS
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/processor.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/base64.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/sha.hpp
-	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/selectinterrupter.hpp
+	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/pollinterrupter.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcpserver.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tcptransport.hpp
 	${CMAKE_CURRENT_SOURCE_DIR}/src/impl/tlstransport.hpp

+ 9 - 9
src/impl/selectinterrupter.cpp → src/impl/pollinterrupter.cpp

@@ -16,7 +16,7 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#include "selectinterrupter.hpp"
+#include "pollinterrupter.hpp"
 #include "internals.hpp"
 
 #if RTC_ENABLE_WEBSOCKET
@@ -28,7 +28,7 @@
 
 namespace rtc::impl {
 
-SelectInterrupter::SelectInterrupter() {
+PollInterrupter::PollInterrupter() {
 #ifndef _WIN32
 	int pipefd[2];
 	if (::pipe(pipefd) != 0)
@@ -40,7 +40,7 @@ SelectInterrupter::SelectInterrupter() {
 #endif
 }
 
-SelectInterrupter::~SelectInterrupter() {
+PollInterrupter::~PollInterrupter() {
 	std::lock_guard lock(mMutex);
 #ifdef _WIN32
 	if (mDummySock != INVALID_SOCKET)
@@ -51,24 +51,24 @@ SelectInterrupter::~SelectInterrupter() {
 #endif
 }
 
-int SelectInterrupter::prepare(fd_set &readfds) {
+void PollInterrupter::prepare(struct pollfd &pfd) {
 	std::lock_guard lock(mMutex);
 #ifdef _WIN32
 	if (mDummySock == INVALID_SOCKET)
 		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
-	FD_SET(mDummySock, &readfds);
-	return SOCKET_TO_INT(mDummySock) + 1;
+	pfd.fd = mDummySock;
+	pfd.events = POLLIN;
 #else
 	char dummy;
 	if (::read(mPipeIn, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
 		PLOG_WARNING << "Reading from interrupter pipe failed, errno=" << errno;
 	}
-	FD_SET(mPipeIn, &readfds);
-	return mPipeIn + 1;
+	pfd.fd = mPipeIn;
+	pfd.events = POLLIN;
 #endif
 }
 
-void SelectInterrupter::interrupt() {
+void PollInterrupter::interrupt() {
 	std::lock_guard lock(mMutex);
 #ifdef _WIN32
 	if (mDummySock != INVALID_SOCKET) {

+ 7 - 7
src/impl/selectinterrupter.hpp → src/impl/pollinterrupter.hpp

@@ -16,8 +16,8 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#ifndef RTC_IMPL_SELECT_INTERRUPTER_H
-#define RTC_IMPL_SELECT_INTERRUPTER_H
+#ifndef RTC_IMPL_POLL_INTERRUPTER_H
+#define RTC_IMPL_POLL_INTERRUPTER_H
 
 #include "common.hpp"
 #include "socket.hpp"
@@ -28,13 +28,13 @@
 
 namespace rtc::impl {
 
-// Utility class to interrupt select()
-class SelectInterrupter final {
+// Utility class to interrupt poll()
+class PollInterrupter final {
 public:
-	SelectInterrupter();
-	~SelectInterrupter();
+	PollInterrupter();
+	~PollInterrupter();
 
-	int prepare(fd_set &readfds);
+	void prepare(struct pollfd &pfd);
 	void interrupt();
 
 private:

+ 5 - 3
src/impl/socket.hpp

@@ -49,13 +49,15 @@
 
 typedef SOCKET socket_t;
 typedef SOCKADDR sockaddr;
-typedef u_long ctl_t;
+typedef ULONG ctl_t;
 typedef DWORD sockopt_t;
 #define sockerrno ((int)WSAGetLastError())
 #define IP_DONTFRAG IP_DONTFRAGMENT
-#define SOCKET_TO_INT(x) 0
 #define HOST_NAME_MAX 256
 
+#define poll WSAPoll
+typedef ULONG nfds_t;
+
 #define SEADDRINUSE WSAEADDRINUSE
 #define SEINTR WSAEINTR
 #define SEAGAIN WSAEWOULDBLOCK
@@ -76,6 +78,7 @@ typedef DWORD sockopt_t;
 #include <netdb.h>
 #include <netinet/in.h>
 #include <netinet/tcp.h>
+#include <poll.h>
 #include <sys/ioctl.h>
 #include <sys/select.h>
 #include <sys/socket.h>
@@ -99,7 +102,6 @@ typedef int ctl_t;
 typedef int sockopt_t;
 #define sockerrno errno
 #define INVALID_SOCKET -1
-#define SOCKET_TO_INT(x) (x)
 #define ioctlsocket ioctl
 #define closesocket close
 

+ 11 - 6
src/impl/tcpserver.cpp

@@ -45,13 +45,14 @@ shared_ptr<TcpTransport> TcpServer::accept() {
 		if (mSock == INVALID_SOCKET)
 			break;
 
-		fd_set readfds;
-		FD_ZERO(&readfds);
-		FD_SET(mSock, &readfds);
-		int n = std::max(mInterrupter.prepare(readfds), SOCKET_TO_INT(mSock) + 1);
+		struct pollfd pfd[2];
+		pfd[0].fd = mSock;
+		pfd[0].events = POLLIN;
+		mInterrupter.prepare(pfd[1]);
 		lock.unlock();
-		int ret = ::select(n, &readfds, NULL, NULL, NULL);
+		int ret = ::poll(pfd, 2, -1);
 		lock.lock();
+
 		if (mSock == INVALID_SOCKET)
 			break;
 
@@ -62,7 +63,11 @@ shared_ptr<TcpTransport> TcpServer::accept() {
 				throw std::runtime_error("Failed to wait for socket connection");
 		}
 
-		if (FD_ISSET(mSock, &readfds)) {
+		if (pfd[0].revents & POLLNVAL || pfd[0].revents & POLLERR) {
+			throw std::runtime_error("Error while waiting for socket connection");
+		}
+
+		if (pfd[0].revents & POLLIN) {
 			struct sockaddr_storage addr;
 			socklen_t addrlen = sizeof(addr);
 			socket_t incomingSock = ::accept(mSock, (struct sockaddr *)&addr, &addrlen);

+ 2 - 1
src/impl/tcpserver.hpp

@@ -20,6 +20,7 @@
 #define RTC_IMPL_TCP_SERVER_H
 
 #include "common.hpp"
+#include "pollinterrupter.hpp"
 #include "queue.hpp"
 #include "socket.hpp"
 #include "tcptransport.hpp"
@@ -44,7 +45,7 @@ private:
 	uint16_t mPort;
 	socket_t mSock = INVALID_SOCKET;
 	std::mutex mSockMutex;
-	SelectInterrupter mInterrupter;
+	PollInterrupter mInterrupter;
 };
 
 } // namespace rtc::impl

+ 35 - 24
src/impl/tcptransport.cpp

@@ -26,8 +26,13 @@
 #include <unistd.h>
 #endif
 
+#include <chrono>
+
 namespace rtc::impl {
 
+using namespace std::chrono_literals;
+using std::chrono::milliseconds;
+
 TcpTransport::TcpTransport(string hostname, string service, state_callback callback)
     : Transport(nullptr, std::move(callback)), mIsActive(true), mHostname(std::move(hostname)),
       mService(std::move(service)) {
@@ -192,13 +197,11 @@ void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
 		}
 
 		while (true) {
-			fd_set writefds;
-			FD_ZERO(&writefds);
-			FD_SET(mSock, &writefds);
-			struct timeval tv;
-			tv.tv_sec = 10; // TODO: Make the timeout configurable
-			tv.tv_usec = 0;
-			ret = ::select(SOCKET_TO_INT(mSock) + 1, NULL, &writefds, NULL, &tv);
+			struct pollfd pfd[1];
+			pfd[0].fd = mSock;
+			pfd[0].events = POLLOUT;
+			milliseconds timeout = 10s; // TODO: Make the timeout configurable
+			int ret = ::poll(pfd, 1, int(timeout.count()));
 
 			if (ret < 0) {
 				if (sockerrno == SEINTR || sockerrno == SEAGAIN) // interrupted
@@ -207,7 +210,11 @@ void TcpTransport::connect(const sockaddr *addr, socklen_t addrlen) {
 					throw std::runtime_error("Failed to wait for socket connection");
 			}
 
-			if (ret == 0) {
+			if (pfd[0].revents & POLLNVAL || pfd[0].revents & POLLERR) {
+				throw std::runtime_error("Error while waiting for socket connection");
+			}
+
+			if (!(pfd[0].revents & POLLOUT)) {
 				std::ostringstream msg;
 				msg << "TCP connection to " << node << ":" << serv << " timed out";
 				throw std::runtime_error(msg.str());
@@ -310,29 +317,28 @@ void TcpTransport::runLoop() {
 
 		while (true) {
 			std::unique_lock lock(mSockMutex);
+
 			if (mSock == INVALID_SOCKET)
 				break;
 
-			fd_set readfds, writefds;
-			FD_ZERO(&readfds);
-			FD_ZERO(&writefds);
-			FD_SET(mSock, &readfds);
-			if (!mSendQueue.empty())
-				FD_SET(mSock, &writefds);
-
-			int n = std::max(mInterrupter.prepare(readfds), SOCKET_TO_INT(mSock) + 1);
-
-			struct timeval tv;
-			tv.tv_sec = 10;
-			tv.tv_usec = 0;
+			struct pollfd pfd[2];
+			pfd[0].fd = mSock;
+			pfd[0].events = !mSendQueue.empty() ? (POLLIN | POLLOUT) : POLLIN;
+			mInterrupter.prepare(pfd[1]);
+			milliseconds timeout = 10s;
 			lock.unlock();
-			int ret = ::select(n, &readfds, &writefds, NULL, &tv);
+			int ret = ::poll(pfd, 2, int(timeout.count()));
 			lock.lock();
+
 			if (mSock == INVALID_SOCKET)
 				break;
 
 			if (ret < 0) {
-				throw std::runtime_error("Failed to wait on socket");
+				if (sockerrno == SEINTR || sockerrno == SEAGAIN) // interrupted
+					continue;
+				else
+					throw std::runtime_error("Failed to wait on socket");
+
 			} else if (ret == 0) {
 				PLOG_VERBOSE << "TCP is idle";
 				lock.unlock(); // unlock now since the upper layer might send on incoming
@@ -340,10 +346,15 @@ void TcpTransport::runLoop() {
 				continue;
 			}
 
-			if (FD_ISSET(mSock, &writefds))
+			if (pfd[0].revents & POLLNVAL || pfd[0].revents & POLLERR) {
+				throw std::runtime_error("Error while waiting for socket connection");
+			}
+
+			if (pfd[0].revents & POLLOUT) {
 				trySendQueue();
+			}
 
-			if (FD_ISSET(mSock, &readfds)) {
+			if (pfd[0].revents & POLLIN) {
 				char buffer[bufferSize];
 				int len = ::recv(mSock, buffer, bufferSize, 0);
 				if (len < 0) {

+ 2 - 2
src/impl/tcptransport.hpp

@@ -21,7 +21,7 @@
 
 #include "common.hpp"
 #include "queue.hpp"
-#include "selectinterrupter.hpp"
+#include "pollinterrupter.hpp"
 #include "socket.hpp"
 #include "transport.hpp"
 
@@ -65,7 +65,7 @@ private:
 	socket_t mSock = INVALID_SOCKET;
 	std::mutex mSockMutex;
 	std::thread mThread;
-	SelectInterrupter mInterrupter;
+	PollInterrupter mInterrupter;
 	Queue<message_ptr> mSendQueue;
 };