浏览代码

[cert-v2] nebula-cert should verify all certs (#1291)

Jack Doan 8 月之前
父节点
当前提交
3f31517018
共有 4 个文件被更改,包括 31 次插入17 次删除
  1. 1 1
      cert/ca_pool.go
  2. 1 0
      cert/errors.go
  3. 25 14
      cmd/nebula-cert/verify.go
  4. 4 2
      cmd/nebula-cert/verify_test.go

+ 1 - 1
cert/ca_pool.go

@@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) {
 		return signer, nil
 	}
 
-	return nil, fmt.Errorf("could not find ca for the certificate")
+	return nil, ErrCaNotFound
 }
 
 // GetFingerprints returns an array of trusted CA fingerprints

+ 1 - 0
cert/errors.go

@@ -17,6 +17,7 @@ var (
 	ErrInvalidPrivateKey          = errors.New("invalid private key")
 	ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
 	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
+	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
 
 	ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
 

+ 25 - 14
cmd/nebula-cert/verify.go

@@ -1,6 +1,7 @@
 package main
 
 import (
+	"errors"
 	"flag"
 	"fmt"
 	"io"
@@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCACert, err := os.ReadFile(*vf.caPath)
 	if err != nil {
-		return fmt.Errorf("error while reading ca: %s", err)
+		return fmt.Errorf("error while reading ca: %w", err)
 	}
 
 	caPool := cert.NewCAPool()
 	for {
 		rawCACert, err = caPool.AddCAFromPEM(rawCACert)
 		if err != nil {
-			return fmt.Errorf("error while adding ca cert to pool: %s", err)
+			return fmt.Errorf("error while adding ca cert to pool: %w", err)
 		}
 
 		if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" {
@@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error {
 
 	rawCert, err := os.ReadFile(*vf.certPath)
 	if err != nil {
-		return fmt.Errorf("unable to read crt; %s", err)
+		return fmt.Errorf("unable to read crt: %w", err)
 	}
-
-	c, _, err := cert.UnmarshalCertificateFromPEM(rawCert)
-	if err != nil {
-		return fmt.Errorf("error while parsing crt: %s", err)
-	}
-
-	_, err = caPool.VerifyCertificate(time.Now(), c)
-	if err != nil {
-		return err
+	var errs []error
+	for {
+		if len(rawCert) == 0 {
+			break
+		}
+		c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert)
+		if err != nil {
+			return fmt.Errorf("error while parsing crt: %w", err)
+		}
+		rawCert = extra
+		_, err = caPool.VerifyCertificate(time.Now(), c)
+		if err != nil {
+			switch {
+			case errors.Is(err, cert.ErrCaNotFound):
+				errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err))
+			default:
+				errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err))
+			}
+		}
 	}
 
-	return nil
+	return errors.Join(errs...)
 }
 
 func verifySummary() string {
@@ -80,7 +91,7 @@ func verifySummary() string {
 
 func verifyHelp(out io.Writer) {
 	vf := newVerifyFlags()
-	out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
+	_, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n"))
 	vf.set.SetOutput(out)
 	vf.set.PrintDefaults()
 }

+ 4 - 2
cmd/nebula-cert/verify_test.go

@@ -3,10 +3,12 @@ package main
 import (
 	"bytes"
 	"crypto/rand"
+	"errors"
 	"os"
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/stretchr/testify/assert"
 	"golang.org/x/crypto/ed25519"
 )
@@ -76,7 +78,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError)
+	assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
 
 	// invalid crt at path
 	ob.Reset()
@@ -106,7 +108,7 @@ func Test_verify(t *testing.T) {
 	err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
 	assert.Equal(t, "", ob.String())
 	assert.Equal(t, "", eb.String())
-	assert.EqualError(t, err, "certificate signature did not match")
+	assert.True(t, errors.Is(err, cert.ErrSignatureMismatch))
 
 	// verified cert at path
 	crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)