Browse Source

updateAndCheckMulticastBalance and friends

Adam Ierymenko 12 years ago
parent
commit
cdb96726df
4 changed files with 105 additions and 78 deletions
  1. 23 34
      node/BandwidthAccount.hpp
  2. 7 21
      node/Network.cpp
  3. 33 20
      node/Network.hpp
  4. 42 3
      node/Utils.hpp

+ 23 - 34
node/BandwidthAccount.hpp

@@ -28,16 +28,14 @@
 #ifndef _ZT_BWACCOUNT_HPP
 #ifndef _ZT_BWACCOUNT_HPP
 #define _ZT_BWACCOUNT_HPP
 #define _ZT_BWACCOUNT_HPP
 
 
+#include <stdint.h>
 #include <math.h>
 #include <math.h>
 
 
+#include <algorithm>
+
 #include "Constants.hpp"
 #include "Constants.hpp"
 #include "Utils.hpp"
 #include "Utils.hpp"
 
 
-#ifdef __WINDOWS__
-#define fmin(a,b) (((a) <= (b)) ? (a) : (b))
-#define fmax(a,b) (((a) >= (b)) ? (a) : (b))
-#endif
-
 namespace ZeroTier {
 namespace ZeroTier {
 
 
 /**
 /**
@@ -56,27 +54,6 @@ namespace ZeroTier {
 class BandwidthAccount
 class BandwidthAccount
 {
 {
 public:
 public:
-	/**
-	 * Rate of balance accrual and min/max
-	 */
-	struct Accrual
-	{
-		/**
-		 * Rate of balance accrual in bytes per second
-		 */
-		double bytesPerSecond;
-
-		/**
-		 * Maximum balance that can ever be accrued (should be > 0.0)
-		 */
-		double maxBalance;
-
-		/**
-		 * Minimum balance, or maximum allowable "debt" (should be <= 0.0)
-		 */
-		double minBalance;
-	};
-
 	/**
 	/**
 	 * Create an uninitialized account
 	 * Create an uninitialized account
 	 *
 	 *
@@ -88,43 +65,55 @@ public:
 	 * Create and initialize
 	 * Create and initialize
 	 *
 	 *
 	 * @param preload Initial balance to place in account
 	 * @param preload Initial balance to place in account
+	 * @param minb Minimum allowed balance (or maximum debt) (<= 0)
+	 * @param maxb Maximum allowed balance (> 0)
+	 * @param acc Rate of accrual in bytes per second
 	 */
 	 */
-	BandwidthAccount(double preload)
+	BandwidthAccount(int32_t preload,int32_t minb,int32_t maxb,int32_t acc)
 		throw()
 		throw()
 	{
 	{
-		init(preload);
+		init(preload,minb,maxb,acc);
 	}
 	}
 
 
 	/**
 	/**
 	 * Initialize or re-initialize account
 	 * Initialize or re-initialize account
 	 *
 	 *
 	 * @param preload Initial balance to place in account
 	 * @param preload Initial balance to place in account
+	 * @param minb Minimum allowed balance (or maximum debt) (<= 0)
+	 * @param maxb Maximum allowed balance (> 0)
+	 * @param acc Rate of accrual in bytes per second
 	 */
 	 */
-	inline void init(double preload)
+	inline void init(int32_t preload,int32_t minb,int32_t maxb,int32_t acc)
 		throw()
 		throw()
 	{
 	{
 		_lastTime = Utils::nowf();
 		_lastTime = Utils::nowf();
 		_balance = preload;
 		_balance = preload;
+		_minBalance = minb;
+		_maxBalance = maxb;
+		_accrual = acc;
 	}
 	}
 
 
 	/**
 	/**
 	 * Update balance by accruing and then deducting
 	 * Update balance by accruing and then deducting
 	 *
 	 *
-	 * @param ar Current rate of accrual
 	 * @param deduct Amount to deduct, or 0.0 to just update
 	 * @param deduct Amount to deduct, or 0.0 to just update
 	 * @return New balance with deduction applied
 	 * @return New balance with deduction applied
 	 */
 	 */
-	inline double update(const Accrual &ar,double deduct)
+	inline int32_t update(int32_t deduct)
 		throw()
 		throw()
 	{
 	{
 		double lt = _lastTime;
 		double lt = _lastTime;
-		double now = _lastTime = Utils::nowf();
-		return (_balance = fmax(ar.minBalance,fmin(ar.maxBalance,(_balance + (ar.bytesPerSecond * (now - lt))) - deduct)));
+		double now = Utils::nowf();
+		_lastTime = now;
+		return (_balance = std::max(_minBalance,std::min(_maxBalance,(int32_t)round(((double)_balance) + (((double)_accrual) * (now - lt))) - deduct)));
 	}
 	}
 
 
 private:
 private:
 	double _lastTime;
 	double _lastTime;
-	double _balance;
+	int32_t _balance;
+	int32_t _minBalance;
+	int32_t _maxBalance;
+	int32_t _accrual;
 };
 };
 
 
 } // namespace ZeroTier
 } // namespace ZeroTier

+ 7 - 21
node/Network.cpp

@@ -76,27 +76,13 @@ bool Network::Certificate::qualifyMembership(const Network::Certificate &mc) con
 				if (myField->second != theirField->second)
 				if (myField->second != theirField->second)
 					return false;
 					return false;
 			} else {
 			} else {
-				// Otherwise compare range with max delta. Presence of a dot in delta
-				// indicates a floating point comparison. Otherwise an integer
-				// comparison occurs.
-				if (deltaField->second.find('.') != std::string::npos) {
-					double my = Utils::strToDouble(myField->second.c_str());
-					double their = Utils::strToDouble(theirField->second.c_str());
-					double delta = Utils::strToDouble(deltaField->second.c_str());
-					if (fabs(my - their) > delta)
-						return false;
-				} else {
-					uint64_t my = Utils::hexStrToU64(myField->second.c_str());
-					uint64_t their = Utils::hexStrToU64(theirField->second.c_str());
-					uint64_t delta = Utils::hexStrToU64(deltaField->second.c_str());
-					if (my > their) {
-						if ((my - their) > delta)
-							return false;
-					} else {
-						if ((their - my) > delta)
-							return false;
-					}
-				}
+				// Otherwise compare the absolute value of the difference between
+				// the two values against the max delta.
+				int64_t my = Utils::hexStrTo64(myField->second.c_str());
+				int64_t their = Utils::hexStrTo64(theirField->second.c_str());
+				int64_t delta = Utils::hexStrTo64(deltaField->second.c_str());
+				if (llabs((long long)(my - their)) > delta)
+					return false;
 			}
 			}
 		}
 		}
 	}
 	}

+ 33 - 20
node/Network.hpp

@@ -142,7 +142,8 @@ public:
 	 * Key is multicast group in lower case hex format: MAC (without :s) /
 	 * Key is multicast group in lower case hex format: MAC (without :s) /
 	 * ADI (hex). Value is a comma-delimited list of: preload, min, max,
 	 * ADI (hex). Value is a comma-delimited list of: preload, min, max,
 	 * rate of accrual for bandwidth accounts. A key called '*' indicates
 	 * rate of accrual for bandwidth accounts. A key called '*' indicates
-	 * the default for unlisted groups.
+	 * the default for unlisted groups. Values are in hexadecimal and may
+	 * be prefixed with '-' to indicate a negative value.
 	 */
 	 */
 	class MulticastRates : private Dictionary
 	class MulticastRates : private Dictionary
 	{
 	{
@@ -153,16 +154,17 @@ public:
 		struct Rate
 		struct Rate
 		{
 		{
 			Rate() {}
 			Rate() {}
-			Rate(double pl,double minr,double maxr,double bps)
+			Rate(int32_t pl,int32_t minb,int32_t maxb,int32_t acc)
 			{
 			{
 				preload = pl;
 				preload = pl;
-				accrual.bytesPerSecond = bps;
-				accrual.maxBalance = maxr;
-				accrual.minBalance = minr;
+				minBalance = minb;
+				maxBalance = maxb;
+				accrual = acc;
 			}
 			}
-
-			double preload;
-			BandwidthAccount::Accrual accrual;
+			int32_t preload;
+			int32_t minBalance;
+			int32_t maxBalance;
+			int32_t accrual;
 		};
 		};
 
 
 		MulticastRates() {}
 		MulticastRates() {}
@@ -178,7 +180,7 @@ public:
 		/**
 		/**
 		 * @return Default rate, or GLOBAL_DEFAULT_RATE if not specified
 		 * @return Default rate, or GLOBAL_DEFAULT_RATE if not specified
 		 */
 		 */
-		Rate defaultRate() const
+		inline Rate defaultRate() const
 		{
 		{
 			Rate r;
 			Rate r;
 			const_iterator dfl(find("*"));
 			const_iterator dfl(find("*"));
@@ -193,7 +195,7 @@ public:
 		 * @param mg Multicast group
 		 * @param mg Multicast group
 		 * @return Rate or default() rate if not specified
 		 * @return Rate or default() rate if not specified
 		 */
 		 */
-		Rate get(const MulticastGroup &mg) const
+		inline Rate get(const MulticastGroup &mg) const
 		{
 		{
 			const_iterator r(find(mg.toString()));
 			const_iterator r(find(mg.toString()));
 			if (r == end())
 			if (r == end())
@@ -206,26 +208,22 @@ public:
 		{
 		{
 			char tmp[16384];
 			char tmp[16384];
 			Utils::scopy(tmp,sizeof(tmp),s.c_str());
 			Utils::scopy(tmp,sizeof(tmp),s.c_str());
-			Rate r;
-			r.preload = 0.0;
-			r.accrual.bytesPerSecond = 0.0;
-			r.accrual.maxBalance = 0.0;
-			r.accrual.minBalance = 0.0;
+			Rate r(0,0,0,0);
 			char *saveptr = (char *)0;
 			char *saveptr = (char *)0;
 			unsigned int fn = 0;
 			unsigned int fn = 0;
 			for(char *f=Utils::stok(tmp,",",&saveptr);(f);f=Utils::stok((char *)0,",",&saveptr)) {
 			for(char *f=Utils::stok(tmp,",",&saveptr);(f);f=Utils::stok((char *)0,",",&saveptr)) {
 				switch(fn++) {
 				switch(fn++) {
 					case 0:
 					case 0:
-						r.preload = Utils::strToDouble(f);
+						r.preload = (int32_t)Utils::hexStrToLong(f);
 						break;
 						break;
 					case 1:
 					case 1:
-						r.accrual.minBalance = Utils::strToDouble(f);
+						r.minBalance = (int32_t)Utils::hexStrToLong(f);
 						break;
 						break;
 					case 2:
 					case 2:
-						r.accrual.maxBalance = Utils::strToDouble(f);
+						r.maxBalance = (int32_t)Utils::hexStrToLong(f);
 						break;
 						break;
 					case 3:
 					case 3:
-						r.accrual.bytesPerSecond = Utils::strToDouble(f);
+						r.accrual = (int32_t)Utils::hexStrToLong(f);
 						break;
 						break;
 				}
 				}
 			}
 			}
@@ -538,10 +536,24 @@ public:
 		else return ((_etWhitelist[etherType / 8] & (unsigned char)(1 << (etherType % 8))) != 0);
 		else return ((_etWhitelist[etherType / 8] & (unsigned char)(1 << (etherType % 8))) != 0);
 	}
 	}
 
 
+	/**
+	 * Update multicast balance for an address and multicast group, return whether packet is allowed
+	 *
+	 * @param a Address that wants to send/relay packet
+	 * @param mg Multicast group
+	 * @param bytes Size of packet
+	 * @return True if packet is within budget
+	 */
 	inline bool updateAndCheckMulticastBalance(const Address &a,const MulticastGroup &mg,unsigned int bytes)
 	inline bool updateAndCheckMulticastBalance(const Address &a,const MulticastGroup &mg,unsigned int bytes)
 	{
 	{
 		Mutex::Lock _l(_lock);
 		Mutex::Lock _l(_lock);
-		std::map< std::pair<Address,MulticastGroup>,BandwidthAccount >::iterator bal(_multicastRateAccounts.find(std::pair<Address,MulticastGroup>(a,mg)));
+		std::pair<Address,MulticastGroup> k(a,mg);
+		std::map< std::pair<Address,MulticastGroup>,BandwidthAccount >::iterator bal(_multicastRateAccounts.find(k));
+		if (bal == _multicastRateAccounts.end()) {
+			MulticastRates::Rate r(_mcRates.get(mg));
+			bal = _multicastRateAccounts.insert(std::make_pair(k,BandwidthAccount(r.preload,r.minBalance,r.maxBalance,r.accrual))).first;
+		}
+		return (bal->second.update((int32_t)bytes) < (int32_t)bytes);
 	}
 	}
 
 
 private:
 private:
@@ -563,6 +575,7 @@ private:
 	// Configuration from network master node
 	// Configuration from network master node
 	Config _configuration;
 	Config _configuration;
 	Certificate _myCertificate;
 	Certificate _myCertificate;
+	MulticastRates _mcRates;
 
 
 	// Ethertype whitelist bit field, set from config, for really fast lookup
 	// Ethertype whitelist bit field, set from config, for really fast lookup
 	unsigned char _etWhitelist[65536 / 8];
 	unsigned char _etWhitelist[65536 / 8];

+ 42 - 3
node/Utils.hpp

@@ -461,24 +461,44 @@ public:
 #endif
 #endif
 	}
 	}
 
 
-	// String to number converters
+	// String to number converters -- defined here to permit portability
+	// ifdefs for platforms that lack some of the strtoXX functions.
 	static inline unsigned int strToUInt(const char *s)
 	static inline unsigned int strToUInt(const char *s)
 		throw()
 		throw()
 	{
 	{
 		return (unsigned int)strtoul(s,(char **)0,10);
 		return (unsigned int)strtoul(s,(char **)0,10);
 	}
 	}
+	static inline int strToInt(const char *s)
+		throw()
+	{
+		return (int)strtol(s,(char **)0,10);
+	}
 	static inline unsigned long strToULong(const char *s)
 	static inline unsigned long strToULong(const char *s)
 		throw()
 		throw()
 	{
 	{
 		return strtoul(s,(char **)0,10);
 		return strtoul(s,(char **)0,10);
 	}
 	}
+	static inline long strToLong(const char *s)
+		throw()
+	{
+		return strtol(s,(char **)0,10);
+	}
 	static inline unsigned long long strToU64(const char *s)
 	static inline unsigned long long strToU64(const char *s)
 		throw()
 		throw()
 	{
 	{
 #ifdef __WINDOWS__
 #ifdef __WINDOWS__
-		return _strtoui64(s,(char **)0,10);
+		return (unsigned long long)_strtoui64(s,(char **)0,10);
 #else
 #else
 		return strtoull(s,(char **)0,10);
 		return strtoull(s,(char **)0,10);
+#endif
+	}
+	static inline long long strTo64(const char *s)
+		throw()
+	{
+#ifdef __WINDOWS__
+		return (long long)_strtoi64(s,(char **)0,10);
+#else
+		return strtoll(s,(char **)0,10);
 #endif
 #endif
 	}
 	}
 	static inline unsigned int hexStrToUInt(const char *s)
 	static inline unsigned int hexStrToUInt(const char *s)
@@ -486,18 +506,37 @@ public:
 	{
 	{
 		return (unsigned int)strtoul(s,(char **)0,16);
 		return (unsigned int)strtoul(s,(char **)0,16);
 	}
 	}
+	static inline int hexStrToInt(const char *s)
+		throw()
+	{
+		return (int)strtol(s,(char **)0,16);
+	}
 	static inline unsigned long hexStrToULong(const char *s)
 	static inline unsigned long hexStrToULong(const char *s)
 		throw()
 		throw()
 	{
 	{
 		return strtoul(s,(char **)0,16);
 		return strtoul(s,(char **)0,16);
 	}
 	}
+	static inline long hexStrToLong(const char *s)
+		throw()
+	{
+		return strtol(s,(char **)0,16);
+	}
 	static inline unsigned long long hexStrToU64(const char *s)
 	static inline unsigned long long hexStrToU64(const char *s)
 		throw()
 		throw()
 	{
 	{
 #ifdef __WINDOWS__
 #ifdef __WINDOWS__
-		return _strtoui64(s,(char **)0,16);
+		return (unsigned long long)_strtoui64(s,(char **)0,16);
 #else
 #else
 		return strtoull(s,(char **)0,16);
 		return strtoull(s,(char **)0,16);
+#endif
+	}
+	static inline long long hexStrTo64(const char *s)
+		throw()
+	{
+#ifdef __WINDOWS__
+		return (long long)_strtoi64(s,(char **)0,16);
+#else
+		return strtoll(s,(char **)0,16);
 #endif
 #endif
 	}
 	}
 	static inline double strToDouble(const char *s)
 	static inline double strToDouble(const char *s)