Browse Source

Wire selftest, and passes all tests.

Adam Ierymenko 10 years ago
parent
commit
3c1a59fa24
2 changed files with 182 additions and 29 deletions
  1. 48 29
      osnet/Wire.hpp
  2. 134 0
      selftest.cpp

+ 48 - 29
osnet/Wire.hpp

@@ -178,7 +178,8 @@ public:
 		ON_TCP_CLOSE_FUNCTION tcpCloseHandler,
 		ON_TCP_DATA_FUNCTION tcpDataHandler,
 		ON_TCP_WRITABLE_FUNCTION tcpWritableHandler,
-		bool noDelay) :
+		bool noDelay
+			) :
 		_datagramHandler(datagramHandler),
 		_tcpConnectHandler(tcpConnectHandler),
 		_tcpAcceptHandler(tcpAcceptHandler),
@@ -263,11 +264,11 @@ public:
 	 * Bind a UDP socket
 	 *
 	 * @param localAddress Local endpoint address and port
-	 * @param uptr Initial value of user pointer associated with this socket
-	 * @param bufferSize Desired socket receive/send buffer size -- will set as close to this as possible (0 to accept default)
+	 * @param uptr Initial value of user pointer associated with this socket (default: NULL)
+	 * @param bufferSize Desired socket receive/send buffer size -- will set as close to this as possible (default: 0, leave alone)
 	 * @return Socket or NULL on failure to bind
 	 */
-	inline WireSocket *udpBind(const struct sockaddr *localAddress,void *uptr,int bufferSize)
+	inline WireSocket *udpBind(const struct sockaddr *localAddress,void *uptr = (void *)0,int bufferSize = 0)
 	{
 		if (_socks.size() >= ZT_WIRE_MAX_SOCKETS)
 			return (WireSocket *)0;
@@ -358,26 +359,25 @@ public:
 	 * Send a UDP packet
 	 *
 	 * @param sock UDP socket
-	 * @param addr Destination address (must be correct type for socket)
-	 * @param addrlen Length of sockaddr_X structure
+	 * @param remoteAddress Destination address (must be correct type for socket)
 	 * @param data Data to send
 	 * @param len Length of packet
 	 * @return True if packet appears to have been sent successfully
 	 */
-	inline bool udpSend(WireSocket *sock,const struct sockaddr *addr,unsigned int addrlen,WireSocket *data,unsigned long len)
+	inline bool udpSend(WireSocket *sock,const struct sockaddr *remoteAddress,const void *data,unsigned long len)
 	{
 		WireSocketImpl &sws = *(const_cast <WireSocketImpl *>(reinterpret_cast<const WireSocketImpl *>(sock)));
-		return ((long)::sendto(sws.sock,data,len,0,addr,(socklen_t)addrlen) == (long)len);
+		return ((long)::sendto(sws.sock,data,len,0,remoteAddress,(remoteAddress->sa_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)) == (long)len);
 	}
 
 	/**
 	 * Bind a local listen socket to listen for new TCP connections
 	 *
 	 * @param localAddress Local address and port
-	 * @param uptr Initial value of uptr for new socket
+	 * @param uptr Initial value of uptr for new socket (default: NULL)
 	 * @return Socket or NULL on failure to bind
 	 */
-	inline WireSocket *tcpListen(const struct sockaddr *localAddress,void *uptr)
+	inline WireSocket *tcpListen(const struct sockaddr *localAddress,void *uptr = (void *)0)
 	{
 		if (_socks.size() >= ZT_WIRE_MAX_SOCKETS)
 			return (WireSocket *)0;
@@ -438,30 +438,35 @@ public:
 	/**
 	 * Start a non-blocking connect; CONNECT handler is called on success or failure
 	 *
-	 * Note that if NULL is returned here, the handler is not called. Such
-	 * a return would indicate failure to allocate the socket, too many
-	 * open sockets, etc.
+	 * A return value of NULL indicates a synchronous failure such as a
+	 * failure to open a socket. The TCP connection handler is not called
+	 * in this case.
 	 *
-	 * Also note that an "instant connect" may occur for e.g. loopback
-	 * connections. If this happens the 'connected' result paramter will
-	 * be true. If callConnectHandlerOnInstantConnect is true, the
-	 * TCP connect handler will be called before the function returns
-	 * as well in this case. Otherwise it will not.
+	 * It is possible on some platforms for an "instant connect" to occur,
+	 * such as when connecting to a loopback address. In this case, the
+	 * 'connected' result parameter will be set to 'true' and if the
+	 * 'callConnectHandler' flag is true (the default) the TCP connect
+	 * handler will be called before the function returns.
+	 *
+	 * These semantics can be a bit confusing, but they're less so than
+	 * the underlying semantics of asynchronous TCP connect.
 	 *
 	 * @param remoteAddress Remote address
-	 * @param uptr Initial value of uptr for new socket
-	 * @param callConnectHandlerOnInstantConnect If true, call TCP connect handler now if an "instant connect" occurs
-	 * @param connected Reference to result paramter set to true if "instant connect" occurs, false otherwise
+	 * @param connected Result parameter: set to whether an "instant connect" has occurred (true if yes)
+	 * @param uptr Initial value of uptr for new socket (default: NULL)
+	 * @param callConnectHandler If true, call TCP connect handler even if result is known before function exit (default: true)
 	 * @return New socket or NULL on failure
 	 */
-	inline WireSocket *tcpConnect(const struct sockaddr *remoteAddress,void *uptr,bool callConnectHandlerOnInstantConnect,bool &connected)
+	inline WireSocket *tcpConnect(const struct sockaddr *remoteAddress,bool &connected,void *uptr = (void *)0,bool callConnectHandler = true)
 	{
 		if (_socks.size() >= ZT_WIRE_MAX_SOCKETS)
 			return (WireSocket *)0;
 
 		ZT_WIRE_SOCKFD_TYPE s = ::socket(remoteAddress->sa_family,SOCK_STREAM,0);
-		if (!ZT_WIRE_SOCKFD_VALID(s))
+		if (!ZT_WIRE_SOCKFD_VALID(s)) {
+			connected = false;
 			return (WireSocket *)0;
+		}
 
 #if defined(_WIN32) || defined(_WIN64)
 		{
@@ -484,6 +489,7 @@ public:
 
 		connected = true;
 		if (::connect(s,remoteAddress,(remoteAddress->sa_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in))) {
+			connected = false;
 #if defined(_WIN32) || defined(_WIN64)
 			if (WSAGetLastError() != WSAEWOULDBLOCK) {
 #else
@@ -491,7 +497,7 @@ public:
 #endif
 				ZT_WIRE_CLOSE_SOCKET(s);
 				return (WireSocket *)0;
-			} else connected = false;
+			} // else connection is proceeding asynchronously...
 		}
 
 		try {
@@ -519,9 +525,9 @@ public:
 		memset(&(sws.saddr),0,sizeof(struct sockaddr_storage));
 		memcpy(&(sws.saddr),remoteAddress,(remoteAddress->sa_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in));
 
-		if ((callConnectHandlerOnInstantConnect)&&(connected)) {
+		if ((callConnectHandler)&&(connected)) {
 			try {
-				_tcpConnectHandler((WireSocket *)&sws,uptr,true);
+				_tcpConnectHandler((WireSocket *)&sws,&(sws.uptr),true);
 			} catch ( ... ) {}
 		}
 
@@ -541,7 +547,7 @@ public:
 	 * @param callCloseHandler If true, call close handler on socket closing failure condition
 	 * @return Number of bytes actually sent or -1 on fatal error (socket closure)
 	 */
-	inline long tcpSend(WireSocket *sock,WireSocket *data,unsigned long len,bool callCloseHandler)
+	inline long tcpSend(WireSocket *sock,const void *data,unsigned long len,bool callCloseHandler)
 	{
 		WireSocketImpl &sws = *(const_cast <WireSocketImpl *>(reinterpret_cast<const WireSocketImpl *>(sock)));
 		long n = (long)::send(sws.sock,data,len,0);
@@ -706,7 +712,7 @@ public:
 								if ((long)newSock > _nfds)
 									_nfds = (long)newSock;
 								sws.type = ZT_WIRE_SOCKET_TCP_IN;
-								sws.sock = s;
+								sws.sock = newSock;
 								sws.uptr = (void *)0;
 								memcpy(&(sws.saddr),&ss,sizeof(struct sockaddr_storage));
 								try {
@@ -774,7 +780,7 @@ public:
 		long oldSock = (long)sws.sock;
 
 		for(typename std::list<WireSocketImpl>::iterator s(_socks.begin());s!=_socks.end();++s) {
-			if (&(*s) == sock) {
+			if (reinterpret_cast<WireSocket *>(&(*s)) == sock) {
 				_socks.erase(s);
 				break;
 			}
@@ -793,6 +799,19 @@ public:
 	}
 };
 
+// Typedefs for using regular naked functions as template parameters to Wire<>
+typedef void (*Wire_OnDatagramFunctionPtr)(WireSocket *sock,void **uptr,const struct sockaddr *from,void *data,unsigned long len);
+typedef void (*Wire_OnTcpConnectFunction)(WireSocket *sock,void **uptr,bool success);
+typedef void (*Wire_OnTcpAcceptFunction)(WireSocket *sockL,WireSocket *sockN,void **uptrL,void **uptrN,const struct sockaddr *from);
+typedef void (*Wire_OnTcpCloseFunction)(WireSocket *sock,void **uptr);
+typedef void (*Wire_OnTcpDataFunction)(WireSocket *sock,void **uptr,void *data,unsigned long len);
+typedef void (*Wire_OnTcpWritableFunction)(WireSocket *sock,void **uptr);
+
+/**
+ * Wire<> typedef'd to use simple naked function pointers
+ */
+typedef Wire<Wire_OnDatagramFunctionPtr,Wire_OnTcpConnectFunction,Wire_OnTcpAcceptFunction,Wire_OnTcpCloseFunction,Wire_OnTcpDataFunction,Wire_OnTcpWritableFunction> SimpleFunctionWire;
+
 } // namespace ZeroTier
 
 #endif

+ 134 - 0
selftest.cpp

@@ -646,6 +646,139 @@ static int testOther()
 	return 0;
 }
 
+#ifdef ZT_TEST_WIRE
+#define ZT_TEST_WIRE_NUM_UDP_PACKETS 10000
+#define ZT_TEST_WIRE_UDP_PACKET_SIZE 1000
+#define ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS 10
+#define ZT_TEST_WIRE_NUM_INVALID_TCP_CONNECTS 2
+#define ZT_TEST_WIRE_TCP_MESSAGE_SIZE 1000000
+#define ZT_TEST_WIRE_TIMEOUT_MS 20000
+static unsigned long wireTestUdpPacketCount = 0;
+static unsigned long wireTestTcpByteCount = 0;
+static unsigned long wireTestTcpConnectSuccessCount = 0;
+static unsigned long wireTestTcpConnectFailCount = 0;
+static unsigned long wireTestTcpAcceptCount = 0;
+static SimpleFunctionWire *testWireInstance = (SimpleFunctionWire *)0;
+static void testWireOnDatagramFunction(WireSocket *sock,void **uptr,const struct sockaddr *from,void *data,unsigned long len)
+{
+	++wireTestUdpPacketCount;
+}
+static void testWireOnTcpConnectFunction(WireSocket *sock,void **uptr,bool success)
+{
+	if (success) {
+		++wireTestTcpConnectSuccessCount;
+	} else {
+		++wireTestTcpConnectFailCount;
+	}
+}
+static void testWireOnTcpAcceptFunction(WireSocket *sockL,WireSocket *sockN,void **uptrL,void **uptrN,const struct sockaddr *from)
+{
+	++wireTestTcpAcceptCount;
+	*uptrN = new std::string(ZT_TEST_WIRE_TCP_MESSAGE_SIZE,(char)0xff);
+	testWireInstance->tcpSetNotifyWritable(sockN,true);
+}
+static void testWireOnTcpCloseFunction(WireSocket *sock,void **uptr)
+{
+	delete (std::string *)*uptr; // delete testMessage if any
+}
+static void testWireOnTcpDataFunction(WireSocket *sock,void **uptr,void *data,unsigned long len)
+{
+	wireTestTcpByteCount += len;
+}
+static void testWireOnTcpWritableFunction(WireSocket *sock,void **uptr)
+{
+	std::string *testMessage = (std::string *)*uptr;
+	if ((testMessage)&&(testMessage->length() > 0)) {
+		long sent = testWireInstance->tcpSend(sock,(const void *)testMessage->data(),testMessage->length(),true);
+		if (sent > 0)
+			testMessage->erase(0,sent);
+	}
+	if ((!testMessage)||(!testMessage->length())) {
+		testWireInstance->close(sock,true);
+	}
+}
+#endif // ZT_TEST_WIRE
+
+static int testWire()
+{
+#ifdef ZT_TEST_WIRE
+	char udpTestPayload[ZT_TEST_WIRE_UDP_PACKET_SIZE];
+	memset(udpTestPayload,0xff,sizeof(udpTestPayload));
+
+	struct sockaddr_in bindaddr;
+	memset(&bindaddr,0,sizeof(bindaddr));
+	bindaddr.sin_family = AF_INET;
+	bindaddr.sin_port = Utils::hton((uint16_t)60002);
+	bindaddr.sin_addr.s_addr = Utils::hton((uint32_t)0x7f000001);
+	struct sockaddr_in invalidAddr;
+	memset(&bindaddr,0,sizeof(bindaddr));
+	bindaddr.sin_family = AF_INET;
+	bindaddr.sin_port = Utils::hton((uint16_t)60004);
+	bindaddr.sin_addr.s_addr = Utils::hton((uint32_t)0x7f000001);
+
+	std::cout << "[wire] Creating wire endpoint..." << std::endl;
+	testWireInstance = new SimpleFunctionWire(testWireOnDatagramFunction,testWireOnTcpConnectFunction,testWireOnTcpAcceptFunction,testWireOnTcpCloseFunction,testWireOnTcpDataFunction,testWireOnTcpWritableFunction,false);
+
+	std::cout << "[wire] Binding UDP listen socket to 127.0.0.1/60002... ";
+	WireSocket *udpListenSock = testWireInstance->udpBind((const struct sockaddr *)&bindaddr);
+	if (!udpListenSock) {
+		std::cout << "FAILED." << std::endl;
+		return -1;
+	}
+	std::cout << "OK" << std::endl;
+
+	std::cout << "[wire] Binding TCP listen socket to 127.0.0.1/60002... ";
+	WireSocket *tcpListenSock = testWireInstance->tcpListen((const struct sockaddr *)&bindaddr);
+	if (!tcpListenSock) {
+		std::cout << "FAILED." << std::endl;
+		return -1;
+	}
+	std::cout << "OK" << std::endl;
+
+	unsigned long wireTestUdpPacketsSent = 0;
+	unsigned long wireTestTcpValidConnectionsAttempted = 0;
+	unsigned long wireTestTcpInvalidConnectionsAttempted = 0;
+
+	std::cout << "[wire] Testing UDP send/receive... "; std::cout.flush();
+	uint64_t timeoutAt = Utils::now() + ZT_TEST_WIRE_TIMEOUT_MS;
+	while ((Utils::now() < timeoutAt)&&(wireTestUdpPacketCount < ZT_TEST_WIRE_NUM_UDP_PACKETS)) {
+		if (wireTestUdpPacketsSent < ZT_TEST_WIRE_NUM_UDP_PACKETS) {
+			if (!testWireInstance->udpSend(udpListenSock,(const struct sockaddr *)&bindaddr,udpTestPayload,sizeof(udpTestPayload))) {
+				std::cout << "FAILED." << std::endl;
+				return -1;
+			} else ++wireTestUdpPacketsSent;
+		}
+		testWireInstance->poll(100);
+	}
+	std::cout << "got " << wireTestUdpPacketCount << " packets, OK" << std::endl;
+
+	std::cout << "[wire] Testing TCP... "; std::cout.flush();
+	timeoutAt = Utils::now() + ZT_TEST_WIRE_TIMEOUT_MS;
+	while ((Utils::now() < timeoutAt)&&(wireTestTcpByteCount < (ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS * ZT_TEST_WIRE_TCP_MESSAGE_SIZE))) {
+		if (wireTestTcpValidConnectionsAttempted < ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS) {
+			++wireTestTcpValidConnectionsAttempted;
+			bool connected = false;
+			if (!testWireInstance->tcpConnect((const struct sockaddr *)&bindaddr,connected,(void *)0,true))
+				++wireTestTcpConnectFailCount;
+		}
+		if (wireTestTcpInvalidConnectionsAttempted < ZT_TEST_WIRE_NUM_INVALID_TCP_CONNECTS) {
+			++wireTestTcpInvalidConnectionsAttempted;
+			bool connected = false;
+			if (!testWireInstance->tcpConnect((const struct sockaddr *)&invalidAddr,connected,(void *)0,true))
+				++wireTestTcpConnectFailCount;
+		}
+		testWireInstance->poll(100);
+	}
+	if (wireTestTcpByteCount < (ZT_TEST_WIRE_NUM_VALID_TCP_CONNECTS * ZT_TEST_WIRE_TCP_MESSAGE_SIZE)) {
+		std::cout << "got " << wireTestTcpConnectSuccessCount << " connect successes, " << wireTestTcpConnectFailCount << " failures, and " << wireTestTcpByteCount << " bytes, FAILED." << std::endl;
+		return -1;
+	} else {
+		std::cout << "got " << wireTestTcpConnectSuccessCount << " connect successes, " << wireTestTcpConnectFailCount << " failures, and " << wireTestTcpByteCount << " bytes, OK" << std::endl;
+	}
+#endif // ZT_TEST_WIRE
+	return 0;
+}
+
 static int testSqliteNetconfMaster()
 {
 #ifdef ZT_ENABLE_NETCONF_MASTER
@@ -717,6 +850,7 @@ int main(int argc,char **argv)
 
 	srand((unsigned int)time(0));
 
+	r |= testWire();
 	r |= testSqliteNetconfMaster();
 	r |= testCrypto();
 	r |= testHttp();