Forráskód Böngészése

[v1.9.x] do not panic when loading a V2 CA certificate (#1282)

Co-authored-by: Jack Doan <[email protected]>
Nate Brown 7 hónapja
szülő
commit
2e85d138cd
5 módosított fájl, 64 hozzáadás és 39 törlés
  1. 18 10
      cert/ca.go
  2. 4 0
      cert/cert.go
  3. 29 8
      cert/cert_test.go
  4. 7 6
      cert/errors.go
  5. 6 15
      pki.go

+ 18 - 10
cert/ca.go

@@ -24,31 +24,39 @@ func NewCAPool() *NebulaCAPool {
 
 // NewCAPoolFromBytes will create a new CA pool from the provided
 // input bytes, which must be a PEM-encoded set of nebula certificates.
+// If the pool contains unsupported certificates, they will generate warnings
+// in the []error return arg.
 // If the pool contains any expired certificates, an ErrExpired will be
 // returned along with the pool. The caller must handle any such errors.
-func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, error) {
+func NewCAPoolFromBytes(caPEMs []byte) (*NebulaCAPool, []error, error) {
 	pool := NewCAPool()
 	var err error
-	var expired bool
+	var warnings []error
+	good := 0
+
 	for {
 		caPEMs, err = pool.AddCACertificate(caPEMs)
 		if errors.Is(err, ErrExpired) {
-			expired = true
-			err = nil
-		}
-		if err != nil {
-			return nil, err
+			warnings = append(warnings, err)
+		} else if errors.Is(err, ErrInvalidPEMCertificateUnsupported) {
+			warnings = append(warnings, err)
+		} else if err != nil {
+			return nil, warnings, err
+		} else {
+			// Only consider a good certificate if there were no errors present
+			good++
 		}
+
 		if len(caPEMs) == 0 || strings.TrimSpace(string(caPEMs)) == "" {
 			break
 		}
 	}
 
-	if expired {
-		return pool, ErrExpired
+	if good == 0 {
+		return nil, warnings, errors.New("no valid CA certificates present")
 	}
 
-	return pool, nil
+	return pool, warnings, nil
 }
 
 // AddCACertificate verifies a Nebula CA certificate and adds it to the pool

+ 4 - 0
cert/cert.go

@@ -28,6 +28,7 @@ const publicKeyLen = 32
 
 const (
 	CertBanner                       = "NEBULA CERTIFICATE"
+	CertificateV2Banner              = "NEBULA CERTIFICATE V2"
 	X25519PrivateKeyBanner           = "NEBULA X25519 PRIVATE KEY"
 	X25519PublicKeyBanner            = "NEBULA X25519 PUBLIC KEY"
 	EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY"
@@ -163,6 +164,9 @@ func UnmarshalNebulaCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, er
 	if p == nil {
 		return nil, r, fmt.Errorf("input did not contain a valid PEM encoded block")
 	}
+	if p.Type == CertificateV2Banner {
+		return nil, r, fmt.Errorf("%w: %s", ErrInvalidPEMCertificateUnsupported, p.Type)
+	}
 	if p.Type != CertBanner {
 		return nil, r, fmt.Errorf("bytes did not contain a proper nebula certificate banner")
 	}

+ 29 - 8
cert/cert_test.go

@@ -5,6 +5,7 @@ import (
 	"crypto/ecdsa"
 	"crypto/elliptic"
 	"crypto/rand"
+	"errors"
 	"fmt"
 	"io"
 	"net"
@@ -572,6 +573,13 @@ CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
 76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC
 IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
 -----END NEBULA CERTIFICATE-----
+`
+
+	v2 := `
+# valid PEM with the V2 header
+-----BEGIN NEBULA CERTIFICATE V2-----
+CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2
+-----END NEBULA CERTIFICATE V2-----
 `
 
 	rootCA := NebulaCertificate{
@@ -592,33 +600,46 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX
 		},
 	}
 
-	p, err := NewCAPoolFromBytes([]byte(noNewLines))
+	p, warn, err := NewCAPoolFromBytes([]byte(noNewLines))
 	assert.Nil(t, err)
+	assert.Nil(t, warn)
 	assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
 	assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
 
-	pp, err := NewCAPoolFromBytes([]byte(withNewLines))
+	pp, warn, err := NewCAPoolFromBytes([]byte(withNewLines))
 	assert.Nil(t, err)
+	assert.Nil(t, warn)
 	assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
 	assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
 
 	// expired cert, no valid certs
-	ppp, err := NewCAPoolFromBytes([]byte(expired))
-	assert.Equal(t, ErrExpired, err)
-	assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
+	ppp, warn, err := NewCAPoolFromBytes([]byte(expired))
+	assert.Error(t, err, "no valid CA certificates present")
+	assert.Len(t, warn, 1)
+	assert.Error(t, warn[0], ErrExpired)
+	assert.Nil(t, ppp)
 
 	// expired cert, with valid certs
-	pppp, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
-	assert.Equal(t, ErrExpired, err)
+	pppp, warn, err := NewCAPoolFromBytes(append([]byte(expired), noNewLines...))
+	assert.Len(t, warn, 1)
+	assert.Nil(t, err)
+	assert.Error(t, warn[0], ErrExpired)
 	assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name)
 	assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name)
 	assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired")
 	assert.Equal(t, len(pppp.CAs), 3)
 
-	ppppp, err := NewCAPoolFromBytes([]byte(p256))
+	ppppp, warn, err := NewCAPoolFromBytes([]byte(p256))
 	assert.Nil(t, err)
+	assert.Nil(t, warn)
 	assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
 	assert.Equal(t, len(ppppp.CAs), 1)
+
+	pppppp, warn, err := NewCAPoolFromBytes(append([]byte(p256), []byte(v2)...))
+	assert.Nil(t, err)
+	assert.True(t, errors.Is(warn[0], ErrInvalidPEMCertificateUnsupported))
+	assert.Equal(t, pppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name)
+	assert.Equal(t, len(pppppp.CAs), 1)
 }
 
 func appendByteSlices(b ...[]byte) []byte {

+ 7 - 6
cert/errors.go

@@ -5,10 +5,11 @@ import (
 )
 
 var (
-	ErrRootExpired       = errors.New("root certificate is expired")
-	ErrExpired           = errors.New("certificate is expired")
-	ErrNotCA             = errors.New("certificate is not a CA")
-	ErrNotSelfSigned     = errors.New("certificate is not self-signed")
-	ErrBlockListed       = errors.New("certificate is in the block list")
-	ErrSignatureMismatch = errors.New("certificate signature did not match")
+	ErrRootExpired                      = errors.New("root certificate is expired")
+	ErrExpired                          = errors.New("certificate is expired")
+	ErrNotCA                            = errors.New("certificate is not a CA")
+	ErrNotSelfSigned                    = errors.New("certificate is not self-signed")
+	ErrBlockListed                      = errors.New("certificate is in the block list")
+	ErrSignatureMismatch                = errors.New("certificate signature did not match")
+	ErrInvalidPEMCertificateUnsupported = errors.New("bytes contain an unsupported certificate format")
 )

+ 6 - 15
pki.go

@@ -223,22 +223,13 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, er
 		}
 	}
 
-	caPool, err := cert.NewCAPoolFromBytes(rawCA)
-	if errors.Is(err, cert.ErrExpired) {
-		var expired int
-		for _, crt := range caPool.CAs {
-			if crt.Expired(time.Now()) {
-				expired++
-				l.WithField("cert", crt).Warn("expired certificate present in CA pool")
-			}
-		}
-
-		if expired >= len(caPool.CAs) {
-			return nil, errors.New("no valid CA certificates present")
-		}
+	caPool, warnings, err := cert.NewCAPoolFromBytes(rawCA)
+	for _, w := range warnings {
+		l.WithError(w).Warn("parsing a CA certificate failed")
+	}
 
-	} else if err != nil {
-		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
+	if err != nil {
+		return nil, fmt.Errorf("could not create CA certificate pool: %s", err)
 	}
 
 	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {