Browse Source

Basic plumbing for authentication requirement and piping through of URL information.

Adam Ierymenko 4 years ago
parent
commit
b270d527f4

+ 2 - 0
controller/DB.cpp

@@ -67,6 +67,8 @@ void DB::initMember(nlohmann::json &member)
 	if (!member.count("lastAuthorizedTime")) member["lastAuthorizedTime"] = 0ULL;
 	if (!member.count("lastAuthorizedTime")) member["lastAuthorizedTime"] = 0ULL;
 	if (!member.count("lastAuthorizedCredentialType")) member["lastAuthorizedCredentialType"] = nlohmann::json();
 	if (!member.count("lastAuthorizedCredentialType")) member["lastAuthorizedCredentialType"] = nlohmann::json();
 	if (!member.count("lastAuthorizedCredential")) member["lastAuthorizedCredential"] = nlohmann::json();
 	if (!member.count("lastAuthorizedCredential")) member["lastAuthorizedCredential"] = nlohmann::json();
+	if (!member.count("authenticationExpiryTime")) member["authenticationExpiryTime"] = -1LL;
+	if (!member.count("authenticationURL")) member["authenticationURL"] = nlohmann::json();
 	if (!member.count("vMajor")) member["vMajor"] = -1;
 	if (!member.count("vMajor")) member["vMajor"] = -1;
 	if (!member.count("vMinor")) member["vMinor"] = -1;
 	if (!member.count("vMinor")) member["vMinor"] = -1;
 	if (!member.count("vRev")) member["vRev"] = -1;
 	if (!member.count("vRev")) member["vRev"] = -1;

+ 28 - 6
controller/EmbeddedNetworkController.cpp

@@ -466,6 +466,14 @@ EmbeddedNetworkController::EmbeddedNetworkController(Node *node,const char *ztPa
 	_db(this),
 	_db(this),
 	_rc(rc)
 	_rc(rc)
 {
 {
+	memset(_ssoPsk, 0, sizeof(_ssoPsk));
+	char *const ssoPskHex = getenv("ZT_SSO_PSK");
+	if (ssoPskHex) {
+		// SECURITY: note that ssoPskHex will always be null-terminated if libc acatually
+		// returns something non-NULL. If the hex encodes something shorter than 48 bytes,
+		// it will be padded at the end with zeroes. If longer, it'll be truncated.
+		Utils::unhex(ssoPskHex, _ssoPsk, sizeof(_ssoPsk));
+	}
 }
 }
 
 
 EmbeddedNetworkController::~EmbeddedNetworkController()
 EmbeddedNetworkController::~EmbeddedNetworkController()
@@ -1248,7 +1256,7 @@ void EmbeddedNetworkController::_request(
 	Utils::hex(nwid,nwids);
 	Utils::hex(nwid,nwids);
 	_db.get(nwid,network,identity.address().toInt(),member,ns);
 	_db.get(nwid,network,identity.address().toInt(),member,ns);
 	if ((!network.is_object())||(network.empty())) {
 	if ((!network.is_object())||(network.empty())) {
-		_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_OBJECT_NOT_FOUND);
+		_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_OBJECT_NOT_FOUND, nullptr, 0);
 		return;
 		return;
 	}
 	}
 	const bool newMember = ((!member.is_object())||(member.empty()));
 	const bool newMember = ((!member.is_object())||(member.empty()));
@@ -1262,11 +1270,11 @@ void EmbeddedNetworkController::_request(
 			// known member.
 			// known member.
 			try {
 			try {
 				if (Identity(haveIdStr.c_str()) != identity) {
 				if (Identity(haveIdStr.c_str()) != identity) {
-					_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_ACCESS_DENIED);
+					_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_ACCESS_DENIED, nullptr, 0);
 					return;
 					return;
 				}
 				}
 			} catch ( ... ) {
 			} catch ( ... ) {
-				_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_ACCESS_DENIED);
+				_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_ACCESS_DENIED, nullptr, 0);
 				return;
 				return;
 			}
 			}
 		} else {
 		} else {
@@ -1348,16 +1356,30 @@ void EmbeddedNetworkController::_request(
 				ms.identity = identity;
 				ms.identity = identity;
 			}
 			}
 		}
 		}
+
+		const int64_t authenticationExpiryTime = member["authenticationExpiryTime"];
+		if ((authenticationExpiryTime >= 0)&&(authenticationExpiryTime < now)) {
+			const std::string authenticationURL = member["authenticationURL"];
+			if (authenticationURL.empty()) {
+				_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, nullptr, 0);
+				return;
+			} else {
+				Dictionary<1024> authInfo;
+				authInfo.add("aU", authenticationURL.c_str());
+				_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED, authInfo.data(), authInfo.sizeBytes());
+				return;
+			}
+		}
 	} else {
 	} else {
 		// If they are not authorized, STOP!
 		// If they are not authorized, STOP!
 		DB::cleanMember(member);
 		DB::cleanMember(member);
 		_db.save(member,true);
 		_db.save(member,true);
-		_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_ACCESS_DENIED);
+		_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_ACCESS_DENIED, nullptr, 0);
 		return;
 		return;
 	}
 	}
 
 
 	// -------------------------------------------------------------------------
 	// -------------------------------------------------------------------------
-	// If we made it this far, they are authorized.
+	// If we made it this far, they are authorized (and authenticated).
 	// -------------------------------------------------------------------------
 	// -------------------------------------------------------------------------
 
 
 	int64_t credentialtmd = ZT_NETWORKCONFIG_DEFAULT_CREDENTIAL_TIME_MAX_MAX_DELTA;
 	int64_t credentialtmd = ZT_NETWORKCONFIG_DEFAULT_CREDENTIAL_TIME_MAX_MAX_DELTA;
@@ -1734,7 +1756,7 @@ void EmbeddedNetworkController::_request(
 	if (com.sign(_signingId)) {
 	if (com.sign(_signingId)) {
 		nc->com = com;
 		nc->com = com;
 	} else {
 	} else {
-		_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_INTERNAL_SERVER_ERROR);
+		_sender->ncSendError(nwid,requestPacketId,identity.address(),NetworkController::NC_ERROR_INTERNAL_SERVER_ERROR, nullptr, 0);
 		return;
 		return;
 	}
 	}
 
 

+ 1 - 0
controller/EmbeddedNetworkController.hpp

@@ -140,6 +140,7 @@ private:
 	Identity _signingId;
 	Identity _signingId;
 	std::string _signingIdAddressString;
 	std::string _signingIdAddressString;
 	NetworkController::Sender *_sender;
 	NetworkController::Sender *_sender;
+	uint8_t _ssoPsk[48];
 
 
 	DBMirrorSet _db;
 	DBMirrorSet _db;
 	BlockingQueue< _RQEntry * > _queue;
 	BlockingQueue< _RQEntry * > _queue;

+ 16 - 1
include/ZeroTierOne.h

@@ -820,7 +820,12 @@ enum ZT_VirtualNetworkStatus
 	/**
 	/**
 	 * ZeroTier core version too old
 	 * ZeroTier core version too old
 	 */
 	 */
-	ZT_NETWORK_STATUS_CLIENT_TOO_OLD = 5
+	ZT_NETWORK_STATUS_CLIENT_TOO_OLD = 5,
+
+	/**
+	 * External authentication is required (e.g. SSO)
+	 */
+	ZT_NETWORK_STATUS_AUTHENTICATION_REQUIRED = 6
 };
 };
 
 
 /**
 /**
@@ -1339,6 +1344,16 @@ typedef struct
 	 * Network specific DNS configuration
 	 * Network specific DNS configuration
 	 */
 	 */
 	ZT_VirtualNetworkDNS dns;
 	ZT_VirtualNetworkDNS dns;
+
+	/**
+	 * If the status us AUTHENTICATION_REQUIRED, this may contain a URL for authentication.
+	 */
+	char authenticationURL[256];
+
+	/**
+	 * Time that current authentication expires or -1 if external authentication is not required.
+	 */
+	int64_t authenticationExpiryTime;
 } ZT_VirtualNetworkConfig;
 } ZT_VirtualNetworkConfig;
 
 
 /**
 /**

+ 3 - 0
node/Network.cpp

@@ -1429,6 +1429,9 @@ void Network::_externalConfig(ZT_VirtualNetworkConfig *ec) const
 	}
 	}
 
 
 	memcpy(&ec->dns, &_config.dns, sizeof(ZT_VirtualNetworkDNS));
 	memcpy(&ec->dns, &_config.dns, sizeof(ZT_VirtualNetworkDNS));
+
+	Utils::scopy(ec->authenticationURL, sizeof(ec->authenticationURL), _config.authenticationURL);
+	ec->authenticationExpiryTime = _config.authenticationExpiryTime;
 }
 }
 
 
 void Network::_sendUpdatesToMembers(void *tPtr,const MulticastGroup *const newMulticastGroup)
 void Network::_sendUpdatesToMembers(void *tPtr,const MulticastGroup *const newMulticastGroup)

+ 13 - 1
node/Network.hpp

@@ -220,6 +220,16 @@ public:
 		_netconfFailure = NETCONF_FAILURE_NOT_FOUND;
 		_netconfFailure = NETCONF_FAILURE_NOT_FOUND;
 	}
 	}
 
 
+	/**
+	 * Set netconf failure to 'authentication required' possibly with an authorization URL
+	 */
+	inline void setAuthenticationRequired(const char *url)
+	{
+		Mutex::Lock _l(_lock);
+		_netconfFailure = NETCONF_FAILURE_AUTHENTICATION_REQUIRED;
+		_authorizationURL = (url) ? url : "";
+	}
+
 	/**
 	/**
 	 * Causes this network to request an updated configuration from its master node now
 	 * Causes this network to request an updated configuration from its master node now
 	 *
 	 *
@@ -435,9 +445,11 @@ private:
 		NETCONF_FAILURE_NONE,
 		NETCONF_FAILURE_NONE,
 		NETCONF_FAILURE_ACCESS_DENIED,
 		NETCONF_FAILURE_ACCESS_DENIED,
 		NETCONF_FAILURE_NOT_FOUND,
 		NETCONF_FAILURE_NOT_FOUND,
-		NETCONF_FAILURE_INIT_FAILED
+		NETCONF_FAILURE_INIT_FAILED,
+		NETCONF_FAILURE_AUTHENTICATION_REQUIRED
 	} _netconfFailure;
 	} _netconfFailure;
 	int _portError; // return value from port config callback
 	int _portError; // return value from port config callback
+	std::string _authorizationURL;
 
 
 	Hashtable<Address,Membership> _memberships;
 	Hashtable<Address,Membership> _memberships;
 
 

+ 14 - 0
node/NetworkConfig.cpp

@@ -182,6 +182,13 @@ bool NetworkConfig::toDictionary(Dictionary<ZT_NETWORKCONFIG_DICT_CAPACITY> &d,b
 			if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_DNS,*tmp)) return false;
 			if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_DNS,*tmp)) return false;
 		}
 		}
 
 
+		if (this->authenticationURL[0]) {
+			if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_AUTHENTICATION_URL, this->authenticationURL)) return false;
+		}
+		if (this->authenticationExpiryTime >= 0) {
+			if (!d.add(ZT_NETWORKCONFIG_DICT_KEY_AUTHENTICATION_EXPIRY_TIME, this->authenticationExpiryTime)) return false;
+		}
+
 		delete tmp;
 		delete tmp;
 	} catch ( ... ) {
 	} catch ( ... ) {
 		delete tmp;
 		delete tmp;
@@ -365,6 +372,13 @@ bool NetworkConfig::fromDictionary(const Dictionary<ZT_NETWORKCONFIG_DICT_CAPACI
 				unsigned int p = 0;
 				unsigned int p = 0;
 				DNS::deserializeDNS(*tmp, p, &dns);
 				DNS::deserializeDNS(*tmp, p, &dns);
 			}
 			}
+
+			if (d.get(ZT_NETWORKCONFIG_DICT_KEY_AUTHENTICATION_URL, this->authenticationURL, (unsigned int)sizeof(this->authenticationURL)) > 0) {
+				this->authenticationURL[sizeof(this->authenticationURL) - 1] = 0; // ensure null terminated
+			} else {
+				this->authenticationURL[0] = 0;
+			}
+			this->authenticationExpiryTime = d.getI(ZT_NETWORKCONFIG_DICT_KEY_AUTHENTICATION_EXPIRY_TIME, -1);
 		}
 		}
 
 
 		//printf("~~~\n%s\n~~~\n",d.data());
 		//printf("~~~\n%s\n~~~\n",d.data());

+ 14 - 0
node/NetworkConfig.hpp

@@ -178,6 +178,10 @@ namespace ZeroTier {
 #define ZT_NETWORKCONFIG_DICT_KEY_CERTIFICATES_OF_OWNERSHIP "COO"
 #define ZT_NETWORKCONFIG_DICT_KEY_CERTIFICATES_OF_OWNERSHIP "COO"
 // dns (binary blobs)
 // dns (binary blobs)
 #define ZT_NETWORKCONFIG_DICT_KEY_DNS "DNS"
 #define ZT_NETWORKCONFIG_DICT_KEY_DNS "DNS"
+// authentication URL
+#define ZT_NETWORKCONFIG_DICT_KEY_AUTHENTICATION_URL "aurl"
+// authentication expiry
+#define ZT_NETWORKCONFIG_DICT_KEY_AUTHENTICATION_EXPIRY_TIME "aexpt"
 
 
 // Legacy fields -- these are obsoleted but are included when older clients query
 // Legacy fields -- these are obsoleted but are included when older clients query
 
 
@@ -604,6 +608,16 @@ public:
 	 * ZT pushed DNS configuration
 	 * ZT pushed DNS configuration
 	 */
 	 */
 	ZT_VirtualNetworkDNS dns;
 	ZT_VirtualNetworkDNS dns;
+
+	/**
+	 * Authentication URL if authentication is required
+	 */
+	char authenticationURL[256];
+
+	/**
+	 * Time current authentication expires or -1 if external authentication is disabled
+	 */
+	int64_t authenticationExpiryTime;
 };
 };
 
 
 } // namespace ZeroTier
 } // namespace ZeroTier

+ 8 - 2
node/NetworkController.hpp

@@ -38,7 +38,8 @@ public:
 		NC_ERROR_NONE = 0,
 		NC_ERROR_NONE = 0,
 		NC_ERROR_OBJECT_NOT_FOUND = 1,
 		NC_ERROR_OBJECT_NOT_FOUND = 1,
 		NC_ERROR_ACCESS_DENIED = 2,
 		NC_ERROR_ACCESS_DENIED = 2,
-		NC_ERROR_INTERNAL_SERVER_ERROR = 3
+		NC_ERROR_INTERNAL_SERVER_ERROR = 3,
+		NC_ERROR_AUTHENTICATION_REQUIRED = 4
 	};
 	};
 
 
 	/**
 	/**
@@ -69,12 +70,17 @@ public:
 		/**
 		/**
 		 * Send a network configuration request error
 		 * Send a network configuration request error
 		 *
 		 *
+		 * If errorData/errorDataSize are provided they must point to a valid serialized
+		 * Dictionary containing error data. They can be null/zero if not specified.
+		 * 
 		 * @param nwid Network ID
 		 * @param nwid Network ID
 		 * @param requestPacketId Request packet ID or 0 if none
 		 * @param requestPacketId Request packet ID or 0 if none
 		 * @param destination Destination peer Address
 		 * @param destination Destination peer Address
 		 * @param errorCode Error code
 		 * @param errorCode Error code
+		 * @param errorData Data associated with error or NULL if none
+		 * @param errorDataSize Size of errorData in bytes
 		 */
 		 */
-		virtual void ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &destination,NetworkController::ErrorCode errorCode) = 0;
+		virtual void ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &destination,NetworkController::ErrorCode errorCode, const void *errorData, unsigned int errorDataSize) = 0;
 	};
 	};
 
 
 	NetworkController() {}
 	NetworkController() {}

+ 12 - 1
node/Node.cpp

@@ -731,7 +731,7 @@ void Node::ncSendRevocation(const Address &destination,const Revocation &rev)
 	}
 	}
 }
 }
 
 
-void Node::ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &destination,NetworkController::ErrorCode errorCode)
+void Node::ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &destination,NetworkController::ErrorCode errorCode, const void *errorData, unsigned int errorDataSize)
 {
 {
 	if (destination == RR->identity.address()) {
 	if (destination == RR->identity.address()) {
 		SharedPtr<Network> n(network(nwid));
 		SharedPtr<Network> n(network(nwid));
@@ -744,6 +744,9 @@ void Node::ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &des
 			case NetworkController::NC_ERROR_ACCESS_DENIED:
 			case NetworkController::NC_ERROR_ACCESS_DENIED:
 				n->setAccessDenied();
 				n->setAccessDenied();
 				break;
 				break;
+			case NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED: {
+			}
+				break;
 
 
 			default: break;
 			default: break;
 		}
 		}
@@ -760,8 +763,16 @@ void Node::ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &des
 			case NetworkController::NC_ERROR_ACCESS_DENIED:
 			case NetworkController::NC_ERROR_ACCESS_DENIED:
 				outp.append((unsigned char)Packet::ERROR_NETWORK_ACCESS_DENIED_);
 				outp.append((unsigned char)Packet::ERROR_NETWORK_ACCESS_DENIED_);
 				break;
 				break;
+			case NetworkController::NC_ERROR_AUTHENTICATION_REQUIRED:
+				outp.append((unsigned char)Packet::ERROR_NETWORK_AUTHENTICATION_REQUIRED);
+				break;
 		}
 		}
+
 		outp.append(nwid);
 		outp.append(nwid);
+
+		if ((errorData)&&(errorDataSize > 0))
+			outp.append(errorData, errorDataSize);
+
 		RR->sw->send((void *)0,outp,true);
 		RR->sw->send((void *)0,outp,true);
 	} // else we can't send an ERROR() in response to nothing, so discard
 	} // else we can't send an ERROR() in response to nothing, so discard
 }
 }

+ 1 - 1
node/Node.hpp

@@ -245,7 +245,7 @@ public:
 
 
 	virtual void ncSendConfig(uint64_t nwid,uint64_t requestPacketId,const Address &destination,const NetworkConfig &nc,bool sendLegacyFormatConfig);
 	virtual void ncSendConfig(uint64_t nwid,uint64_t requestPacketId,const Address &destination,const NetworkConfig &nc,bool sendLegacyFormatConfig);
 	virtual void ncSendRevocation(const Address &destination,const Revocation &rev);
 	virtual void ncSendRevocation(const Address &destination,const Revocation &rev);
-	virtual void ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &destination,NetworkController::ErrorCode errorCode);
+	virtual void ncSendError(uint64_t nwid,uint64_t requestPacketId,const Address &destination,NetworkController::ErrorCode errorCode, const void *errorData, unsigned int errorDataSize);
 
 
 	inline const Address &remoteTraceTarget() const { return _remoteTraceTarget; }
 	inline const Address &remoteTraceTarget() const { return _remoteTraceTarget; }
 	inline Trace::Level remoteTraceLevel() const { return _remoteTraceLevel; }
 	inline Trace::Level remoteTraceLevel() const { return _remoteTraceLevel; }

+ 10 - 1
node/Packet.hpp

@@ -792,6 +792,12 @@ public:
 		 *
 		 *
 		 * ERROR response payload:
 		 * ERROR response payload:
 		 *   <[8] 64-bit network ID>
 		 *   <[8] 64-bit network ID>
+     *   <[2] 16-bit length of error-related data (optional)>
+     *   <[...] error-related data (optional)>
+     * 
+     * Error related data is a Dictionary containing things like a URL
+     * for authentication or a human-readable error message, and is
+     * optional and may be absent or empty.
 		 */
 		 */
 		VERB_NETWORK_CONFIG_REQUEST = 0x0b,
 		VERB_NETWORK_CONFIG_REQUEST = 0x0b,
 
 
@@ -1076,7 +1082,10 @@ public:
 		ERROR_NETWORK_ACCESS_DENIED_ = 0x07, /* extra _ at end to avoid Windows name conflict */
 		ERROR_NETWORK_ACCESS_DENIED_ = 0x07, /* extra _ at end to avoid Windows name conflict */
 
 
 		/* Multicasts to this group are not wanted */
 		/* Multicasts to this group are not wanted */
-		ERROR_UNWANTED_MULTICAST = 0x08
+		ERROR_UNWANTED_MULTICAST = 0x08,
+
+    /* Network requires external or 2FA authentication (e.g. SSO). */
+    ERROR_NETWORK_AUTHENTICATION_REQUIRED = 0x09
 	};
 	};
 
 
 	template<unsigned int C2>
 	template<unsigned int C2>

+ 2 - 0
service/OneService.cpp

@@ -251,6 +251,8 @@ static void _networkToJson(nlohmann::json &nj,const ZT_VirtualNetworkConfig *nc,
 	}
 	}
 	nj["dns"] = m;
 	nj["dns"] = m;
 
 
+	nj["authenticationURL"] = nc->authenticationURL;
+	nj["authenticationExpiryTime"] = nc->authenticationExpiryTime;
 }
 }
 
 
 static void _peerToJson(nlohmann::json &pj,const ZT_Peer *peer)
 static void _peerToJson(nlohmann::json &pj,const ZT_Peer *peer)