Browse Source

Merge pull request #182 from paullouisageneau/safe-queue-pop

Safer message queue
Paul-Louis Ageneau 4 years ago
parent
commit
a284af2000
8 changed files with 55 additions and 44 deletions
  1. 37 24
      include/rtc/queue.hpp
  2. 2 2
      src/datachannel.cpp
  3. 6 6
      src/dtlstransport.cpp
  4. 1 1
      src/sctptransport.cpp
  5. 1 1
      src/tcptransport.cpp
  6. 4 6
      src/tlstransport.cpp
  7. 2 2
      src/track.cpp
  8. 2 2
      src/websocket.cpp

+ 37 - 24
include/rtc/queue.hpp

@@ -44,11 +44,15 @@ public:
 	size_t amount() const; // amount
 	size_t amount() const; // amount
 	void push(T element);
 	void push(T element);
 	std::optional<T> pop();
 	std::optional<T> pop();
+	std::optional<T> tryPop();
 	std::optional<T> peek();
 	std::optional<T> peek();
 	std::optional<T> exchange(T element);
 	std::optional<T> exchange(T element);
 	bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
 	bool wait(const std::optional<std::chrono::milliseconds> &duration = nullopt);
 
 
 private:
 private:
+	void pushImpl(T element);
+	std::optional<T> popImpl();
+
 	const size_t mLimit;
 	const size_t mLimit;
 	size_t mAmount;
 	size_t mAmount;
 	std::queue<T> mQueue;
 	std::queue<T> mQueue;
@@ -99,43 +103,32 @@ template <typename T> size_t Queue<T>::amount() const {
 template <typename T> void Queue<T>::push(T element) {
 template <typename T> void Queue<T>::push(T element) {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
 	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
 	mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
-	if (!mStopping) {
-		mAmount += mAmountFunction(element);
-		mQueue.emplace(std::move(element));
-		mPopCondition.notify_one();
-	}
+	pushImpl(std::move(element));
 }
 }
 
 
 template <typename T> std::optional<T> Queue<T>::pop() {
 template <typename T> std::optional<T> Queue<T>::pop() {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 	mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
-	if (!mQueue.empty()) {
-		mAmount -= mAmountFunction(mQueue.front());
-		std::optional<T> element{std::move(mQueue.front())};
-		mQueue.pop();
-		return element;
-	} else {
-		return nullopt;
-	}
+	return popImpl();
+}
+
+template <typename T> std::optional<T> Queue<T>::tryPop() {
+	std::unique_lock lock(mMutex);
+	return popImpl();
 }
 }
 
 
 template <typename T> std::optional<T> Queue<T>::peek() {
 template <typename T> std::optional<T> Queue<T>::peek() {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
-	if (!mQueue.empty()) {
-		return std::optional<T>{mQueue.front()};
-	} else {
-		return nullopt;
-	}
+	return !mQueue.empty() ? std::make_optional(mQueue.front()) : nullopt;
 }
 }
 
 
 template <typename T> std::optional<T> Queue<T>::exchange(T element) {
 template <typename T> std::optional<T> Queue<T>::exchange(T element) {
 	std::unique_lock lock(mMutex);
 	std::unique_lock lock(mMutex);
-	if (!mQueue.empty()) {
-		std::swap(mQueue.front(), element);
-		return std::optional<T>{element};
-	} else {
+	if (mQueue.empty())
 		return nullopt;
 		return nullopt;
-	}
+
+	std::swap(mQueue.front(), element);
+	return std::make_optional(std::move(element));
 }
 }
 
 
 template <typename T>
 template <typename T>
@@ -145,7 +138,27 @@ bool Queue<T>::wait(const std::optional<std::chrono::milliseconds> &duration) {
 		mPopCondition.wait_for(lock, *duration, [this]() { return !mQueue.empty() || mStopping; });
 		mPopCondition.wait_for(lock, *duration, [this]() { return !mQueue.empty() || mStopping; });
 	else
 	else
 		mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
 		mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
-	return !mStopping;
+
+	return !mQueue.empty();
+}
+
+template <typename T> void Queue<T>::pushImpl(T element) {
+	if (mStopping)
+		return;
+
+	mAmount += mAmountFunction(element);
+	mQueue.emplace(std::move(element));
+	mPopCondition.notify_one();
+}
+
+template <typename T> std::optional<T> Queue<T>::popImpl() {
+	if (mQueue.empty())
+		return nullopt;
+
+	mAmount -= mAmountFunction(mQueue.front());
+	std::optional<T> element{std::move(mQueue.front())};
+	mQueue.pop();
+	return element;
 }
 }
 
 
 } // namespace rtc
 } // namespace rtc

+ 2 - 2
src/datachannel.cpp

@@ -122,8 +122,8 @@ bool DataChannel::send(const byte *data, size_t size) {
 }
 }
 
 
 std::optional<message_variant> DataChannel::receive() {
 std::optional<message_variant> DataChannel::receive() {
-	while (!mRecvQueue.empty()) {
-		auto message = *mRecvQueue.pop();
+	while (auto next = mRecvQueue.tryPop()) {
+		message_ptr message = std::move(*next);
 		if (message->type == Message::Control) {
 		if (message->type == Message::Control) {
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
 			auto raw = reinterpret_cast<const uint8_t *>(message->data());
 			if (!message->empty() && raw[0] == MESSAGE_CLOSE)
 			if (!message->empty() && raw[0] == MESSAGE_CLOSE)

+ 6 - 6
src/dtlstransport.cpp

@@ -258,7 +258,7 @@ ssize_t DtlsTransport::WriteCallback(gnutls_transport_ptr_t ptr, const void *dat
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	if (auto next = t->mIncomingQueue.pop()) {
 	if (auto next = t->mIncomingQueue.pop()) {
-		auto message = *next;
+		message_ptr message = std::move(*next);
 		ssize_t len = std::min(maxlen, message->size());
 		ssize_t len = std::min(maxlen, message->size());
 		std::memcpy(data, message->data(), len);
 		std::memcpy(data, message->data(), len);
 		gnutls_transport_set_errno(t->mSession, 0);
 		gnutls_transport_set_errno(t->mSession, 0);
@@ -271,9 +271,9 @@ ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size
 
 
 int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
 	DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
-	t->mIncomingQueue.wait(ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms))
-	                                                       : nullopt);
-	return !t->mIncomingQueue.empty() ? 1 : 0;
+	bool notEmpty = t->mIncomingQueue.wait(
+	    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
+	return notEmpty ? 1 : 0;
 }
 }
 
 
 #else // USE_GNUTLS==0
 #else // USE_GNUTLS==0
@@ -437,8 +437,8 @@ void DtlsTransport::runRecvLoop() {
 		byte buffer[bufferSize];
 		byte buffer[bufferSize];
 		while (true) {
 		while (true) {
 			// Process pending messages
 			// Process pending messages
-			while (!mIncomingQueue.empty()) {
-				auto message = *mIncomingQueue.pop();
+			while (auto next = mIncomingQueue.tryPop()) {
+				message_ptr message = std::move(*next);
 				BIO_write(mInBio, message->data(), int(message->size()));
 				BIO_write(mInBio, message->data(), int(message->size()));
 
 
 				if (state() == State::Connecting) {
 				if (state() == State::Connecting) {

+ 1 - 1
src/sctptransport.cpp

@@ -322,7 +322,7 @@ void SctpTransport::incoming(message_ptr message) {
 bool SctpTransport::trySendQueue() {
 bool SctpTransport::trySendQueue() {
 	// Requires mSendMutex to be locked
 	// Requires mSendMutex to be locked
 	while (auto next = mSendQueue.peek()) {
 	while (auto next = mSendQueue.peek()) {
-		auto message = *next;
+		message_ptr message = std::move(*next);
 		if (!trySendMessage(message))
 		if (!trySendMessage(message))
 			return false;
 			return false;
 		mSendQueue.pop();
 		mSendQueue.pop();

+ 1 - 1
src/tcptransport.cpp

@@ -271,7 +271,7 @@ void TcpTransport::close() {
 bool TcpTransport::trySendQueue() {
 bool TcpTransport::trySendQueue() {
 	// mSockMutex must be locked
 	// mSockMutex must be locked
 	while (auto next = mSendQueue.peek()) {
 	while (auto next = mSendQueue.peek()) {
-		auto message = *next;
+		message_ptr message = std::move(*next);
 		if (!trySendMessage(message)) {
 		if (!trySendMessage(message)) {
 			mSendQueue.exchange(message);
 			mSendQueue.exchange(message);
 			return false;
 			return false;

+ 4 - 6
src/tlstransport.cpp

@@ -238,11 +238,9 @@ ssize_t TlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_
 
 
 int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 int TlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
 	TlsTransport *t = static_cast<TlsTransport *>(ptr);
-	if (ms != GNUTLS_INDEFINITE_TIMEOUT)
-		t->mIncomingQueue.wait(milliseconds(ms));
-	else
-		t->mIncomingQueue.wait();
-	return !t->mIncomingQueue.empty() ? 1 : 0;
+	bool notEmpty = t->mIncomingQueue.wait(
+	    ms != GNUTLS_INDEFINITE_TIMEOUT ? std::make_optional(milliseconds(ms)) : nullopt);
+	return notEmpty ? 1 : 0;
 }
 }
 
 
 #else // USE_GNUTLS==0
 #else // USE_GNUTLS==0
@@ -413,7 +411,7 @@ void TlsTransport::runRecvLoop() {
 			if (!next)
 			if (!next)
 				break;
 				break;
 
 
-			message_ptr message = *next;
+			message_ptr message = std::move(*next);
 			if (message->size() > 0)
 			if (message->size() > 0)
 				BIO_write(mInBio, message->data(), int(message->size())); // Input
 				BIO_write(mInBio, message->data(), int(message->size())); // Input
 			else
 			else

+ 2 - 2
src/track.cpp

@@ -45,8 +45,8 @@ bool Track::send(const byte *data, size_t size) {
 }
 }
 
 
 std::optional<message_variant> Track::receive() {
 std::optional<message_variant> Track::receive() {
-	if (!mRecvQueue.empty())
-		return to_variant(std::move(**mRecvQueue.pop()));
+	if (auto next = mRecvQueue.tryPop())
+		return to_variant(std::move(**next));
 
 
 	return nullopt;
 	return nullopt;
 }
 }

+ 2 - 2
src/websocket.cpp

@@ -110,8 +110,8 @@ bool WebSocket::isClosed() const { return mState == State::Closed; }
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 size_t WebSocket::maxMessageSize() const { return DEFAULT_MAX_MESSAGE_SIZE; }
 
 
 std::optional<message_variant> WebSocket::receive() {
 std::optional<message_variant> WebSocket::receive() {
-	while (!mRecvQueue.empty()) {
-		auto message = *mRecvQueue.pop();
+	while (auto next = mRecvQueue.tryPop()) {
+		message_ptr message = std::move(*next);
 		if (message->type != Message::Control)
 		if (message->type != Message::Control)
 			return to_variant(std::move(*message));
 			return to_variant(std::move(*message));
 	}
 	}