浏览代码

Enhanced poll interrupt on Windows consistently with libjuice

Paul-Louis Ageneau 3 年之前
父节点
当前提交
aa0583632e

+ 65 - 17
src/impl/pollinterrupter.cpp

@@ -29,10 +29,54 @@
 namespace rtc::impl {
 
 PollInterrupter::PollInterrupter() {
-#ifndef _WIN32
+#ifdef _WIN32
+	struct addrinfo *ai = NULL;
+	struct addrinfo hints;
+	memset(&hints, 0, sizeof(hints));
+	hints.ai_family = AF_UNSPEC;
+	hints.ai_socktype = SOCK_DGRAM;
+	hints.ai_protocol = IPPROTO_UDP;
+	hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
+	if (getaddrinfo("localhost", "0", &hints, &ai) != 0)
+		throw std::runtime_error("Resolution failed for localhost address");
+
+	try {
+		mSock = ::socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
+		if (mSock == INVALID_SOCKET)
+			throw std::runtime_error("UDP socket creation failed");
+
+		// Set non-blocking
+		ctl_t nbio = 1;
+		::ioctlsocket(mSock, FIONBIO, &nbio);
+
+		// Bind
+		if (::bind(mSock, ai->ai_addr, (socklen_t)ai->ai_addrlen) < 0)
+			throw std::runtime_error("Failed to bind UDP socket");
+
+		// Connect to self
+		struct sockaddr_storage addr;
+		socklen_t addrlen = sizeof(addr);
+		if (::getsockname(mSock, reinterpret_cast<struct sockaddr *>(&addr), &addrlen) < 0)
+			throw std::runtime_error("getsockname failed");
+
+		if (::connect(mSock, reinterpret_cast<struct sockaddr *>(&addr), addrlen) < 0)
+			throw std::runtime_error("Failed to connect UDP socket");
+
+	} catch (...) {
+		freeaddrinfo(ai);
+		if (mSock != INVALID_SOCKET)
+			::closesocket(mSock);
+
+		throw;
+	}
+
+	freeaddrinfo(ai);
+
+#else
 	int pipefd[2];
 	if (::pipe(pipefd) != 0)
 		throw std::runtime_error("Failed to create pipe");
+
 	::fcntl(pipefd[0], F_SETFL, O_NONBLOCK);
 	::fcntl(pipefd[1], F_SETFL, O_NONBLOCK);
 	mPipeOut = pipefd[1]; // read
@@ -41,10 +85,8 @@ PollInterrupter::PollInterrupter() {
 }
 
 PollInterrupter::~PollInterrupter() {
-	std::lock_guard lock(mMutex);
 #ifdef _WIN32
-	if (mDummySock != INVALID_SOCKET)
-		::closesocket(mDummySock);
+	::closesocket(mSock);
 #else
 	::close(mPipeIn);
 	::close(mPipeOut);
@@ -52,28 +94,34 @@ PollInterrupter::~PollInterrupter() {
 }
 
 void PollInterrupter::prepare(struct pollfd &pfd) {
-	std::lock_guard lock(mMutex);
 #ifdef _WIN32
-	if (mDummySock == INVALID_SOCKET)
-		mDummySock = ::socket(AF_INET, SOCK_DGRAM, 0);
-	pfd.fd = mDummySock;
-	pfd.events = POLLIN;
+	pfd.fd = mSock;
 #else
-	char dummy;
-	if (::read(mPipeIn, &dummy, 1) < 0 && errno != EAGAIN && errno != EWOULDBLOCK) {
-		PLOG_WARNING << "Reading from interrupter pipe failed, errno=" << errno;
-	}
 	pfd.fd = mPipeIn;
+#endif
 	pfd.events = POLLIN;
+}
+
+void PollInterrupter::process(struct pollfd &pfd) {
+	if (pfd.revents & POLLIN) {
+#ifdef _WIN32
+		char dummy;
+		while (::recv(pfd.fd, &dummy, 1, 0) >= 0) {
+			// Ignore
+		}
+#else
+		char dummy;
+		while (::read(pfd.fd, &dummy, 1) > 0) {
+			// Ignore
+		}
 #endif
+	}
 }
 
 void PollInterrupter::interrupt() {
-	std::lock_guard lock(mMutex);
 #ifdef _WIN32
-	if (mDummySock != INVALID_SOCKET) {
-		::closesocket(mDummySock);
-		mDummySock = INVALID_SOCKET;
+	if (::send(mSock, NULL, 0, 0) < 0 && sockerrno != SEAGAIN && sockerrno != SEWOULDBLOCK) {
+		PLOG_WARNING << "Writing to interrupter socket failed, errno=" << sockerrno;
 	}
 #else
 	char dummy = 0;

+ 5 - 4
src/impl/pollinterrupter.hpp

@@ -24,8 +24,6 @@
 
 #if RTC_ENABLE_WEBSOCKET
 
-#include <mutex>
-
 namespace rtc::impl {
 
 // Utility class to interrupt poll()
@@ -34,13 +32,16 @@ public:
 	PollInterrupter();
 	~PollInterrupter();
 
+	PollInterrupter(const PollInterrupter &other) = delete;
+	void operator=(const PollInterrupter &other) = delete;
+
 	void prepare(struct pollfd &pfd);
+	void process(struct pollfd &pfd);
 	void interrupt();
 
 private:
-	std::mutex mMutex;
 #ifdef _WIN32
-	socket_t mDummySock = INVALID_SOCKET;
+	socket_t mSock;
 #else // assume POSIX
 	int mPipeIn, mPipeOut;
 #endif

+ 53 - 55
src/impl/pollservice.cpp

@@ -40,6 +40,7 @@ PollService::~PollService() {}
 
 void PollService::start() {
 	mSocks = std::make_unique<SocketMap>();
+	mInterrupter = std::make_unique<PollInterrupter>();
 	mStopped = false;
 	mThread = std::thread(&PollService::runLoop, this);
 }
@@ -51,35 +52,35 @@ void PollService::join() {
 
 	lock.unlock();
 
-	mInterrupter.interrupt();
+	mInterrupter->interrupt();
 	mThread.join();
+
 	mSocks.reset();
+	mInterrupter.reset();
 }
 
 void PollService::add(socket_t sock, Params params) {
 	std::unique_lock lock(mMutex);
-	assert(mSocks);
-
-	mSocks->erase(sock);
-
-	if (!params.callback)
-		return;
+	assert(params.callback);
 
 	PLOG_VERBOSE << "Registering socket in poll service, direction=" << params.direction;
 	auto until = params.timeout ? std::make_optional(clock::now() + *params.timeout) : nullopt;
-	mSocks->emplace(sock, SocketEntry{std::move(params), std::move(until)});
+	assert(mSocks);
+	mSocks->insert_or_assign(sock, SocketEntry{std::move(params), std::move(until)});
 
-	mInterrupter.interrupt();
+	assert(mInterrupter);
+	mInterrupter->interrupt();
 }
 
 void PollService::remove(socket_t sock) {
 	std::unique_lock lock(mMutex);
-	assert(mSocks);
 
 	PLOG_VERBOSE << "Unregistering socket in poll service";
+	assert(mSocks);
 	mSocks->erase(sock);
 
-	mInterrupter.interrupt();
+	assert(mInterrupter);
+	mInterrupter->interrupt();
 }
 
 void PollService::prepare(std::vector<struct pollfd> &pfds, optional<clock::time_point> &next) {
@@ -88,7 +89,7 @@ void PollService::prepare(std::vector<struct pollfd> &pfds, optional<clock::time
 	next.reset();
 
 	auto it = pfds.begin();
-	mInterrupter.prepare(*it++);
+	mInterrupter->prepare(*it++);
 	for (const auto &[sock, entry] : *mSocks) {
 		it->fd = sock;
 		switch (entry.params.direction) {
@@ -111,55 +112,52 @@ void PollService::prepare(std::vector<struct pollfd> &pfds, optional<clock::time
 
 void PollService::process(std::vector<struct pollfd> &pfds) {
 	std::unique_lock lock(mMutex);
-	for (auto it = pfds.begin(); it != pfds.end(); ++it) {
+
+	auto it = pfds.begin();
+	mInterrupter->process(*it++);
+	while (it != pfds.end()) {
 		socket_t sock = it->fd;
 		auto jt = mSocks->find(sock);
-		if (jt == mSocks->end())
-			continue; // removed
-
-		try {
-			auto &entry = jt->second;
-			const auto &params = entry.params;
-
-			if (it->revents & POLLNVAL || it->revents & POLLERR) {
-				PLOG_VERBOSE << "Poll error event";
-				auto callback = std::move(params.callback);
-				mSocks->erase(sock);
-				callback(Event::Error);
-				continue;
-			}
-
-			if (it->revents & POLLIN || it->revents & POLLOUT) {
-				entry.until =
-				    params.timeout ? std::make_optional(clock::now() + *params.timeout) : nullopt;
-
-				auto callback = params.callback;
-
-				if (it->revents & POLLIN) {
-					PLOG_VERBOSE << "Poll in event";
-					params.callback(Event::In);
-				}
-
-				if (it->revents & POLLOUT) {
-					PLOG_VERBOSE << "Poll out event";
-					params.callback(Event::Out);
+		if (jt != mSocks->end()) {
+			try {
+				auto &entry = jt->second;
+				const auto &params = entry.params;
+
+				if (it->revents & POLLNVAL || it->revents & POLLERR) {
+					PLOG_VERBOSE << "Poll error event";
+					auto callback = std::move(params.callback);
+					mSocks->erase(sock);
+					callback(Event::Error);
+
+				} else if (it->revents & POLLIN || it->revents & POLLOUT) {
+					entry.until = params.timeout
+					                  ? std::make_optional(clock::now() + *params.timeout)
+					                  : nullopt;
+
+					auto callback = params.callback;
+					if (it->revents & POLLIN) {
+						PLOG_VERBOSE << "Poll in event";
+						callback(Event::In);
+					}
+					if (it->revents & POLLOUT) {
+						PLOG_VERBOSE << "Poll out event";
+						callback(Event::Out);
+					}
+
+				} else if (entry.until && clock::now() >= *entry.until) {
+					PLOG_VERBOSE << "Poll timeout event";
+					auto callback = std::move(params.callback);
+					mSocks->erase(sock);
+					callback(Event::Timeout);
 				}
 
-				continue;
-			}
-
-			if (entry.until && clock::now() >= *entry.until) {
-				PLOG_VERBOSE << "Poll timeout event";
-				auto callback = std::move(params.callback);
+			} catch (const std::exception &e) {
+				PLOG_WARNING << e.what();
 				mSocks->erase(sock);
-				callback(Event::Timeout);
-				continue;
 			}
-
-		} catch (const std::exception &e) {
-			PLOG_WARNING << e.what();
-			mSocks->erase(sock);
 		}
+
+		++it;
 	}
 }
 
@@ -180,7 +178,7 @@ void PollService::runLoop() {
 					auto msecs = duration_cast<milliseconds>(
 					    std::max(clock::duration::zero(), *next - clock::now()));
 					PLOG_VERBOSE << "Entering poll, timeout=" << msecs.count() << "ms";
-					timeout = msecs.count();
+					timeout = int(msecs.count());
 				} else {
 					PLOG_VERBOSE << "Entering poll";
 					timeout = -1;

+ 1 - 1
src/impl/pollservice.hpp

@@ -76,11 +76,11 @@ private:
 
 	using SocketMap = std::unordered_map<socket_t, SocketEntry>;
 	unique_ptr<SocketMap> mSocks;
+	unique_ptr<PollInterrupter> mInterrupter;
 
 	std::recursive_mutex mMutex;
 	std::thread mThread;
 	bool mStopped;
-	PollInterrupter mInterrupter;
 };
 
 std::ostream &operator<<(std::ostream &out, PollService::Direction direction);

+ 12 - 9
src/impl/tcpserver.cpp

@@ -46,9 +46,10 @@ shared_ptr<TcpTransport> TcpServer::accept() {
 			break;
 
 		struct pollfd pfd[2];
-		pfd[0].fd = mSock;
-		pfd[0].events = POLLIN;
-		mInterrupter.prepare(pfd[1]);
+		mInterrupter.prepare(pfd[0]);
+		pfd[1].fd = mSock;
+		pfd[1].events = POLLIN;
+
 		lock.unlock();
 		int ret = ::poll(pfd, 2, -1);
 		lock.lock();
@@ -63,11 +64,13 @@ shared_ptr<TcpTransport> TcpServer::accept() {
 				throw std::runtime_error("Failed to wait for socket connection");
 		}
 
-		if (pfd[0].revents & POLLNVAL || pfd[0].revents & POLLERR) {
+		mInterrupter.process(pfd[0]);
+
+		if (pfd[1].revents & POLLNVAL || pfd[1].revents & POLLERR) {
 			throw std::runtime_error("Error while waiting for socket connection");
 		}
 
-		if (pfd[0].revents & POLLIN) {
+		if (pfd[1].revents & POLLIN) {
 			struct sockaddr_storage addr;
 			socklen_t addrlen = sizeof(addr);
 			socket_t incomingSock = ::accept(mSock, (struct sockaddr *)&addr, &addrlen);
@@ -106,7 +109,7 @@ void TcpServer::listen(uint16_t port) {
 	hints.ai_flags = AI_PASSIVE | AI_NUMERICSERV;
 
 	struct addrinfo *result = nullptr;
-	if (::getaddrinfo(nullptr, std::to_string(port).c_str(), &hints, &result))
+	if (getaddrinfo(nullptr, std::to_string(port).c_str(), &hints, &result))
 		throw std::runtime_error("Resolution failed for local address");
 
 	try {
@@ -135,7 +138,7 @@ void TcpServer::listen(uint16_t port) {
 
 		// Enable REUSEADDR
 		::setsockopt(mSock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char *>(&enabled),
-			             sizeof(enabled));
+		             sizeof(enabled));
 
 		// Listen on both IPv6 and IPv4
 		if (ai->ai_family == AF_INET6)
@@ -143,8 +146,8 @@ void TcpServer::listen(uint16_t port) {
 			             reinterpret_cast<const char *>(&disabled), sizeof(disabled));
 
 		// Set non-blocking
-		ctl_t b = 1;
-		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+		ctl_t nbio = 1;
+		if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
 			throw std::runtime_error("Failed to set socket non-blocking mode");
 
 		// Bind socket

+ 4 - 1
src/impl/tcpserver.hpp

@@ -29,11 +29,14 @@
 
 namespace rtc::impl {
 
-class TcpServer {
+class TcpServer final {
 public:
 	TcpServer(uint16_t port);
 	~TcpServer();
 
+	TcpServer(const TcpServer &other) = delete;
+	void operator=(const TcpServer &other) = delete;
+
 	shared_ptr<TcpTransport> accept();
 	void close();
 

+ 4 - 4
src/impl/tcptransport.cpp

@@ -48,8 +48,8 @@ TcpTransport::TcpTransport(socket_t sock, state_callback callback)
 	PLOG_DEBUG << "Initializing TCP transport with socket";
 
 	// Set non-blocking
-	ctl_t b = 1;
-	if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+	ctl_t nbio = 1;
+	if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
 		throw std::runtime_error("Failed to set socket non-blocking mode");
 
 	// Retrieve hostname and service
@@ -221,8 +221,8 @@ void TcpTransport::prepare(const sockaddr *addr, socklen_t addrlen) {
 			throw std::runtime_error("TCP socket creation failed");
 
 		// Set non-blocking
-		ctl_t b = 1;
-		if (::ioctlsocket(mSock, FIONBIO, &b) < 0)
+		ctl_t nbio = 1;
+		if (::ioctlsocket(mSock, FIONBIO, &nbio) < 0)
 			throw std::runtime_error("Failed to set socket non-blocking mode");
 
 #ifdef __APPLE__

+ 1 - 1
test/capi_websocketserver.cpp

@@ -144,7 +144,7 @@ int test_capi_websocketserver_main() {
 	while (!success && !failed && attempts--)
 		sleep(1);
 
-	if (failed)
+	if (!success || failed)
 		goto error;
 
 	rtcDeleteWebSocket(wsclient);