Browse Source

Tweak protocol getPacketId(), unit tests for defragmenter, AES fix.

Adam Ierymenko 5 years ago
parent
commit
a58f11e601
6 changed files with 124 additions and 46 deletions
  1. 2 2
      node/AES.cpp
  2. 18 6
      node/Defragmenter.hpp
  3. 8 5
      node/Protocol.cpp
  4. 6 3
      node/Protocol.hpp
  5. 90 29
      node/Tests.cpp
  6. 0 1
      node/Tests.h

+ 2 - 2
node/AES.cpp

@@ -535,7 +535,7 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 		out += totalLen;
 		out += totalLen;
 		_len = (totalLen + len);
 		_len = (totalLen + len);
 
 
-		if (likely((c1 + len) > c1)) { // it's incredibly likely that we can ignore carry in counter increment
+		if (likely((c1 + len) > c1)) { // if this is true we can just increment c1 and ignore c0
 			while (len >= 64) {
 			while (len >= 64) {
 				__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
 				__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1),(long long)c0);
 				__m128i d1 = _mm_set_epi64x((long long)Utils::hton(c1 + 1ULL),(long long)c0);
 				__m128i d1 = _mm_set_epi64x((long long)Utils::hton(c1 + 1ULL),(long long)c0);
@@ -663,7 +663,7 @@ void AES::CTR::crypt(const void *const input,unsigned int len) noexcept
 					out += 16;
 					out += 16;
 				} while (len >= 16);
 				} while (len >= 16);
 			}
 			}
-		} else {
+		} else { // in the unlikely case c1 is near uint64_max, we must add with carry
 			while (len >= 64) {
 			while (len >= 64) {
 				__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1++),(long long)c0);
 				__m128i d0 = _mm_set_epi64x((long long)Utils::hton(c1++),(long long)c0);
 				if (unlikely(c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);
 				if (unlikely(c1 == 0ULL)) c0 = Utils::hton(Utils::ntoh(c0) + 1ULL);

+ 18 - 6
node/Defragmenter.hpp

@@ -43,7 +43,7 @@ namespace ZeroTier {
  * @tparam GCS Garbage collection target size for the incoming message queue
  * @tparam GCS Garbage collection target size for the incoming message queue
  * @tparam GCT Garbage collection trigger threshold, usually 2X GCS
  * @tparam GCT Garbage collection trigger threshold, usually 2X GCS
  */
  */
-template<unsigned int MF,unsigned int GCS = 32,unsigned int GCT = 64>
+template<unsigned int MF = 16,unsigned int GCS = 32,unsigned int GCT = 64>
 class Defragmenter
 class Defragmenter
 {
 {
 public:
 public:
@@ -248,7 +248,7 @@ public:
 		s.e = fragmentDataIndex + fragmentDataSize;
 		s.e = fragmentDataIndex + fragmentDataSize;
 
 
 		// If we now have all fragments then assemble them.
 		// If we now have all fragments then assemble them.
-		if ((e->message.size() >= e->totalFragmentsExpected)&&(e->totalFragmentsExpected > 0)) {
+		if ((++e->fragmentsReceived >= e->totalFragmentsExpected)&&(e->totalFragmentsExpected > 0)) {
 			// This message is done so de-register it with its path if one is associated.
 			// This message is done so de-register it with its path if one is associated.
 			if (e->via) {
 			if (e->via) {
 				e->via->_inboundFragmentedMessages_l.lock();
 				e->via->_inboundFragmentedMessages_l.lock();
@@ -277,19 +277,31 @@ public:
 		_messages.clear();
 		_messages.clear();
 	}
 	}
 
 
+	/**
+	 * @return Number of entries currently in message defragmentation cache
+	 */
+	ZT_ALWAYS_INLINE unsigned int cacheSize() noexcept
+	{
+		RWMutex::RLock ml(_messages_l);
+		return _messages.size();
+	}
+
 private:
 private:
 	struct _E
 	struct _E
 	{
 	{
-		ZT_ALWAYS_INLINE _E() : id(0),lastUsed(0),totalFragmentsExpected(0),via(),message(),lock() {}
+		ZT_ALWAYS_INLINE _E() : id(0),lastUsed(0),totalFragmentsExpected(0),fragmentsReceived(0),via(),message(),lock() {}
 		ZT_ALWAYS_INLINE ~_E()
 		ZT_ALWAYS_INLINE ~_E()
 		{
 		{
-			via->_inboundFragmentedMessages_l.lock();
-			via->_inboundFragmentedMessages.erase(id);
-			via->_inboundFragmentedMessages_l.unlock();
+			if (via) {
+				via->_inboundFragmentedMessages_l.lock();
+				via->_inboundFragmentedMessages.erase(id);
+				via->_inboundFragmentedMessages_l.unlock();
+			}
 		}
 		}
 		uint64_t id;
 		uint64_t id;
 		volatile int64_t lastUsed;
 		volatile int64_t lastUsed;
 		unsigned int totalFragmentsExpected;
 		unsigned int totalFragmentsExpected;
+		unsigned int fragmentsReceived;
 		SharedPtr<Path> via;
 		SharedPtr<Path> via;
 		FCV< Buf::Slice,MF > message;
 		FCV< Buf::Slice,MF > message;
 		Mutex lock;
 		Mutex lock;

+ 8 - 5
node/Protocol.cpp

@@ -16,6 +16,13 @@
 #include "Utils.hpp"
 #include "Utils.hpp"
 
 
 #include <cstdlib>
 #include <cstdlib>
+#include <ctime>
+
+#ifdef __WINDOWS__
+#include <process.h>
+#else
+#include <unistd.h>
+#endif
 
 
 namespace ZeroTier {
 namespace ZeroTier {
 namespace Protocol {
 namespace Protocol {
@@ -30,11 +37,7 @@ uint64_t createProbe(const Identity &sender,const Identity &recipient,const uint
 	return hash[0];
 	return hash[0];
 }
 }
 
 
-uint64_t getPacketId() noexcept
-{
-	static std::atomic<uint64_t> s_packetIdCtr(Utils::getSecureRandomU64());
-	return ++s_packetIdCtr;
-}
+std::atomic<uint64_t> _s_packetIdCtr((uint64_t)time(nullptr) << 32U);
 
 
 void armor(Buf &pkt,int packetSize,const uint8_t key[ZT_PEER_SECRET_KEY_LENGTH],uint8_t cipherSuite) noexcept
 void armor(Buf &pkt,int packetSize,const uint8_t key[ZT_PEER_SECRET_KEY_LENGTH],uint8_t cipherSuite) noexcept
 {
 {

+ 6 - 3
node/Protocol.hpp

@@ -1048,12 +1048,15 @@ static ZT_ALWAYS_INLINE void salsa2012DeriveKey(const uint8_t *const in,uint8_t
  */
  */
 uint64_t createProbe(const Identity &sender,const Identity &recipient,const uint8_t key[ZT_PEER_SECRET_KEY_LENGTH]) noexcept;
 uint64_t createProbe(const Identity &sender,const Identity &recipient,const uint8_t key[ZT_PEER_SECRET_KEY_LENGTH]) noexcept;
 
 
+// Do not use directly
+extern std::atomic<uint64_t> _s_packetIdCtr;
+
 /**
 /**
- * Get a sequential non-repeating packet ID for the next packet (thread-safe)
+ * Get a packet ID (and nonce) for a new packet
  *
  *
- * @return Next packet ID / cryptographic nonce
+ * @return Next packet ID
  */
  */
-uint64_t getPacketId() noexcept;
+static ZT_ALWAYS_INLINE uint64_t getPacketId() noexcept { return ++_s_packetIdCtr; }
 
 
 /**
 /**
  * Encrypt and compute packet MAC
  * Encrypt and compute packet MAC

+ 90 - 29
node/Tests.cpp

@@ -465,6 +465,96 @@ extern "C" const char *ZTT_general()
 			}
 			}
 			ZT_T_PRINTF("OK" ZT_EOL_S);
 			ZT_T_PRINTF("OK" ZT_EOL_S);
 		}
 		}
+
+		{
+			// This doesn't check behavior when fragments are invalid or input is totally insane.
+			// That's done during fuzzing.
+			ZT_T_PRINTF("[general] Testing Defragmenter... ");
+			Defragmenter<> defrag;
+
+			const SharedPtr<Path> nullvia;
+			uint64_t messageId = 0;
+			int64_t ts = now();
+			for(int k=0;k<50000;++k) {
+				++messageId;
+				FCV<Buf::Slice,16> message;
+				FCV<Buf::Slice,16> ref;
+
+				int frags = 1 + (int)(Utils::random() % 16);
+				int skip = ((k & 3) == 1) ? -1 : (int)(Utils::random() % frags);
+				bool complete = false;
+				message.resize(frags);
+				ref.resize(frags);
+
+				for (int f=0;f<frags;++f) {
+					if (f != skip) {
+						ref[f].b.set(new Buf());
+						ref[f].s = (unsigned int)(Utils::random() % 24);
+						ref[f].e = ref[f].s + (unsigned int)(Utils::random() % 1000);
+						for (unsigned int i=ref[f].s;i<ref[f].e;++i)
+							ref[f].b->unsafeData[i] = (uint8_t)f;
+					}
+				}
+
+				for (int f=0;f<frags;++f) {
+					if (f != skip) {
+						if (complete) {
+							ZT_T_PRINTF("FAILED (message prematurely complete)" ZT_EOL_S);
+							return "Defragmenter test failed: message prematurely complete";
+						}
+						switch (defrag.assemble(messageId,message,ref[f].b,ref[f].s,ref[f].e - ref[f].s,f,frags,ts++,nullvia,0)) {
+							case Defragmenter<>::OK:
+								break;
+							case Defragmenter<>::COMPLETE:
+								complete = true;
+								break;
+							case Defragmenter<>::ERR_DUPLICATE_FRAGMENT:
+								break;
+							case Defragmenter<>::ERR_INVALID_FRAGMENT:
+								ZT_T_PRINTF("FAILED (invalid fragment)" ZT_EOL_S);
+								return "Defragmenter test failed: invalid fragment";
+							case Defragmenter<>::ERR_TOO_MANY_FRAGMENTS_FOR_PATH:
+								break;
+							case Defragmenter<>::ERR_OUT_OF_MEMORY:
+								ZT_T_PRINTF("FAILED (out of memory)" ZT_EOL_S);
+								return "Defragmenter test failed: out of memory";
+						}
+					}
+				}
+
+				if (skip == -1) {
+					if (complete) {
+						for(int f=0;f<frags;++f) {
+							if (!message[f].b) {
+								ZT_T_PRINTF("FAILED (fragment %d has null buffer)" ZT_EOL_S,f);
+								return "Defragmenter test failed: fragment has null buffer";
+							}
+							if ((message[f].s != ref[f].s)||(message[f].e != ref[f].e)) {
+								ZT_T_PRINTF("FAILED (fragment %d size and bounds incorrect (%u:%u, expected %u:%u))" ZT_EOL_S,f,message[f].s,message[f].e,ref[f].s,ref[f].e);
+								return "Defragmenter test failed: fragment size and bounds incorrect";
+							}
+							for(unsigned int i=message[f].s;i!=message[f].e;++i) {
+								if (message[f].b->unsafeData[i] != (uint8_t)f) {
+									ZT_T_PRINTF("FAILED (fragment %d data invalid (raw index %u: %d != %d))" ZT_EOL_S,f,i,(int)message[f].b->unsafeData[i],f);
+									return "Defragmenter test failed: fragment data invalid";
+								}
+							}
+						}
+					} else {
+						ZT_T_PRINTF("FAILED (message incomplete after all fragments)" ZT_EOL_S);
+						return "Defragmenter test failed: message incomplete after all fragments";
+					}
+				} else {
+					if (complete) {
+						ZT_T_PRINTF("FAILED (message completed without all fragments)" ZT_EOL_S);
+						return "Defragmenter test failed: message completed without all fragments";
+					}
+				}
+			}
+
+			Buf::freePool();
+			ZT_T_PRINTF("OK (cache remaining: %u)" ZT_EOL_S,defrag.cacheSize());
+		}
 	} catch (std::exception &e) {
 	} catch (std::exception &e) {
 		ZT_T_PRINTF(ZT_EOL_S "[general] Unexpected exception: %s" ZT_EOL_S,e.what());
 		ZT_T_PRINTF(ZT_EOL_S "[general] Unexpected exception: %s" ZT_EOL_S,e.what());
 		return e.what();
 		return e.what();
@@ -702,35 +792,6 @@ extern "C" const char *ZTT_crypto()
 	return nullptr;
 	return nullptr;
 }
 }
 
 
-extern "C" const char *ZTT_defragmenter()
-{
-#if 0
-	Defragmenter<11> defrag;
-
-/*
-	ZT_ALWAYS_INLINE ResultCode assemble(
-		const uint64_t messageId,
-		FCV< Buf::Slice,MF > &message,
-		SharedPtr<Buf> &fragment,
-		const unsigned int fragmentDataIndex,
-		const unsigned int fragmentDataSize,
-		const unsigned int fragmentNo,
-		const unsigned int totalFragmentsExpected,
-		const int64_t now,
-		const SharedPtr< Path > &via,
-		const unsigned int maxIncomingFragmentsPerPath)
-	{
-*/
-
-	uint64_t messageId = 1;
-	FCV< Buf::Slice,11 > message;
-	for(int kk=0;kk<16;++kk) {
-	}
-
-#endif
-	return nullptr;
-}
-
 extern "C" const char *ZTT_benchmarkCrypto()
 extern "C" const char *ZTT_benchmarkCrypto()
 {
 {
 	try {
 	try {

+ 0 - 1
node/Tests.h

@@ -56,7 +56,6 @@ extern "C" {
 
 
 const char *ZTT_general();
 const char *ZTT_general();
 const char *ZTT_crypto();
 const char *ZTT_crypto();
-const char *ZTT_defragmenter();
 
 
 // Benchmarks ---------------------------------------------------------------------------------------------------------
 // Benchmarks ---------------------------------------------------------------------------------------------------------