Kaynağa Gözat

Always disconnect block listed hosts (#858)

Nate Brown 2 yıl önce
ebeveyn
işleme
702e1c59bd
4 değiştirilmiş dosya ile 22 ekleme ve 14 silme
  1. 4 4
      cert/cert.go
  2. 1 1
      cert/cert_test.go
  3. 9 4
      cert/errors.go
  4. 8 5
      connection_manager.go

+ 4 - 4
cert/cert.go

@@ -393,7 +393,7 @@ func (nc *NebulaCertificate) Expired(t time.Time) bool {
 // Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc)
 func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error) {
 	if ncp.IsBlocklisted(nc) {
-		return false, fmt.Errorf("certificate has been blocked")
+		return false, ErrBlockListed
 	}
 
 	signer, err := ncp.GetCAForCert(nc)
@@ -402,15 +402,15 @@ func (nc *NebulaCertificate) Verify(t time.Time, ncp *NebulaCAPool) (bool, error
 	}
 
 	if signer.Expired(t) {
-		return false, fmt.Errorf("root certificate is expired")
+		return false, ErrRootExpired
 	}
 
 	if nc.Expired(t) {
-		return false, fmt.Errorf("certificate is expired")
+		return false, ErrExpired
 	}
 
 	if !nc.CheckSignature(signer.Details.PublicKey) {
-		return false, fmt.Errorf("certificate signature did not match")
+		return false, ErrSignatureMismatch
 	}
 
 	if err := nc.CheckRootConstrains(signer); err != nil {

+ 1 - 1
cert/cert_test.go

@@ -177,7 +177,7 @@ func TestNebulaCertificate_Verify(t *testing.T) {
 
 	v, err := c.Verify(time.Now(), caPool)
 	assert.False(t, v)
-	assert.EqualError(t, err, "certificate has been blocked")
+	assert.EqualError(t, err, "certificate is in the block list")
 
 	caPool.ResetCertBlocklist()
 	v, err = c.Verify(time.Now(), caPool)

+ 9 - 4
cert/errors.go

@@ -1,9 +1,14 @@
 package cert
 
-import "errors"
+import (
+	"errors"
+)
 
 var (
-	ErrExpired       = errors.New("certificate is expired")
-	ErrNotCA         = errors.New("certificate is not a CA")
-	ErrNotSelfSigned = errors.New("certificate is not self-signed")
+	ErrRootExpired       = errors.New("root certificate is expired")
+	ErrExpired           = errors.New("certificate is expired")
+	ErrNotCA             = errors.New("certificate is not a CA")
+	ErrNotSelfSigned     = errors.New("certificate is not self-signed")
+	ErrBlockListed       = errors.New("certificate is in the block list")
+	ErrSignatureMismatch = errors.New("certificate signature did not match")
 )

+ 8 - 5
connection_manager.go

@@ -8,6 +8,7 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
@@ -419,12 +420,9 @@ func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 }
 
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
-// the certificate is no longer valid
+// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
+// check and return true.
 func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
-	if !n.intf.disconnectInvalid {
-		return false
-	}
-
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
 		return false
@@ -435,6 +433,11 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
+	if !n.intf.disconnectInvalid && err != cert.ErrBlockListed {
+		// Block listed certificates should always be disconnected
+		return false
+	}
+
 	fingerprint, _ := remoteCert.Sha256Sum()
 	hostinfo.logger(n.l).WithError(err).
 		WithField("fingerprint", fingerprint).