Browse Source

Added WebSocket support to C API

Paul-Louis Ageneau 5 years ago
parent
commit
03a48c1c09
3 changed files with 155 additions and 83 deletions
  1. 26 19
      include/rtc/rtc.h
  2. 123 58
      src/rtc.cpp
  3. 6 6
      test/main.cpp

+ 26 - 19
include/rtc/rtc.h

@@ -42,8 +42,7 @@ typedef enum {
 	RTC_GATHERING_COMPLETE = 2
 } rtcGatheringState;
 
-// Don't change, it must match plog severity
-typedef enum {
+typedef enum { // Don't change, it must match plog severity
 	RTC_LOG_NONE = 0,
 	RTC_LOG_FATAL = 1,
 	RTC_LOG_ERROR = 2,
@@ -76,10 +75,10 @@ typedef void (*availableCallbackFunc)(void *ptr);
 void rtcInitLogger(rtcLogLevel level);
 
 // User pointer
-void rtcSetUserPointer(int i, void *ptr);
+void rtcSetUserPointer(int id, void *ptr);
 
 // PeerConnection
-int rtcCreatePeerConnection(const rtcConfiguration *config);
+int rtcCreatePeerConnection(const rtcConfiguration *config); // returns pc id
 int rtcDeletePeerConnection(int pc);
 
 int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb);
@@ -95,24 +94,32 @@ int rtcGetLocalAddress(int pc, char *buffer, int size);
 int rtcGetRemoteAddress(int pc, char *buffer, int size);
 
 // DataChannel
-int rtcCreateDataChannel(int pc, const char *label);
+int rtcCreateDataChannel(int pc, const char *label); // returns dc id
 int rtcDeleteDataChannel(int dc);
 
 int rtcGetDataChannelLabel(int dc, char *buffer, int size);
-int rtcSetOpenCallback(int dc, openCallbackFunc cb);
-int rtcSetClosedCallback(int dc, closedCallbackFunc cb);
-int rtcSetErrorCallback(int dc, errorCallbackFunc cb);
-int rtcSetMessageCallback(int dc, messageCallbackFunc cb);
-int rtcSendMessage(int dc, const char *data, int size);
-
-int rtcGetBufferedAmount(int dc); // total size buffered to send
-int rtcSetBufferedAmountLowThreshold(int dc, int amount);
-int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb);
-
-// DataChannel extended API
-int rtcGetAvailableAmount(int dc); // total size available to receive
-int rtcSetAvailableCallback(int dc, availableCallbackFunc cb);
-int rtcReceiveMessage(int dc, char *buffer, int *size);
+
+// WebSocket
+#if ENABLE_WEBSOCKET
+int rtcCreateWebSocket(const char *url); // returns ws id
+int rtcDeleteWebsocket(int ws);
+#endif
+
+// DataChannel and WebSocket common API
+int rtcSetOpenCallback(int id, openCallbackFunc cb);
+int rtcSetClosedCallback(int id, closedCallbackFunc cb);
+int rtcSetErrorCallback(int id, errorCallbackFunc cb);
+int rtcSetMessageCallback(int id, messageCallbackFunc cb);
+int rtcSendMessage(int id, const char *data, int size);
+
+int rtcGetBufferedAmount(int id); // total size buffered to send
+int rtcSetBufferedAmountLowThreshold(int id, int amount);
+int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb);
+
+// DataChannel and WebSocket common extended API
+int rtcGetAvailableAmount(int id); // total size available to receive
+int rtcSetAvailableCallback(int id, availableCallbackFunc cb);
+int rtcReceiveMessage(int id, char *buffer, int *size);
 
 #ifdef __cplusplus
 } // extern "C"

+ 123 - 58
src/rtc.cpp

@@ -16,10 +16,15 @@
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-#include "datachannel.hpp"
 #include "include.hpp"
+
+#include "datachannel.hpp"
 #include "peerconnection.hpp"
 
+#if ENABLE_WEBSOCKET
+#include "websocket.hpp"
+#endif
+
 #include <rtc.h>
 
 #include <exception>
@@ -43,6 +48,9 @@ namespace {
 
 std::unordered_map<int, shared_ptr<PeerConnection>> peerConnectionMap;
 std::unordered_map<int, shared_ptr<DataChannel>> dataChannelMap;
+#if ENABLE_WEBSOCKET
+std::unordered_map<int, shared_ptr<WebSocket>> webSocketMap;
+#endif
 std::unordered_map<int, void *> userPointerMap;
 std::mutex mutex;
 int lastId = 0;
@@ -103,6 +111,40 @@ bool eraseDataChannel(int dc) {
 	return true;
 }
 
+#if ENABLE_WEBSOCKET
+shared_ptr<WebSocket> getWebSocket(int id) {
+	std::lock_guard lock(mutex);
+	auto it = webSocketMap.find(id);
+	return it != webSocketMap.end() ? it->second : nullptr;
+}
+
+int emplaceWebSocket(shared_ptr<WebSocket> ptr) {
+	std::lock_guard lock(mutex);
+	int ws = ++lastId;
+	webSocketMap.emplace(std::make_pair(ws, ptr));
+	return ws;
+}
+
+bool eraseWebSocket(int ws) {
+	std::lock_guard lock(mutex);
+	if (webSocketMap.erase(ws) == 0)
+		return false;
+	userPointerMap.erase(ws);
+	return true;
+}
+#endif
+
+shared_ptr<Channel> getChannel(int id) {
+	std::lock_guard lock(mutex);
+	if (auto it = dataChannelMap.find(id); it != dataChannelMap.end())
+		return it->second;
+#if ENABLE_WEBSOCKET
+	if (auto it = webSocketMap.find(id); it != webSocketMap.end())
+		return it->second;
+#endif
+	return nullptr;
+}
+
 } // namespace
 
 void rtcInitLogger(rtcLogLevel level) { InitLogger(static_cast<LogLevel>(level)); }
@@ -164,6 +206,29 @@ int rtcDeleteDataChannel(int dc) {
 	return 0;
 }
 
+#if ENABLE_WEBSOCKET
+int rtcCreateWebSocket(const char *url) {
+	return emplaceWebSocket(std::make_shared<WebSocket>(url));
+}
+
+int rtcDeleteWebsocket(int ws) {
+	auto webSocket = getWebSocket(ws);
+	if (!webSocket)
+		return -1;
+
+	webSocket->onOpen(nullptr);
+	webSocket->onClosed(nullptr);
+	webSocket->onError(nullptr);
+	webSocket->onMessage(nullptr);
+	webSocket->onBufferedAmountLow(nullptr);
+	webSocket->onAvailable(nullptr);
+
+	eraseWebSocket(ws);
+	return 0;
+}
+
+#endif
+
 int rtcSetDataChannelCallback(int pc, dataChannelCallbackFunc cb) {
 	auto peerConnection = getPeerConnection(pc);
 	if (!peerConnection)
@@ -298,135 +363,135 @@ int rtcGetDataChannelLabel(int dc, char *buffer, int size) {
 	return size + 1;
 }
 
-int rtcSetOpenCallback(int dc, openCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetOpenCallback(int id, openCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onOpen(nullptr);
+		channel->onOpen(nullptr);
 	return 0;
 }
 
-int rtcSetClosedCallback(int dc, closedCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetClosedCallback(int id, closedCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onClosed([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onClosed([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onClosed(nullptr);
+		channel->onClosed(nullptr);
 	return 0;
 }
 
-int rtcSetErrorCallback(int dc, errorCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetErrorCallback(int id, errorCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onError(
-		    [dc, cb](const string &error) { cb(error.c_str(), getUserPointer(dc)); });
+		channel->onError([id, cb](const string &error) { cb(error.c_str(), getUserPointer(id)); });
 	else
-		dataChannel->onError(nullptr);
+		channel->onError(nullptr);
 	return 0;
 }
 
-int rtcSetMessageCallback(int dc, messageCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetMessageCallback(int id, messageCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onMessage(
-		    [dc, cb](const binary &b) {
-			    cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(dc));
+		channel->onMessage(
+		    [id, cb](const binary &b) {
+			    cb(reinterpret_cast<const char *>(b.data()), b.size(), getUserPointer(id));
 		    },
-		    [dc, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(dc)); });
+		    [id, cb](const string &s) { cb(s.c_str(), -1, getUserPointer(id)); });
 	else
-		dataChannel->onMessage(nullptr);
+		channel->onMessage(nullptr);
 
 	return 0;
 }
 
-int rtcSendMessage(int dc, const char *data, int size) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSendMessage(int id, const char *data, int size) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (size >= 0) {
 		auto b = reinterpret_cast<const byte *>(data);
-		CATCH(dataChannel->send(b, size));
+		CATCH(channel->send(binary(b, b + size)));
 		return size;
 	} else {
-		string s(data);
-		CATCH(dataChannel->send(s));
-		return s.size();
+		string str(data);
+		int len = str.size();
+		CATCH(channel->send(std::move(str)));
+		return len;
 	}
 }
 
-int rtcGetBufferedAmount(int dc) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcGetBufferedAmount(int id) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
-	CATCH(return int(dataChannel->bufferedAmount()));
+	CATCH(return int(channel->bufferedAmount()));
 }
 
-int rtcSetBufferedAmountLowThreshold(int dc, int amount) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetBufferedAmountLowThreshold(int id, int amount) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
-	CATCH(dataChannel->setBufferedAmountLowThreshold(size_t(amount)));
+	CATCH(channel->setBufferedAmountLowThreshold(size_t(amount)));
 	return 0;
 }
 
-int rtcSetBufferedAmountLowCallback(int dc, bufferedAmountLowCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetBufferedAmountLowCallback(int id, bufferedAmountLowCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onBufferedAmountLow([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onBufferedAmountLow([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onBufferedAmountLow(nullptr);
+		channel->onBufferedAmountLow(nullptr);
 	return 0;
 }
 
-int rtcGetAvailableAmount(int dc) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcGetAvailableAmount(int id) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
-	CATCH(return int(dataChannel->availableAmount()));
+	CATCH(return int(channel->availableAmount()));
 }
 
-int rtcSetAvailableCallback(int dc, availableCallbackFunc cb) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcSetAvailableCallback(int id, availableCallbackFunc cb) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (cb)
-		dataChannel->onOpen([dc, cb]() { cb(getUserPointer(dc)); });
+		channel->onOpen([id, cb]() { cb(getUserPointer(id)); });
 	else
-		dataChannel->onOpen(nullptr);
+		channel->onOpen(nullptr);
 	return 0;
 }
 
-int rtcReceiveMessage(int dc, char *buffer, int *size) {
-	auto dataChannel = getDataChannel(dc);
-	if (!dataChannel)
+int rtcReceiveMessage(int id, char *buffer, int *size) {
+	auto channel = getChannel(id);
+	if (!channel)
 		return -1;
 
 	if (!size)
 		return -1;
 
 	CATCH({
-		auto message = dataChannel->receive();
+		auto message = channel->receive();
 		if (!message)
 			return 0;
 

+ 6 - 6
test/main.cpp

@@ -25,19 +25,19 @@ void test_capi();
 
 int main(int argc, char **argv) {
 	try {
-		std::cout << "*** Running connectivity test..." << std::endl;
+		cout << endl << "*** Running connectivity test..." << endl;
 		test_connectivity();
-		std::cout << "*** Finished connectivity test" << std::endl;
+		cout << "*** Finished connectivity test" << endl;
 	} catch (const exception &e) {
-		std::cerr << "Connectivity test failed: " << e.what() << endl;
+		cerr << "Connectivity test failed: " << e.what() << endl;
 		return -1;
 	}
 	try {
-		std::cout << "*** Running C API test..." << std::endl;
+		cout << endl << "*** Running C API test..." << endl;
 		test_capi();
-		std::cout << "*** Finished C API test" << std::endl;
+		cout << "*** Finished C API test" << endl;
 	} catch (const exception &e) {
-		std::cerr << "C API test failed: " << e.what() << endl;
+		cerr << "C API test failed: " << e.what() << endl;
 		return -1;
 	}
 	return 0;