Browse Source

Optimize AntiRecursion.

Adam Ierymenko 9 years ago
parent
commit
b6725c4415
2 changed files with 42 additions and 29 deletions
  1. 42 24
      node/AntiRecursion.hpp
  2. 0 5
      node/Constants.hpp

+ 42 - 24
node/AntiRecursion.hpp

@@ -35,28 +35,28 @@
 
 namespace ZeroTier {
 
-#define ZT_ANTIRECURSION_TAIL_LEN 128
+/**
+ * Size of anti-recursion history
+ */
+#define ZT_ANTIRECURSION_HISTORY_SIZE 16
 
 /**
  * Filter to prevent recursion (ZeroTier-over-ZeroTier)
  *
  * This works by logging ZeroTier packets that we send. It's then invoked
- * again against packets read from local Ethernet taps. If the last N
+ * again against packets read from local Ethernet taps. If the last 32
  * bytes representing the ZeroTier packet match in the tap frame, then
  * the frame is a re-injection of a frame that we sent and is rejected.
  *
  * This means that ZeroTier packets simply will not traverse ZeroTier
  * networks, which would cause all sorts of weird problems.
  *
- * NOTE: this is applied to low-level packets before they are sent to
- * SocketManager and/or sockets, not to fully assembled packets before
- * (possible) fragmentation.
+ * This is highly optimized code since it's checked for every packet.
  */
 class AntiRecursion
 {
 public:
 	AntiRecursion()
-		throw()
 	{
 		memset(_history,0,sizeof(_history));
 		_ptr = 0;
@@ -68,13 +68,20 @@ public:
 	 * @param data ZT packet data
 	 * @param len Length of packet
 	 */
-	inline void logOutgoingZT(const void *data,unsigned int len)
-		throw()
+	inline void logOutgoingZT(const void *const data,const unsigned int len)
 	{
-		ArItem *i = &(_history[_ptr++ % ZT_ANTIRECURSION_HISTORY_SIZE]);
-		const unsigned int tl = (len > ZT_ANTIRECURSION_TAIL_LEN) ? ZT_ANTIRECURSION_TAIL_LEN : len;
-		memcpy(i->tail,((const unsigned char *)data) + (len - tl),tl);
-		i->len = tl;
+		if (len < 32)
+			return;
+#ifdef ZT_NO_TYPE_PUNNING
+		memcpy(_history[++_ptr % ZT_ANTIRECURSION_HISTORY_SIZE].tail,reinterpret_cast<const uint8_t *>(data) + (len - 32),32);
+#else
+		uint64_t *t = _history[++_ptr % ZT_ANTIRECURSION_HISTORY_SIZE].tail;
+		const uint64_t *p = reinterpret_cast<const uint64_t *>(reinterpret_cast<const uint8_t *>(data) + (len - 32));
+		*(t++) = *(p++);
+		*(t++) = *(p++);
+		*(t++) = *(p++);
+		*t = *p;
+#endif
 	}
 
 	/**
@@ -84,25 +91,36 @@ public:
 	 * @param len Length of frame
 	 * @return True if frame is OK to be passed, false if it's a ZT frame that we sent
 	 */
-	inline bool checkEthernetFrame(const void *data,unsigned int len)
-		throw()
+	inline bool checkEthernetFrame(const void *const data,const unsigned int len) const
 	{
-		for(unsigned int h=0;h<ZT_ANTIRECURSION_HISTORY_SIZE;++h) {
-			ArItem *i = &(_history[h]);
-			if ((i->len > 0)&&(len >= i->len)&&(!memcmp(((const unsigned char *)data) + (len - i->len),i->tail,i->len)))
+		if (len < 32)
+			return true;
+		const uint8_t *const pp = reinterpret_cast<const uint8_t *>(data) + (len - 32);
+		const _ArItem *i = _history;
+		const _ArItem *const end = i + ZT_ANTIRECURSION_HISTORY_SIZE;
+		while (i != end) {
+#ifdef ZT_NO_TYPE_PUNNING
+			if (!memcmp(pp,i->tail,32))
 				return false;
+#else
+			const uint64_t *t = i->tail;
+			const uint64_t *p = reinterpret_cast<const uint64_t *>(pp);
+			uint64_t bits = *(t++) ^ *(p++);
+			bits |= *(t++) ^ *(p++);
+			bits |= *(t++) ^ *(p++);
+			bits |= *t ^ *p;
+			if (!bits)
+				return false;
+#endif
+			++i;
 		}
 		return true;
 	}
 
 private:
-	struct ArItem
-	{
-		unsigned char tail[ZT_ANTIRECURSION_TAIL_LEN];
-		unsigned int len;
-	};
-	ArItem _history[ZT_ANTIRECURSION_HISTORY_SIZE];
-	volatile unsigned int _ptr;
+	struct _ArItem { uint64_t tail[4]; };
+	_ArItem _history[ZT_ANTIRECURSION_HISTORY_SIZE];
+	volatile unsigned long _ptr;
 };
 
 } // namespace ZeroTier

+ 0 - 5
node/Constants.hpp

@@ -309,11 +309,6 @@
  */
 #define ZT_NAT_T_TACTICAL_ESCALATION_DELAY 1000
 
-/**
- * Size of anti-recursion history (see AntiRecursion.hpp)
- */
-#define ZT_ANTIRECURSION_HISTORY_SIZE 16
-
 /**
  * Minimum delay between attempts to confirm new paths to peers (to avoid HELLO flooding)
  */