Browse Source

Code cleanup, and fix some unsafe pointer handling in Network.

Adam Ierymenko 10 years ago
parent
commit
95f421024a

+ 3 - 7
node/Network.cpp

@@ -535,15 +535,11 @@ void Network::threadMain()
 
 void Network::_CBhandleTapData(void *arg,const MAC &from,const MAC &to,unsigned int etherType,const Buffer<4096> &data)
 {
-	if ((!((Network *)arg)->_enabled)||(((Network *)arg)->status() != NETWORK_OK))
+	SharedPtr<Network> network((Network *)arg,true);
+	if ((!network)||(!network->_enabled)||(network->status() != NETWORK_OK))
 		return;
-
-	const RuntimeEnvironment *RR = ((Network *)arg)->RR;
-	if (RR->shutdownInProgress)
-		return;
-
 	try {
-		RR->sw->onLocalEthernet(SharedPtr<Network>((Network *)arg),from,to,etherType,data);
+		network->RR->sw->onLocalEthernet(network,from,to,etherType,data);
 	} catch (std::exception &exc) {
 		TRACE("unexpected exception handling local packet: %s",exc.what());
 	} catch ( ... ) {

+ 4 - 10
node/Node.cpp

@@ -99,8 +99,6 @@ struct _NodeImpl
 		RuntimeEnvironment *RR = &renv;
 		LOG("terminating: %s",reasonForTerminationStr.c_str());
 
-		renv.shutdownInProgress = true;
-
 		running = false;
 
 #ifndef __WINDOWS__
@@ -109,9 +107,9 @@ struct _NodeImpl
 		delete renv.updater;  renv.updater = (SoftwareUpdater *)0;
 		delete renv.nc;       renv.nc = (NodeConfig *)0;            // shut down all networks, close taps, etc.
 		delete renv.topology; renv.topology = (Topology *)0;        // now we no longer need routing info
-		delete renv.sw;       renv.sw = (Switch *)0;                // order matters less from here down
 		delete renv.mc;       renv.mc = (Multicaster *)0;
 		delete renv.antiRec;  renv.antiRec = (AntiRecursion *)0;
+		delete renv.sw;       renv.sw = (Switch *)0;                // order matters less from here down
 		delete renv.http;     renv.http = (HttpClient *)0;
 		delete renv.prng;     renv.prng = (CMWC4096 *)0;
 		delete renv.log;      renv.log = (Logger *)0;               // but stop logging last of all
@@ -271,16 +269,12 @@ Node::~Node()
 
 static void _CBztTraffic(const SharedPtr<Socket> &fromSock,void *arg,const InetAddress &from,Buffer<ZT_SOCKET_MAX_MESSAGE_LEN> &data)
 {
-	const RuntimeEnvironment *RR = (const RuntimeEnvironment *)arg;
-	if ((RR->sw)&&(!RR->shutdownInProgress))
-		RR->sw->onRemotePacket(fromSock,from,data);
+	((const RuntimeEnvironment *)arg)->sw->onRemotePacket(fromSock,from,data);
 }
 
 static void _cbHandleGetRootTopology(void *arg,int code,const std::string &url,const std::string &body)
 {
 	RuntimeEnvironment *RR = (RuntimeEnvironment *)arg;
-	if (RR->shutdownInProgress)
-		return;
 
 	if ((code != 200)||(body.length() == 0)) {
 		TRACE("failed to retrieve %s",url.c_str());
@@ -391,9 +385,9 @@ Node::ReasonForTermination Node::run()
 		}
 
 		RR->http = new HttpClient();
-		RR->antiRec = new AntiRecursion();
-		RR->mc = new Multicaster(RR);
 		RR->sw = new Switch(RR);
+		RR->mc = new Multicaster(RR);
+		RR->antiRec = new AntiRecursion();
 		RR->topology = new Topology(RR);
 		try {
 			RR->nc = new NodeConfig(RR);

+ 4 - 1
node/NodeConfig.hpp

@@ -145,7 +145,7 @@ public:
 	inline bool hasNetwork(uint64_t nwid)
 	{
 		Mutex::Lock _l(_networks_m);
-		return (_networks.count(nwid) > 0);
+		return (_networks.find(nwid) != _networks.end());
 	}
 
 	/**
@@ -163,12 +163,15 @@ public:
 		return tapDevs;
 	}
 
+private:
 	void _readLocalConfig();
 	void _writeLocalConfig();
 
 	const RuntimeEnvironment *RR;
+
 	Dictionary _localConfig; // persisted as local.conf
 	Mutex _localConfig_m;
+
 	std::map< uint64_t,SharedPtr<Network> > _networks; // persisted in networks.d/
 	Mutex _networks_m;
 };

+ 4 - 8
node/RuntimeEnvironment.hpp

@@ -69,7 +69,6 @@ public:
 		homePath(),
 		identity(),
 		initialized(false),
-		shutdownInProgress(false),
 		tcpTunnelingEnabled(false),
 		timeOfLastResynchronize(0),
 		timeOfLastPacketReceived(0),
@@ -79,9 +78,9 @@ public:
 		log((Logger *)0),
 		prng((CMWC4096 *)0),
 		http((HttpClient *)0),
-		antiRec((AntiRecursion *)0),
-		mc((Multicaster *)0),
 		sw((Switch *)0),
+		mc((Multicaster *)0),
+		antiRec((AntiRecursion *)0),
 		topology((Topology *)0),
 		nc((NodeConfig *)0),
 		node((Node *)0),
@@ -101,9 +100,6 @@ public:
 	// Are we initialized?
 	volatile bool initialized;
 
-	// Indicates that we are shutting down -- this is hacky, want to factor out
-	volatile bool shutdownInProgress;
-
 	// Are we in outgoing TCP failover mode?
 	volatile bool tcpTunnelingEnabled;
 
@@ -130,9 +126,9 @@ public:
 	Logger *log; // null if logging is disabled
 	CMWC4096 *prng;
 	HttpClient *http;
-	AntiRecursion *antiRec;
-	Multicaster *mc;
 	Switch *sw;
+	Multicaster *mc;
+	AntiRecursion *antiRec;
 	Topology *topology;
 	NodeConfig *nc;
 	Node *node;

+ 14 - 0
node/SharedPtr.hpp

@@ -64,6 +64,20 @@ public:
 		++obj->__refCount;
 	}
 
+	SharedPtr(T *obj,bool runAwayFromZombies)
+		throw() :
+		_ptr(obj)
+	{
+		// HACK: this is used in "handlers" to take ownership of naked pointers,
+		// an ugly pattern that really ought to be factored out.
+		if (runAwayFromZombies) {
+			if ((int)(++obj->__refCount) < 2) {
+				--obj->__refCount;
+				_ptr = (T *)0;
+			}
+		} else ++obj->__refCount;
+	}
+
 	SharedPtr(const SharedPtr &sp)
 		throw() :
 		_ptr(sp._getAndInc())

+ 29 - 27
testnet.cpp

@@ -527,11 +527,11 @@ static void doUnicast(const std::vector<std::string> &cmd)
 		for(std::vector<Address>::iterator r(receivers.begin());r!=receivers.end();++r) {
 			if (*s == *r)
 				continue;
+
 			SimNode *sender = nodes[*s];
 			SimNode *receiver = nodes[*r];
-
-			SharedPtr<TestEthernetTap> stap(sender->tapFactory.getByNwid(nwid));
-			SharedPtr<TestEthernetTap> rtap(receiver->tapFactory.getByNwid(nwid));
+			TestEthernetTap *stap = sender->tapFactory.getByNwid(nwid);
+			TestEthernetTap *rtap = receiver->tapFactory.getByNwid(nwid);
 
 			if ((stap)&&(rtap)) {
 				pkt.i[0] = s->toInt();
@@ -557,21 +557,21 @@ static void doUnicast(const std::vector<std::string> &cmd)
 	do {
 		for(std::vector<Address>::iterator r(receivers.begin());r!=receivers.end();++r) {
 			SimNode *receiver = nodes[*r];
-			SharedPtr<TestEthernetTap> rtap(receiver->tapFactory.getByNwid(nwid));
-			if (rtap) {
-				if (rtap->getNextReceivedFrame(frame,5)) {
-					if ((frame.len == frameLen)&&(!memcmp(frame.data + 16,pkt.data + 16,frameLen - 16))) {
-						uint64_t ints[2];
-						memcpy(ints,frame.data,16);
-						printf("%s <- %.10llx received test packet, length == %u, latency == %llums"ZT_EOL_S,r->toString().c_str(),ints[0],frame.len,frame.timestamp - ints[1]);
-						receivedPairs.insert(std::pair<Address,Address>(Address(ints[0]),*r));
-					} else {
-						printf("%s !! got spurious packet, length == %u, etherType == 0x%.4x"ZT_EOL_S,r->toString().c_str(),frame.len,frame.etherType);
-					}
+			TestEthernetTap *rtap = receiver->tapFactory.getByNwid(nwid);
+
+			if ((rtap)&&(rtap->getNextReceivedFrame(frame,5))) {
+				if ((frame.len == frameLen)&&(!memcmp(frame.data + 16,pkt.data + 16,frameLen - 16))) {
+					uint64_t ints[2];
+					memcpy(ints,frame.data,16);
+					printf("%s <- %.10llx received test packet, length == %u, latency == %llums"ZT_EOL_S,r->toString().c_str(),ints[0],frame.len,frame.timestamp - ints[1]);
+					receivedPairs.insert(std::pair<Address,Address>(Address(ints[0]),*r));
+				} else {
+					printf("%s !! got spurious packet, length == %u, etherType == 0x%.4x"ZT_EOL_S,r->toString().c_str(),frame.len,frame.etherType);
 				}
 			}
 		}
-		Thread::sleep(50);
+
+		Thread::sleep(100);
 	} while ((receivedPairs.size() < sentPairs.size())&&(Utils::now() < toutend));
 
 	for(std::vector<Address>::iterator s(senders.begin());s!=senders.end();++s) {
@@ -634,7 +634,8 @@ static void doMulticast(const std::vector<std::string> &cmd)
 
 	for(std::vector<Address>::iterator s(senders.begin());s!=senders.end();++s) {
 		SimNode *sender = nodes[*s];
-		SharedPtr<TestEthernetTap> stap(sender->tapFactory.getByNwid(nwid));
+		TestEthernetTap *stap = sender->tapFactory.getByNwid(nwid);
+
 		if (stap) {
 			pkt.i[0] = s->toInt();
 			pkt.i[1] = Utils::now();
@@ -653,20 +654,21 @@ static void doMulticast(const std::vector<std::string> &cmd)
 	do {
 		for(std::map< Address,SimNode * >::iterator nn(nodes.begin());nn!=nodes.end();++nn) {
 			SimNode *receiver = nn->second;
-			SharedPtr<TestEthernetTap> rtap(receiver->tapFactory.getByNwid(nwid));
-			if (rtap) {
-				if (rtap->getNextReceivedFrame(frame,5)) {
-					if ((frame.len == frameLen)&&(!memcmp(frame.data + 16,pkt.data + 16,frameLen - 16))) {
-						uint64_t ints[2];
-						memcpy(ints,frame.data,16);
-						printf("%s <- %.10llx received test packet, length == %u, latency == %llums"ZT_EOL_S,nn->first.toString().c_str(),ints[0],frame.len,frame.timestamp - ints[1]);
-						++receiveCount;
-					} else {
-						printf("%s !! got spurious packet, length == %u, etherType == 0x%.4x"ZT_EOL_S,nn->first.toString().c_str(),frame.len,frame.etherType);
-					}
+			TestEthernetTap *rtap = receiver->tapFactory.getByNwid(nwid);
+
+			if ((rtap)&&(rtap->getNextReceivedFrame(frame,5))) {
+				if ((frame.len == frameLen)&&(!memcmp(frame.data + 16,pkt.data + 16,frameLen - 16))) {
+					uint64_t ints[2];
+					memcpy(ints,frame.data,16);
+					printf("%s <- %.10llx received test packet, length == %u, latency == %llums"ZT_EOL_S,nn->first.toString().c_str(),ints[0],frame.len,frame.timestamp - ints[1]);
+					++receiveCount;
+				} else {
+					printf("%s !! got spurious packet, length == %u, etherType == 0x%.4x"ZT_EOL_S,nn->first.toString().c_str(),frame.len,frame.etherType);
 				}
 			}
 		}
+
+		Thread::sleep(100);
 	} while (Utils::now() < toutend);
 
 	printf("---------- test multicast received by %u peers"ZT_EOL_S,receiveCount);

+ 0 - 2
testnet/TestEthernetTap.cpp

@@ -43,7 +43,6 @@
 namespace ZeroTier {
 
 TestEthernetTap::TestEthernetTap(
-	TestEthernetTapFactory *parent,
 	const MAC &mac,
 	unsigned int mtu,
 	unsigned int metric,
@@ -54,7 +53,6 @@ TestEthernetTap::TestEthernetTap(
 	void *arg) :
 	EthernetTap("TestEthernetTap",mac,mtu,metric),
 	_nwid(nwid),
-	_parent(parent),
 	_handler(handler),
 	_arg(arg),
 	_enabled(true)

+ 0 - 8
testnet/TestEthernetTap.hpp

@@ -36,8 +36,6 @@
 
 #include "../node/Constants.hpp"
 #include "../node/EthernetTap.hpp"
-#include "../node/AtomicCounter.hpp"
-#include "../node/SharedPtr.hpp"
 #include "../node/Thread.hpp"
 #include "../node/Mutex.hpp"
 
@@ -57,8 +55,6 @@ class TestEthernetTapFactory;
  */
 class TestEthernetTap : public EthernetTap
 {
-	friend class SharedPtr<TestEthernetTap>;
-
 public:
 	struct TestFrame
 	{
@@ -82,7 +78,6 @@ public:
 	};
 
 	TestEthernetTap(
-		TestEthernetTapFactory *parent,
 		const MAC &mac,
 		unsigned int mtu,
 		unsigned int metric,
@@ -113,7 +108,6 @@ public:
 
 private:
 	uint64_t _nwid;
-	TestEthernetTapFactory *_parent;
 
 	void (*_handler)(void *,const MAC &,const MAC &,unsigned int,const Buffer<4096> &);
 	void *_arg;
@@ -123,8 +117,6 @@ private:
 
 	MTQ<TestFrame> _pq;
 	MTQ<TestFrame> _gq;
-
-	AtomicCounter __refCount;
 };
 
 } // namespace ZeroTier

+ 19 - 20
testnet/TestEthernetTapFactory.cpp

@@ -36,6 +36,11 @@ TestEthernetTapFactory::TestEthernetTapFactory()
 
 TestEthernetTapFactory::~TestEthernetTapFactory()
 {
+	Mutex::Lock _l1(_taps_m);
+	Mutex::Lock _l2(_tapsByMac_m);
+	Mutex::Lock _l3(_tapsByNwid_m);
+	for(std::set<EthernetTap *>::iterator t(_taps.begin());t!=_taps.end();++t)
+		delete *t;
 }
 
 EthernetTap *TestEthernetTapFactory::open(
@@ -48,33 +53,27 @@ EthernetTap *TestEthernetTapFactory::open(
 	void (*handler)(void *,const MAC &,const MAC &,unsigned int,const Buffer<4096> &),
 	void *arg)
 {
-	SharedPtr<TestEthernetTap> tap(new TestEthernetTap(this,mac,mtu,metric,nwid,desiredDevice,friendlyName,handler,arg));
-	{
-		Mutex::Lock _l(_taps_m);
-		_taps.insert(tap);
-	}
-	{
-		Mutex::Lock _l(_tapsByMac_m);
-		_tapsByMac[mac] = tap;
-	}
-	{
-		Mutex::Lock _l(_tapsByNwid_m);
-		_tapsByNwid[nwid] = tap;
-	}
-	return tap.ptr();
+	TestEthernetTap *tap = new TestEthernetTap(mac,mtu,metric,nwid,desiredDevice,friendlyName,handler,arg);
+	Mutex::Lock _l1(_taps_m);
+	Mutex::Lock _l2(_tapsByMac_m);
+	Mutex::Lock _l3(_tapsByNwid_m);
+	_taps.insert(tap);
+	_tapsByMac[mac] = tap;
+	_tapsByNwid[nwid] = tap;
+	return tap;
 }
 
 void TestEthernetTapFactory::close(EthernetTap *tap,bool destroyPersistentDevices)
 {
-	if (!tap)
-		return;
-	SharedPtr<TestEthernetTap> tapp((TestEthernetTap *)tap);
 	Mutex::Lock _l1(_taps_m);
 	Mutex::Lock _l2(_tapsByMac_m);
 	Mutex::Lock _l3(_tapsByNwid_m);
-	_taps.erase(tapp);
-	_tapsByMac.erase(tapp->mac());
-	_tapsByNwid.erase(tapp->nwid());
+	if (!tap)
+		return;
+	_taps.erase(tap);
+	_tapsByMac.erase(tap->mac());
+	_tapsByNwid.erase(((TestEthernetTap *)tap)->nwid());
+	delete tap;
 }
 
 } // namespace ZeroTier

+ 9 - 14
testnet/TestEthernetTapFactory.hpp

@@ -32,11 +32,9 @@
 #include <string>
 #include <set>
 
-#include "../node/SharedPtr.hpp"
 #include "../node/EthernetTapFactory.hpp"
 #include "../node/Mutex.hpp"
 #include "../node/MAC.hpp"
-#include "../node/CMWC4096.hpp"
 #include "TestEthernetTap.hpp"
 
 namespace ZeroTier {
@@ -59,36 +57,33 @@ public:
 
 	virtual void close(EthernetTap *tap,bool destroyPersistentDevices);
 
-	inline SharedPtr<TestEthernetTap> getByMac(const MAC &mac) const
+	inline TestEthernetTap *getByMac(const MAC &mac) const
 	{
 		Mutex::Lock _l(_tapsByMac_m);
-		std::map< MAC,SharedPtr<TestEthernetTap> >::const_iterator t(_tapsByMac.find(mac));
+		std::map< MAC,TestEthernetTap * >::const_iterator t(_tapsByMac.find(mac));
 		if (t == _tapsByMac.end())
-			return SharedPtr<TestEthernetTap>();
+			return (TestEthernetTap *)0;
 		return t->second;
 	}
 
-	inline SharedPtr<TestEthernetTap> getByNwid(uint64_t nwid) const
+	inline TestEthernetTap *getByNwid(uint64_t nwid) const
 	{
 		Mutex::Lock _l(_tapsByNwid_m);
-		std::map< uint64_t,SharedPtr<TestEthernetTap> >::const_iterator t(_tapsByNwid.find(nwid));
+		std::map< uint64_t,TestEthernetTap * >::const_iterator t(_tapsByNwid.find(nwid));
 		if (t == _tapsByNwid.end())
-			return SharedPtr<TestEthernetTap>();
+			return (TestEthernetTap *)0;
 		return t->second;
 	}
 
 private:
-	std::set< SharedPtr<TestEthernetTap> > _taps;
+	std::set< EthernetTap * > _taps;
 	Mutex _taps_m;
 
-	std::map< MAC,SharedPtr<TestEthernetTap> > _tapsByMac;
+	std::map< MAC,TestEthernetTap * > _tapsByMac;
 	Mutex _tapsByMac_m;
 
-	std::map< uint64_t,SharedPtr<TestEthernetTap> > _tapsByNwid;
+	std::map< uint64_t,TestEthernetTap * > _tapsByNwid;
 	Mutex _tapsByNwid_m;
-
-	CMWC4096 _prng;
-	Mutex _prng_m;
 };
 
 } // namespace ZeroTier