2
0
Эх сурвалжийг харах

BSD routing table works... that hurt much worse than it should have.

Adam Ierymenko 11 жил өмнө
parent
commit
51766e6549

+ 37 - 4
node/BSDRoutingTable.cpp

@@ -28,6 +28,7 @@
 #include <stdint.h>
 #include <stdint.h>
 #include <stdio.h>
 #include <stdio.h>
 #include <stdlib.h>
 #include <stdlib.h>
+#include <string.h>
 #include <unistd.h>
 #include <unistd.h>
 #include <sys/param.h>
 #include <sys/param.h>
 #include <sys/sysctl.h>
 #include <sys/sysctl.h>
@@ -35,7 +36,9 @@
 #include <netinet/in.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <arpa/inet.h>
 #include <net/route.h>
 #include <net/route.h>
+#include <net/if.h>
 #include <net/if_dl.h>
 #include <net/if_dl.h>
+#include <ifaddrs.h>
 
 
 #include <algorithm>
 #include <algorithm>
 #include <utility>
 #include <utility>
@@ -55,7 +58,7 @@ BSDRoutingTable::~BSDRoutingTable()
 {
 {
 }
 }
 
 
-std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
+std::vector<RoutingTable::Entry> BSDRoutingTable::get(bool includeLinkLocal,bool includeLoopback) const
 {
 {
 	std::vector<RoutingTable::Entry> entries;
 	std::vector<RoutingTable::Entry> entries;
 	int mib[6];
 	int mib[6];
@@ -82,6 +85,7 @@ std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
 
 
 					if (((rtm->rtm_flags & RTF_LLINFO) == 0)&&((rtm->rtm_flags & RTF_HOST) == 0)&&((rtm->rtm_flags & RTF_UP) != 0)&&((rtm->rtm_flags & RTF_MULTICAST) == 0)) {
 					if (((rtm->rtm_flags & RTF_LLINFO) == 0)&&((rtm->rtm_flags & RTF_HOST) == 0)&&((rtm->rtm_flags & RTF_UP) != 0)&&((rtm->rtm_flags & RTF_MULTICAST) == 0)) {
 						RoutingTable::Entry e;
 						RoutingTable::Entry e;
+						e.deviceIndex = -9999; // unset
 
 
 						int which = 0;
 						int which = 0;
 						while (saptr < saend) {
 						while (saptr < saend) {
@@ -120,7 +124,15 @@ std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
 									break;
 									break;
 								case 1:
 								case 1:
 									//printf("RTA_GATEWAY\n");
 									//printf("RTA_GATEWAY\n");
-									e.gateway.set(sa);
+									switch(sa->sa_family) {
+										case AF_LINK:
+											e.deviceIndex = (int)((const struct sockaddr_dl *)sa)->sdl_index;
+											break;
+										case AF_INET:
+										case AF_INET6:
+											e.gateway.set(sa);
+											break;
+									}
 									break;
 									break;
 								case 2: {
 								case 2: {
 									if (e.destination.isV6()) {
 									if (e.destination.isV6()) {
@@ -148,6 +160,7 @@ std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
 									}
 									}
 									//printf("RTA_NETMASK\n");
 									//printf("RTA_NETMASK\n");
 								}	break;
 								}	break;
+								/*
 								case 3:
 								case 3:
 									//printf("RTA_GENMASK\n");
 									//printf("RTA_GENMASK\n");
 									break;
 									break;
@@ -160,6 +173,7 @@ std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
 								case 6:
 								case 6:
 									//printf("RTA_AUTHOR\n");
 									//printf("RTA_AUTHOR\n");
 									break;
 									break;
+								*/
 							}
 							}
 
 
 							saptr += salen;
 							saptr += salen;
@@ -167,8 +181,8 @@ std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
 
 
 						e.metric = (int)rtm->rtm_rmx.rmx_hopcount;
 						e.metric = (int)rtm->rtm_rmx.rmx_hopcount;
 
 
-						entries.push_back(e);
-						printf("%s\n",e.toString().c_str());
+						if (((includeLinkLocal)||(!e.destination.isLinkLocal()))&&((includeLoopback)||((!e.destination.isLoopback())&&(!e.gateway.isLoopback()))))
+							entries.push_back(e);
 					}
 					}
 
 
 					next = saend;
 					next = saend;
@@ -179,7 +193,24 @@ std::vector<RoutingTable::Entry> BSDRoutingTable::get() const
 		}
 		}
 	}
 	}
 
 
+	for(std::vector<ZeroTier::RoutingTable::Entry>::iterator e1(entries.begin());e1!=entries.end();++e1) {
+		if ((!e1->device[0])&&(e1->deviceIndex >= 0))
+			if_indextoname(e1->deviceIndex,e1->device);
+	}
+	for(std::vector<ZeroTier::RoutingTable::Entry>::iterator e1(entries.begin());e1!=entries.end();++e1) {
+		if ((!e1->device[0])&&(e1->gateway)) {
+			int bestMetric = 9999999;
+			for(std::vector<ZeroTier::RoutingTable::Entry>::iterator e2(entries.begin());e2!=entries.end();++e2) {
+				if ((e1->gateway.within(e2->destination))&&(e2->metric <= bestMetric)) {
+					bestMetric = e2->metric;
+					Utils::scopy(e1->device,sizeof(e1->device),e2->device);
+				}
+			}
+		}
+	}
+
 	std::sort(entries.begin(),entries.end());
 	std::sort(entries.begin(),entries.end());
+
 	return entries;
 	return entries;
 }
 }
 
 
@@ -196,6 +227,8 @@ int main(int argc,char **argv)
 {
 {
 	ZeroTier::BSDRoutingTable rt;
 	ZeroTier::BSDRoutingTable rt;
 	std::vector<ZeroTier::RoutingTable::Entry> ents(rt.get());
 	std::vector<ZeroTier::RoutingTable::Entry> ents(rt.get());
+	for(std::vector<ZeroTier::RoutingTable::Entry>::iterator e(ents.begin());e!=ents.end();++e)
+		printf("%s\n",e->toString().c_str());
 	return 0;
 	return 0;
 }
 }
 //*/
 //*/

+ 1 - 1
node/BSDRoutingTable.hpp

@@ -42,7 +42,7 @@ class BSDRoutingTable : public RoutingTable
 public:
 public:
 	BSDRoutingTable();
 	BSDRoutingTable();
 	virtual ~BSDRoutingTable();
 	virtual ~BSDRoutingTable();
-	virtual std::vector<RoutingTable::Entry> get() const;
+	virtual std::vector<RoutingTable::Entry> get(bool includeLinkLocal = false,bool includeLoopback = false) const;
 	virtual bool set(const RoutingTable::Entry &re);
 	virtual bool set(const RoutingTable::Entry &re);
 };
 };
 
 

+ 30 - 0
node/InetAddress.cpp

@@ -215,6 +215,36 @@ bool InetAddress::sameNetworkAs(const InetAddress &ipnet) const
 	return ((*a >> bits) == (*b >> bits));
 	return ((*a >> bits) == (*b >> bits));
 }
 }
 
 
+bool InetAddress::within(const InetAddress &ipnet) const
+	throw()
+{
+	if (_sa.saddr.sa_family != ipnet._sa.saddr.sa_family)
+		return false;
+
+	unsigned int bits = ipnet.netmaskBits();
+	switch(_sa.saddr.sa_family) {
+		case AF_INET:
+			if (bits > 32) return false;
+			break;
+		case AF_INET6:
+			if (bits > 128) return false;
+			break;
+		default: return false;
+	}
+
+	const uint8_t *a = (const uint8_t *)rawIpData();
+	const uint8_t *b = (const uint8_t *)ipnet.rawIpData();
+	while (bits >= 8) {
+		if (*(a++) != *(b++))
+			return false;
+		bits -= 8;
+	}
+	if (bits) {
+		uint8_t mask = ((0xff << (8 - bits)) & 0xff);
+		return ((*a & mask) == (*b & mask));
+	} else return true;
+}
+
 bool InetAddress::operator==(const InetAddress &a) const
 bool InetAddress::operator==(const InetAddress &a) const
 	throw()
 	throw()
 {
 {

+ 18 - 0
node/InetAddress.hpp

@@ -147,6 +147,15 @@ public:
 	bool isLinkLocal() const
 	bool isLinkLocal() const
 		throw();
 		throw();
 
 
+	/**
+	 * @return True if this is a loopback address
+	 */
+	inline bool isLoopback() const
+		throw()
+	{
+		return ((*this == LO4)||(*this == LO6));
+	}
+
 	/**
 	/**
 	 * @return ASCII IP/port format representation
 	 * @return ASCII IP/port format representation
 	 */
 	 */
@@ -286,6 +295,15 @@ public:
 	bool sameNetworkAs(const InetAddress &ipnet) const
 	bool sameNetworkAs(const InetAddress &ipnet) const
 		throw();
 		throw();
 
 
+	/**
+	 * Determine whether this address is within an ip/netmask
+	 *
+	 * @param ipnet IP/netmask
+	 * @return True if this address is within this network
+	 */
+	bool within(const InetAddress &ipnet) const
+		throw();
+
 	/**
 	/**
 	 * Set to null/zero
 	 * Set to null/zero
 	 */
 	 */

+ 1 - 1
node/RoutingTable.cpp

@@ -39,7 +39,7 @@ namespace ZeroTier {
 std::string RoutingTable::Entry::toString() const
 std::string RoutingTable::Entry::toString() const
 {
 {
 	char tmp[1024];
 	char tmp[1024];
-	Utils::snprintf(tmp,sizeof(tmp),"%s %s %s %d",destination.toString().c_str(),((gateway) ? gateway.toIpString().c_str() : "(link)"),device,metric);
+	Utils::snprintf(tmp,sizeof(tmp),"%s %s %s %d",destination.toString().c_str(),((gateway) ? gateway.toIpString().c_str() : "<link>"),device,metric);
 	return std::string(tmp);
 	return std::string(tmp);
 }
 }
 
 

+ 7 - 2
node/RoutingTable.hpp

@@ -50,7 +50,8 @@ public:
 		InetAddress destination;
 		InetAddress destination;
 		InetAddress gateway; // port/netmaskBits field not used, should be 0 -- null if direct-to-device route
 		InetAddress gateway; // port/netmaskBits field not used, should be 0 -- null if direct-to-device route
 		char device[128];
 		char device[128];
-		int metric;
+		int deviceIndex; // may not always be set, depending on OS -- for internal use only
+		int metric; // higher = lower priority -- on some OSes this is "hop count," etc.
 
 
 		std::string toString() const;
 		std::string toString() const;
 
 
@@ -66,9 +67,13 @@ public:
 	virtual ~RoutingTable();
 	virtual ~RoutingTable();
 
 
 	/**
 	/**
+	 * Get routing table
+	 *
+	 * @param includeLinkLocal If true, include link-local address routes (default: false)
+	 * @param includeLoopback Include loopback (default: false)
 	 * @return Sorted routing table entries
 	 * @return Sorted routing table entries
 	 */
 	 */
-	virtual std::vector<Entry> get() const = 0;
+	virtual std::vector<Entry> get(bool includeLinkLocal = false,bool includeLoopback = false) const = 0;
 
 
 	/**
 	/**
 	 * Add or update a routing table entry
 	 * Add or update a routing table entry

+ 0 - 26
node/Utils.hpp

@@ -492,32 +492,6 @@ public:
 		return true;
 		return true;
 	}
 	}
 
 
-	/**
-	 * Match two strings with bits masked netmask-style
-	 *
-	 * @param a First string
-	 * @param abits Number of bits in first string
-	 * @param b Second string
-	 * @param bbits Number of bits in second string
-	 * @return True if min(abits,bbits) match between a and b
-	 */
-	static inline bool matchNetmask(const void *a,unsigned int abits,const void *b,unsigned int bbits)
-		throw()
-	{
-		const unsigned char *aptr = (const unsigned char *)a;
-		const unsigned char *bptr = (const unsigned char *)b;
-
-		while ((abits >= 8)&&(bbits >= 8)) {
-			if (*aptr++ != *bptr++)
-				return false;
-			abits -= 8;
-			bbits -= 8;
-		}
-
-		unsigned char mask = 0xff << (8 - ((abits > bbits) ? bbits : abits));
-		return ((*aptr & mask) == (*aptr & mask));
-	}
-
 	/**
 	/**
 	 * Compute SDBM hash of a binary string
 	 * Compute SDBM hash of a binary string
 	 *
 	 *