浏览代码

try to make certificate addition/removal reloadable in some cases

JackDoan 3 天之前
父节点
当前提交
6b78397f30
共有 2 个文件被更改,包括 74 次插入47 次删除
  1. 30 10
      connection_manager.go
  2. 44 37
      pki.go

+ 30 - 10
connection_manager.go

@@ -476,8 +476,8 @@ func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
 }
 }
 
 
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
-// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
-// check and return true.
+// the certificate is no longer valid, or if we no longer have a certificate of the same version as the remote.
+// Blocklisted certificates will skip the pki.disconnect_invalid check and return true.
 func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
 	if remoteCert == nil {
@@ -488,18 +488,38 @@ func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostI
 	err := caPool.VerifyCachedCertificate(now, remoteCert)
 	err := caPool.VerifyCachedCertificate(now, remoteCert)
 	if err == nil {
 	if err == nil {
 		return false
 		return false
-	}
-
-	if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
+	} else if err == cert.ErrBlockListed { //avoiding errors.Is for speed
 		// Block listed certificates should always be disconnected
 		// Block listed certificates should always be disconnected
-		return false
+		hostinfo.logger(cm.l).WithError(err).
+			WithField("fingerprint", remoteCert.Fingerprint).
+			Info("Remote certificate is blocked, tearing down the tunnel")
+		return true
+	} else if cm.intf.disconnectInvalid.Load() {
+		hostinfo.logger(cm.l).WithError(err).
+			WithField("fingerprint", remoteCert.Fingerprint).
+			Info("Remote certificate is no longer valid, tearing down the tunnel")
+		return true
 	}
 	}
 
 
-	hostinfo.logger(cm.l).WithError(err).
-		WithField("fingerprint", remoteCert.Fingerprint).
-		Info("Remote certificate is no longer valid, tearing down the tunnel")
+	//check that we still have a cert version in common with this connection. If we do not, disconnect.
+	remoteVersion := remoteCert.Certificate.Version()
+	cs := cm.intf.pki.getCertState()
+	out := false
+	switch remoteVersion {
+	case cert.Version1:
+		out = cs.v1Cert == nil
+	case cert.Version2:
+		out = cs.v2Cert == nil
+	default:
+		out = true
+	}
 
 
-	return true
+	if out {
+		hostinfo.logger(cm.l).WithField("fingerprint", remoteCert.Fingerprint).
+			WithField("version", remoteVersion).
+			Info("We no longer have a certificate in common with remote, tearing down the tunnel")
+	}
+	return out
 }
 }
 
 
 func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
 func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {

+ 44 - 37
pki.go

@@ -100,55 +100,62 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError {
 		currentState := p.cs.Load()
 		currentState := p.cs.Load()
 		if newState.v1Cert != nil {
 		if newState.v1Cert != nil {
 			if currentState.v1Cert == nil {
 			if currentState.v1Cert == nil {
-				return util.NewContextualError("v1 certificate was added, restart required", nil, err)
+				//adding certs is fine, actually. Networks-in-common confirmed in newCertState().
+			} else {
+				// did IP in cert change? if so, don't set
+				if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
+					return util.NewContextualError(
+						"Networks in new cert was different from old",
+						m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
+						nil,
+					)
+				}
+
+				if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
+					return util.NewContextualError(
+						"Curve in new cert was different from old",
+						m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
+						nil,
+					)
+				}
 			}
 			}
-
-			// did IP in cert change? if so, don't set
-			if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) {
-				return util.NewContextualError(
-					"Networks in new cert was different from old",
-					m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()},
-					nil,
-				)
-			}
-
-			if currentState.v1Cert.Curve() != newState.v1Cert.Curve() {
-				return util.NewContextualError(
-					"Curve in new cert was different from old",
-					m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()},
-					nil,
-				)
-			}
-
-		} else if currentState.v1Cert != nil {
-			//TODO: CERT-V2 we should be able to tear this down
-			return util.NewContextualError("v1 certificate was removed, restart required", nil, err)
 		}
 		}
 
 
 		if newState.v2Cert != nil {
 		if newState.v2Cert != nil {
 			if currentState.v2Cert == nil {
 			if currentState.v2Cert == nil {
-				return util.NewContextualError("v2 certificate was added, restart required", nil, err)
+				//adding certs is fine, actually
+			} else {
+				// did IP in cert change? if so, don't set
+				if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
+					return util.NewContextualError(
+						"Networks in new cert was different from old",
+						m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
+						nil,
+					)
+				}
+
+				if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+					return util.NewContextualError(
+						"Curve in new cert was different from old",
+						m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
+						nil,
+					)
+				}
 			}
 			}
 
 
-			// did IP in cert change? if so, don't set
-			if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) {
-				return util.NewContextualError(
-					"Networks in new cert was different from old",
-					m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
-					nil,
-				)
+		} else if currentState.v2Cert != nil {
+			//newState.v1Cert is non-nil bc empty certstates aren't permitted
+			if newState.v1Cert == nil {
+				return util.NewContextualError("v1 and v2 certs are nil, this should be impossible", nil, err)
 			}
 			}
-
-			if currentState.v2Cert.Curve() != newState.v2Cert.Curve() {
+			//if we're going to v1-only, we need to make sure we didn't orphan any v2-cert vpnaddrs
+			if !slices.Equal(currentState.v2Cert.Networks(), newState.v1Cert.Networks()) {
 				return util.NewContextualError(
 				return util.NewContextualError(
-					"Curve in new cert was different from old",
-					m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()},
+					"Removing a V2 cert is not permitted unless it has identical networks to the new V1 cert",
+					m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v2Cert.Networks()},
 					nil,
 					nil,
 				)
 				)
 			}
 			}
-
-		} else if currentState.v2Cert != nil {
-			return util.NewContextualError("v2 certificate was removed, restart required", nil, err)
 		}
 		}
 
 
 		// Cipher cant be hot swapped so just leave it at what it was before
 		// Cipher cant be hot swapped so just leave it at what it was before