Browse Source

Turns out you do have to unpack and compare sockaddr structures due to sin_len / sin6_len not present on all platforms and other junk.

Adam Ierymenko 10 years ago
parent
commit
6f4b30add8
2 changed files with 83 additions and 72 deletions
  1. 73 27
      node/InetAddress.cpp
  2. 10 45
      node/InetAddress.hpp

+ 73 - 27
node/InetAddress.cpp

@@ -37,10 +37,8 @@
 
 
 namespace ZeroTier {
 namespace ZeroTier {
 
 
-const InetAddress InetAddress::LO4("127.0.0.1",0);
-const InetAddress InetAddress::LO6("::1",0);
-const InetAddress InetAddress::DEFAULT4((uint32_t)0,0);
-const InetAddress InetAddress::DEFAULT6((const void *)0,16,0);
+const InetAddress InetAddress::LO4((const void *)("\x7f\x00\x00\x01"),4,0);
+const InetAddress InetAddress::LO6((const void *)("\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),16,0);
 
 
 InetAddress::IpScope InetAddress::ipScope() const
 InetAddress::IpScope InetAddress::ipScope() const
 	throw()
 	throw()
@@ -113,22 +111,19 @@ InetAddress::IpScope InetAddress::ipScope() const
 void InetAddress::set(const std::string &ip,unsigned int port)
 void InetAddress::set(const std::string &ip,unsigned int port)
 	throw()
 	throw()
 {
 {
+	memset(this,0,sizeof(InetAddress));
 	if (ip.find(':') != std::string::npos) {
 	if (ip.find(':') != std::string::npos) {
-		struct sockaddr_in6 sin6;
-		memset(&sin6,0,sizeof(sin6));
-		sin6.sin6_family = AF_INET6;
-		sin6.sin6_port = Utils::hton((uint16_t)port);
-		if (inet_pton(AF_INET6,ip.c_str(),(void *)&(sin6.sin6_addr.s6_addr)) <= 0)
+		struct sockaddr_in6 *sin6 = reinterpret_cast<struct sockaddr_in6 *>(this);
+		ss_family = AF_INET6;
+		sin6->sin6_port = Utils::hton((uint16_t)port);
+		if (inet_pton(AF_INET6,ip.c_str(),(void *)&(sin6->sin6_addr.s6_addr)) <= 0)
 			memset(this,0,sizeof(InetAddress));
 			memset(this,0,sizeof(InetAddress));
-		else *this = sin6;
 	} else {
 	} else {
-		struct sockaddr_in sin;
-		memset(&sin,0,sizeof(sin));
-		sin.sin_family = AF_INET;
-		sin.sin_port = Utils::hton((uint16_t)port);
-		if (inet_pton(AF_INET,ip.c_str(),(void *)&(sin.sin_addr.s_addr)) <= 0)
+		struct sockaddr_in *sin = reinterpret_cast<struct sockaddr_in *>(this);
+		ss_family = AF_INET;
+		sin->sin_port = Utils::hton((uint16_t)port);
+		if (inet_pton(AF_INET,ip.c_str(),(void *)&(sin->sin_addr.s_addr)) <= 0)
 			memset(this,0,sizeof(InetAddress));
 			memset(this,0,sizeof(InetAddress));
-		else *this = sin;
 	}
 	}
 }
 }
 
 
@@ -137,15 +132,14 @@ void InetAddress::set(const void *ipBytes,unsigned int ipLen,unsigned int port)
 {
 {
 	memset(this,0,sizeof(InetAddress));
 	memset(this,0,sizeof(InetAddress));
 	if (ipLen == 4) {
 	if (ipLen == 4) {
-		setV4();
-		if (ipBytes)
-			memcpy(rawIpData(),ipBytes,4);
-		setPort(port);
+		ss_family = AF_INET;
+		reinterpret_cast<struct sockaddr_in *>(this)->sin_addr.s_addr = *(reinterpret_cast<const uint32_t *>(ipBytes));
+		reinterpret_cast<struct sockaddr_in *>(this)->sin_port = Utils::hton((uint16_t)port);
 	} else if (ipLen == 16) {
 	} else if (ipLen == 16) {
-		setV6();
-		if (ipBytes)
-			memcpy(rawIpData(),ipBytes,16);
-		setPort(port);
+		ss_family = AF_INET6;
+		(reinterpret_cast<uint64_t *>(reinterpret_cast<struct sockaddr_in6 *>(this)->sin6_addr.s6_addr))[0] = ((const uint64_t *)ipBytes)[0];
+		(reinterpret_cast<uint64_t *>(reinterpret_cast<struct sockaddr_in6 *>(this)->sin6_addr.s6_addr))[1] = ((const uint64_t *)ipBytes)[1];
+		reinterpret_cast<struct sockaddr_in6 *>(this)->sin6_port = Utils::hton((uint16_t)port);
 	}
 	}
 }
 }
 
 
@@ -225,10 +219,10 @@ std::string InetAddress::toIpString() const
 
 
 void InetAddress::fromString(const std::string &ipSlashPort)
 void InetAddress::fromString(const std::string &ipSlashPort)
 {
 {
-	std::size_t slashAt = ipSlashPort.find('/');
-	if ((slashAt == std::string::npos)||(slashAt >= ipSlashPort.length()))
+	const std::size_t slashAt = ipSlashPort.find('/');
+	if (slashAt == std::string::npos) {
 		set(ipSlashPort,0);
 		set(ipSlashPort,0);
-	else {
+	} else {
 		long p = strtol(ipSlashPort.substr(slashAt+1).c_str(),(char **)0,10);
 		long p = strtol(ipSlashPort.substr(slashAt+1).c_str(),(char **)0,10);
 		if ((p > 0)&&(p <= 0xffff))
 		if ((p > 0)&&(p <= 0xffff))
 			set(ipSlashPort.substr(0,slashAt),(unsigned int)p);
 			set(ipSlashPort.substr(0,slashAt),(unsigned int)p);
@@ -280,6 +274,58 @@ InetAddress InetAddress::broadcast() const
 	return r;
 	return r;
 }
 }
 
 
+bool InetAddress::operator==(const InetAddress &a) const
+	throw()
+{
+	if (ss_family == a.ss_family) {
+		switch(ss_family) {
+			case AF_INET:
+				return (
+					(reinterpret_cast<const struct sockaddr_in *>(this)->sin_port == reinterpret_cast<const struct sockaddr_in *>(&a)->sin_port)&&
+					(reinterpret_cast<const struct sockaddr_in *>(this)->sin_addr.s_addr == reinterpret_cast<const struct sockaddr_in *>(&a)->sin_addr.s_addr));
+				break;
+			case AF_INET6:
+				return (
+					(reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_port == reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_port)&&
+					(reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_flowinfo == reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_flowinfo)&&
+					(memcmp(reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_addr.s6_addr,reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_addr.s6_addr,16) == 0)&&
+					(reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_scope_id == reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_scope_id));
+				break;
+			default:
+				return (memcmp(this,&a,sizeof(InetAddress)) == 0);
+		}
+	}
+	return false;
+}
+
+bool InetAddress::operator<(const InetAddress &a) const
+	throw()
+{
+	if (ss_family < a.ss_family) {
+		switch(ss_family) {
+			case AF_INET:
+				if (reinterpret_cast<const struct sockaddr_in *>(this)->sin_port < reinterpret_cast<const struct sockaddr_in *>(&a)->sin_port)
+					return true;
+				if (reinterpret_cast<const struct sockaddr_in *>(this)->sin_addr.s_addr < reinterpret_cast<const struct sockaddr_in *>(&a)->sin_addr.s_addr)
+					return true;
+				break;
+			case AF_INET6:
+				if (reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_port < reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_port)
+					return true;
+				if (reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_flowinfo < reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_flowinfo)
+					return true;
+				if (memcmp(reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_addr.s6_addr,reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_addr.s6_addr,16) < 0)
+					return true;
+				if (reinterpret_cast<const struct sockaddr_in6 *>(this)->sin6_scope_id < reinterpret_cast<const struct sockaddr_in6 *>(&a)->sin6_scope_id)
+					return true;
+				break;
+			default:
+				return (memcmp(this,&a,sizeof(InetAddress)) < 0);
+		}
+	}
+	return false;
+}
+
 InetAddress InetAddress::makeIpv6LinkLocal(const MAC &mac)
 InetAddress InetAddress::makeIpv6LinkLocal(const MAC &mac)
 	throw()
 	throw()
 {
 {

+ 10 - 45
node/InetAddress.hpp

@@ -61,16 +61,6 @@ struct InetAddress : public sockaddr_storage
 	 */
 	 */
 	static const InetAddress LO6;
 	static const InetAddress LO6;
 
 
-	/**
-	 * 0.0.0.0/0
-	 */
-	static const InetAddress DEFAULT4;
-
-	/**
-	 * ::/0
-	 */
-	static const InetAddress DEFAULT6;
-
 	/**
 	/**
 	 * IP address scope
 	 * IP address scope
 	 *
 	 *
@@ -91,6 +81,7 @@ struct InetAddress : public sockaddr_storage
 
 
 	InetAddress() throw() { memset(this,0,sizeof(InetAddress)); }
 	InetAddress() throw() { memset(this,0,sizeof(InetAddress)); }
 	InetAddress(const InetAddress &a) throw() { memcpy(this,&a,sizeof(InetAddress)); }
 	InetAddress(const InetAddress &a) throw() { memcpy(this,&a,sizeof(InetAddress)); }
+	InetAddress(const InetAddress *a) throw() { memcpy(this,a,sizeof(InetAddress)); }
 	InetAddress(const struct sockaddr_storage &ss) throw() { *this = ss; }
 	InetAddress(const struct sockaddr_storage &ss) throw() { *this = ss; }
 	InetAddress(const struct sockaddr_storage *ss) throw() { *this = ss; }
 	InetAddress(const struct sockaddr_storage *ss) throw() { *this = ss; }
 	InetAddress(const struct sockaddr &sa) throw() { *this = sa; }
 	InetAddress(const struct sockaddr &sa) throw() { *this = sa; }
@@ -112,6 +103,13 @@ struct InetAddress : public sockaddr_storage
 		return *this;
 		return *this;
 	}
 	}
 
 
+	inline InetAddress &operator=(const InetAddress *a)
+		throw()
+	{
+		memcpy(this,a,sizeof(InetAddress));
+		return *this;
+	}
+
 	inline InetAddress &operator=(const struct sockaddr_storage &ss)
 	inline InetAddress &operator=(const struct sockaddr_storage &ss)
 		throw()
 		throw()
 	{
 	{
@@ -294,39 +292,6 @@ struct InetAddress : public sockaddr_storage
 	 */
 	 */
 	inline bool isV6() const throw() { return (ss_family == AF_INET6); }
 	inline bool isV6() const throw() { return (ss_family == AF_INET6); }
 
 
-	/**
-	 * Force type to IPv4
-	 */
-	inline void setV4() throw() { ss_family = AF_INET; }
-
-	/**
-	 * Force type to IPv6
-	 */
-	inline void setV6() throw() { ss_family = AF_INET6; }
-
-	/**
-	 * @return Length of sockaddr_in if IPv4, sockaddr_in6 if IPv6
-	 */
-	inline unsigned int saddrLen() const
-		throw()
-	{
-		switch(ss_family) {
-			case AF_INET: return sizeof(struct sockaddr_in);
-			case AF_INET6: return sizeof(struct sockaddr_in6);
-			default: return 0;
-		}
-	}
-
-	/**
-	 * @return Raw sockaddr_in structure (valid if IPv4)
-	 */
-	inline const struct sockaddr_in *saddr4() const throw() { return reinterpret_cast<const struct sockaddr_in *>(this); }
-
-	/**
-	 * @return Raw sockaddr_in6 structure (valid if IPv6)
-	 */
-	inline const struct sockaddr_in6 *saddr6() const throw() { return reinterpret_cast<const struct sockaddr_in6 *>(this); }
-
 	/**
 	/**
 	 * @return pointer to raw IP address bytes
 	 * @return pointer to raw IP address bytes
 	 */
 	 */
@@ -376,9 +341,9 @@ struct InetAddress : public sockaddr_storage
 	 */
 	 */
 	inline operator bool() const throw() { return (ss_family != 0); }
 	inline operator bool() const throw() { return (ss_family != 0); }
 
 
-	inline bool operator==(const InetAddress &a) const throw() { return (memcmp(this,&a,sizeof(InetAddress)) == 0); }
+	bool operator==(const InetAddress &a) const throw();
+	bool operator<(const InetAddress &a) const throw();
 	inline bool operator!=(const InetAddress &a) const throw() { return !(*this == a); }
 	inline bool operator!=(const InetAddress &a) const throw() { return !(*this == a); }
-	inline bool operator<(const InetAddress &a) const throw() { return (memcmp(this,&a,sizeof(InetAddress)) < 0); }
 	inline bool operator>(const InetAddress &a) const throw() { return (a < *this); }
 	inline bool operator>(const InetAddress &a) const throw() { return (a < *this); }
 	inline bool operator<=(const InetAddress &a) const throw() { return !(a < *this); }
 	inline bool operator<=(const InetAddress &a) const throw() { return !(a < *this); }
 	inline bool operator>=(const InetAddress &a) const throw() { return !(*this < a); }
 	inline bool operator>=(const InetAddress &a) const throw() { return !(*this < a); }