Browse Source

Merge branch 'dev' into edge

Adam Ierymenko 6 years ago
parent
commit
ce67abc32f
6 changed files with 92 additions and 30 deletions
  1. 1 0
      controller/DB.cpp
  2. 14 1
      controller/DB.hpp
  3. 52 2
      controller/DBMirrorSet.cpp
  4. 3 0
      controller/DBMirrorSet.hpp
  5. 16 24
      controller/LFDB.cpp
  6. 6 3
      controller/PostgreSQL.cpp

+ 1 - 0
controller/DB.cpp

@@ -52,6 +52,7 @@ void DB::initNetwork(nlohmann::json &network)
 	if (!network.count("mtu")) network["mtu"] = ZT_DEFAULT_MTU;
 	if (!network.count("remoteTraceTarget")) network["remoteTraceTarget"] = nlohmann::json();
 	if (!network.count("removeTraceLevel")) network["remoteTraceLevel"] = 0;
+	if (!network.count("rulesSource")) network["rulesSource"] = "";
 	if (!network.count("rules")) {
 		// If unspecified, rules are set to allow anything and behave like a flat L2 segment
 		network["rules"] = {{

+ 14 - 1
controller/DB.hpp

@@ -100,6 +100,19 @@ public:
 
 	void networks(std::set<uint64_t> &networks);
 
+	template<typename F>
+	inline void each(F f)
+	{
+		nlohmann::json nullJson;
+		std::lock_guard<std::mutex> lck(_networks_l);
+		for(auto nw=_networks.begin();nw!=_networks.end();++nw) {
+			f(nw->first,nw->second->config,0,nullJson); // first provide network with 0 for member ID
+			for(auto m=nw->second->members.begin();m!=nw->second->members.end();++m) {
+				f(nw->first,nw->second->config,m->first,m->second);
+			}
+		}
+	}
+
 	virtual bool save(nlohmann::json &record,bool notifyListeners) = 0;
 
 	virtual void eraseNetwork(const uint64_t networkId) = 0;
@@ -114,7 +127,7 @@ public:
 	}
 
 protected:
-	inline bool _compareRecords(const nlohmann::json &a,const nlohmann::json &b)
+	static inline bool _compareRecords(const nlohmann::json &a,const nlohmann::json &b)
 	{
 		if (a.is_object() == b.is_object()) {
 			if (a.is_object()) {

+ 52 - 2
controller/DBMirrorSet.cpp

@@ -29,12 +29,62 @@
 namespace ZeroTier {
 
 DBMirrorSet::DBMirrorSet(DB::ChangeListener *listener) :
-	_listener(listener)
-{
+	_listener(listener),
+	_running(true)
+{
+	_syncCheckerThread = std::thread([this]() {
+		for(;;) {
+			for(int i=0;i<120;++i) { // 1 minute delay between checks
+				if (!_running)
+					return;
+				std::this_thread::sleep_for(std::chrono::milliseconds(500));
+			}
+
+			std::vector< std::shared_ptr<DB> > dbs;
+			{
+				std::lock_guard<std::mutex> l(_dbs_l);
+				if (_dbs.size() <= 1)
+					continue; // no need to do this if there's only one DB, so skip the iteration
+				dbs = _dbs;
+			}
+
+			for(auto db=dbs.begin();db!=dbs.end();++db) {
+				(*db)->each([this,&dbs,&db](uint64_t networkId,const nlohmann::json &network,uint64_t memberId,const nlohmann::json &member) {
+					try {
+						if (network.is_object()) {
+							if (memberId == 0) {
+								for(auto db2=dbs.begin();db2!=dbs.end();++db2) {
+									if (db->get() != db2->get()) {
+										nlohmann::json nw2;
+										if ((!(*db2)->get(networkId,nw2))||((nw2.is_object())&&(OSUtils::jsonInt(nw2["revision"],0) < OSUtils::jsonInt(network["revision"],0)))) {
+											nw2 = network;
+											(*db2)->save(nw2,false);
+										}
+									}
+								}
+							} else if (member.is_object()) {
+								for(auto db2=dbs.begin();db2!=dbs.end();++db2) {
+									if (db->get() != db2->get()) {
+										nlohmann::json nw2,m2;
+										if ((!(*db2)->get(networkId,nw2,memberId,m2))||((m2.is_object())&&(OSUtils::jsonInt(m2["revision"],0) < OSUtils::jsonInt(member["revision"],0)))) {
+											m2 = member;
+											(*db2)->save(m2,false);
+										}
+									}
+								}
+							}
+						}
+					} catch ( ... ) {} // skip entries that generate JSON errors
+				});
+			}
+		}
+	});
 }
 
 DBMirrorSet::~DBMirrorSet()
 {
+	_running = false;
+	_syncCheckerThread.join();
 }
 
 bool DBMirrorSet::hasNetwork(const uint64_t networkId) const

+ 3 - 0
controller/DBMirrorSet.hpp

@@ -33,6 +33,7 @@
 #include <memory>
 #include <mutex>
 #include <set>
+#include <thread>
 
 namespace ZeroTier {
 
@@ -72,6 +73,8 @@ public:
 
 private:
 	DB::ChangeListener *const _listener;
+	std::atomic_bool _running;
+	std::thread _syncCheckerThread;
 	std::vector< std::shared_ptr< DB > > _dbs;
 	mutable std::mutex _dbs_l;
 };

+ 16 - 24
controller/LFDB.cpp

@@ -220,20 +220,16 @@ LFDB::LFDB(const Identity &myId,const char *path,const char *lfOwnerPrivate,cons
 												const uint64_t id = Utils::hexStrToU64(idstr.c_str());
 												if ((id >> 24) == controllerAddressInt) { // sanity check
 
-													std::lock_guard<std::mutex> sl(_state_l);
-													_NetworkState &ns = _state[id];
-													if (!ns.dirty) {
-														nlohmann::json oldNetwork;
-														if ((timeRangeStart > 0)&&(get(id,oldNetwork))) {
-															const uint64_t revision = network["revision"];
-															const uint64_t prevRevision = oldNetwork["revision"];
-															if (prevRevision < revision) {
-																_networkChanged(oldNetwork,network,timeRangeStart > 0);
-															}
-														} else {
-															nlohmann::json nullJson;
-															_networkChanged(nullJson,network,timeRangeStart > 0);
+													nlohmann::json oldNetwork;
+													if ((timeRangeStart > 0)&&(get(id,oldNetwork))) {
+														const uint64_t revision = network["revision"];
+														const uint64_t prevRevision = oldNetwork["revision"];
+														if (prevRevision < revision) {
+															_networkChanged(oldNetwork,network,timeRangeStart > 0);
 														}
+													} else {
+														nlohmann::json nullJson;
+														_networkChanged(nullJson,network,timeRangeStart > 0);
 													}
 
 												}
@@ -294,17 +290,13 @@ LFDB::LFDB(const Identity &myId,const char *path,const char *lfOwnerPrivate,cons
 												const uint64_t id = Utils::hexStrToU64(idstr.c_str());
 												if ((id)&&((nwid >> 24) == controllerAddressInt)) { // sanity check
 
-													std::lock_guard<std::mutex> sl(_state_l);
-													auto ns = _state.find(nwid);
-													if ((ns == _state.end())||(!ns->second.members[id].dirty)) {
-														nlohmann::json network,oldMember;
-														if ((timeRangeStart > 0)&&(get(nwid,network,id,oldMember))) {
-															const uint64_t revision = member["revision"];
-															const uint64_t prevRevision = oldMember["revision"];
-															if (prevRevision < revision)
-																_memberChanged(oldMember,member,timeRangeStart > 0);
-														}
-													} else {
+													nlohmann::json network,oldMember;
+													if ((timeRangeStart > 0)&&(get(nwid,network,id,oldMember))) {
+														const uint64_t revision = member["revision"];
+														const uint64_t prevRevision = oldMember["revision"];
+														if (prevRevision < revision)
+															_memberChanged(oldMember,member,timeRangeStart > 0);
+													} else if (hasNetwork(nwid)) {
 														nlohmann::json nullJson;
 														_memberChanged(nullJson,member,timeRangeStart > 0);
 													}

+ 6 - 3
controller/PostgreSQL.cpp

@@ -1047,7 +1047,10 @@ void PostgreSQL::commitThread()
 					if (!(*config)["remoteTraceTarget"].is_null()) {
 						remoteTraceTarget = (*config)["remoteTraceTarget"];
 					}
-					std::string rulesSource = (*config)["rulesSource"];
+					std::string rulesSource;
+					if ((*config)["rulesSource"].is_string()) {
+						rulesSource = (*config)["rulesSource"];
+					}
 					std::string caps = OSUtils::jsonDump((*config)["capabilitles"], -1);
 					std::string now = std::to_string(OSUtils::now());
 					std::string mtu = std::to_string((int)(*config)["mtu"]);
@@ -1081,13 +1084,13 @@ void PostgreSQL::commitThread()
 
 					PGresult *res = PQexecParams(conn,
 						"INSERT INTO ztc_network (id, controller_id, capabilities, enable_broadcast, "
-						"last_updated, mtu, multicast_limit, name, private, "
+						"last_modified, mtu, multicast_limit, name, private, "
 						"remote_trace_level, remote_trace_target, rules, rules_source, "
 						"tags, v4_assign_mode, v6_assign_mode) VALUES ("
 						"$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) "
 						"ON CONFLICT (id) DO UPDATE set controller_id = EXCLUDED.controller_id, "
 						"capabilities = EXCLUDED.capabilities, enable_broadcast = EXCLUDED.enable_broadcast, "
-						"last_updated = EXCLUDED.last_updated, mtu = EXCLUDED.mtu, "
+						"last_modified = EXCLUDED.last_modified, mtu = EXCLUDED.mtu, "
 						"multicast_limit = EXCLUDED.multicast_limit, name = EXCLUDED.name, "
 						"private = EXCLUDED.private, remote_trace_level = EXCLUDED.remote_trace_level, "
 						"remote_trace_target = EXCLUDED.remote_trace_target, rules = EXCLUDED.rules, "