浏览代码

enforce certificate correctness in TBSCertificate.SignWith (#1266)

* enforce certificate correctness in TBSCertificate.SignWith

* check length, not nil

* Address review comments

* github hates me

---------

Co-authored-by: Nate Brown <[email protected]>
Co-authored-by: Jack Doan <[email protected]>
Jack Doan 8 月之前
父节点
当前提交
8704047395
共有 8 个文件被更改,包括 188 次插入21 次删除
  1. 1 0
      cert/cert.go
  2. 57 0
      cert/cert_v1.go
  3. 82 7
      cert/cert_v2.go
  4. 14 2
      cert/errors.go
  5. 4 0
      cert/helper_test.go
  6. 1 0
      cert/pem.go
  7. 14 7
      cert/sign.go
  8. 15 5
      cmd/nebula-cert/print_test.go

+ 1 - 0
cert/cert.go

@@ -143,6 +143,7 @@ func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, cu
 	var err error
 
 	switch v {
+	// Implementations must ensure the result is a valid cert!
 	case VersionPre1, Version1:
 		c, err = unmarshalCertificateV1(b, publicKey)
 	case Version2:

+ 57 - 0
cert/cert_v1.go

@@ -317,6 +317,58 @@ func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error {
 		issuer:         t.issuer,
 	}
 
+	return c.validate()
+}
+
+func (c *certificateV1) validate() error {
+	// Empty names are allowed
+
+	if len(c.details.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	// Original v1 rules allowed multiple networks to be present but ignored all but the first one.
+	// Continue to allow this behavior
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificates must contain exactly one network")
+	}
+
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 networks: %v", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Is6() {
+			return NewErrInvalidCertificateProperties("certificate may not contain IPv6 unsafe networks: %v", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+	}
+
+	// v1 doesn't bother with sort order or uniqueness of networks or unsafe networks.
+	// We can't modify the unmarshalled data because verification requires re-marshalling and a re-ordered
+	// unsafe networks would result in a different signature.
+
 	return nil
 }
 
@@ -404,6 +456,11 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error)
 		}
 	}
 
+	err = nc.validate()
+	if err != nil {
+		return nil, err
+	}
+
 	return &nc, nil
 }
 

+ 82 - 7
cert/cert_v2.go

@@ -65,8 +65,8 @@ type certificateV2 struct {
 
 type detailsV2 struct {
 	name           string
-	networks       []netip.Prefix
-	unsafeNetworks []netip.Prefix
+	networks       []netip.Prefix // MUST BE SORTED
+	unsafeNetworks []netip.Prefix // MUST BE SORTED
 	groups         []string
 	isCA           bool
 	notBefore      time.Time
@@ -376,6 +376,77 @@ func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error {
 	}
 	c.curve = t.Curve
 	c.publicKey = t.PublicKey
+	return c.validate()
+}
+
+func (c *certificateV2) validate() error {
+	// Empty names are allowed
+
+	if len(c.publicKey) == 0 {
+		return ErrInvalidPublicKey
+	}
+
+	if !c.details.isCA && len(c.details.networks) == 0 {
+		return NewErrInvalidCertificateProperties("non-CA certificate must contain at least 1 network")
+	}
+
+	hasV4Networks := false
+	hasV6Networks := false
+	for _, network := range c.details.networks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid network: %s", network)
+		}
+
+		if network.Addr().IsUnspecified() {
+			return NewErrInvalidCertificateProperties("non-CA certificates must not use the zero address as a network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("networks may not contain zones: %s", network)
+		}
+
+		if network.Addr().Is4In6() {
+			return NewErrInvalidCertificateProperties("4in6 networks are not allowed: %s", network)
+		}
+
+		hasV4Networks = hasV4Networks || network.Addr().Is4()
+		hasV6Networks = hasV6Networks || network.Addr().Is6()
+	}
+
+	slices.SortFunc(c.details.networks, comparePrefix)
+	err := findDuplicatePrefix(c.details.networks)
+	if err != nil {
+		return err
+	}
+
+	for _, network := range c.details.unsafeNetworks {
+		if !network.IsValid() || !network.Addr().IsValid() {
+			return NewErrInvalidCertificateProperties("invalid unsafe network: %s", network)
+		}
+
+		if network.Addr().Zone() != "" {
+			return NewErrInvalidCertificateProperties("unsafe networks may not contain zones: %s", network)
+		}
+
+		if !c.details.isCA {
+			if network.Addr().Is6() {
+				if !hasV6Networks {
+					return NewErrInvalidCertificateProperties("IPv6 unsafe networks require an IPv6 address assignment: %s", network)
+				}
+			} else if network.Addr().Is4() {
+				if !hasV4Networks {
+					return NewErrInvalidCertificateProperties("IPv4 unsafe networks require an IPv4 address assignment: %s", network)
+				}
+			}
+		}
+	}
+
+	slices.SortFunc(c.details.unsafeNetworks, comparePrefix)
+	err = findDuplicatePrefix(c.details.unsafeNetworks)
+	if err != nil {
+		return err
+	}
+
 	return nil
 }
 
@@ -536,13 +607,20 @@ func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certifica
 		return nil, err
 	}
 
-	return &certificateV2{
+	c := &certificateV2{
 		details:    details,
 		rawDetails: rawDetails,
 		curve:      curve,
 		publicKey:  rawPublicKey,
 		signature:  rawSignature,
-	}, nil
+	}
+
+	err = c.validate()
+	if err != nil {
+		return nil, err
+	}
+
+	return c, nil
 }
 
 func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
@@ -639,9 +717,6 @@ func unmarshalDetails(b cryptobyte.String) (detailsV2, error) {
 		return detailsV2{}, ErrBadFormat
 	}
 
-	slices.SortFunc(networks, comparePrefix)
-	slices.SortFunc(unsafeNetworks, comparePrefix)
-
 	return detailsV2{
 		name:           string(name),
 		networks:       networks,

+ 14 - 2
cert/errors.go

@@ -2,6 +2,7 @@ package cert
 
 import (
 	"errors"
+	"fmt"
 )
 
 var (
@@ -17,10 +18,9 @@ var (
 	ErrInvalidPrivateKey          = errors.New("invalid private key")
 	ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve")
 	ErrPublicPrivateKeyMismatch   = errors.New("public key and private key are not a pair")
+	ErrPrivateKeyEncrypted        = errors.New("private key must be decrypted")
 	ErrCaNotFound                 = errors.New("could not find ca for the certificate")
 
-	ErrPrivateKeyEncrypted = errors.New("private key must be decrypted")
-
 	ErrInvalidPEMBlock                   = errors.New("input did not contain a valid PEM encoded block")
 	ErrInvalidPEMCertificateBanner       = errors.New("bytes did not contain a proper certificate banner")
 	ErrInvalidPEMX25519PublicKeyBanner   = errors.New("bytes did not contain a proper X25519 public key banner")
@@ -35,3 +35,15 @@ var (
 	ErrEmptySignature  = errors.New("empty signature")
 	ErrEmptyRawDetails = errors.New("empty rawDetails not allowed")
 )
+
+type ErrInvalidCertificateProperties struct {
+	str string
+}
+
+func NewErrInvalidCertificateProperties(format string, a ...any) error {
+	return &ErrInvalidCertificateProperties{fmt.Sprintf(format, a...)}
+}
+
+func (e *ErrInvalidCertificateProperties) Error() string {
+	return e.str
+}

+ 4 - 0
cert/helper_test.go

@@ -77,6 +77,10 @@ func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string
 		after = time.Now().Add(time.Second * 60).Round(time.Second)
 	}
 
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
 	var pub, priv []byte
 	switch curve {
 	case Curve_CURVE25519:

+ 1 - 0
cert/pem.go

@@ -34,6 +34,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
 	var err error
 
 	switch p.Type {
+	// Implementations must validate the resulting certificate contains valid information
 	case CertificateBanner:
 		c, err = unmarshalCertificateV1(p.Bytes, nil)
 	case CertificateV2Banner:

+ 14 - 7
cert/sign.go

@@ -9,7 +9,6 @@ import (
 	"fmt"
 	"math/big"
 	"net/netip"
-	"slices"
 	"time"
 )
 
@@ -31,6 +30,7 @@ type TBSCertificate struct {
 
 type beingSignedCertificate interface {
 	// fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation
+	// Implementations must validate the resulting certificate contains valid information
 	fromTBSCertificate(*TBSCertificate) error
 
 	// marshalForSigning returns the bytes that should be signed
@@ -83,9 +83,6 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
 		return nil, fmt.Errorf("curve in cert and private key supplied don't match")
 	}
 
-	//TODO: make sure we have all minimum properties to sign, like a public key
-	//TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs
-
 	if signer != nil {
 		if t.IsCA {
 			return nil, fmt.Errorf("can not sign a CA certificate with another")
@@ -107,9 +104,6 @@ func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLamb
 		}
 	}
 
-	slices.SortFunc(t.Networks, comparePrefix)
-	slices.SortFunc(t.UnsafeNetworks, comparePrefix)
-
 	var c beingSignedCertificate
 	switch t.Version {
 	case Version1:
@@ -158,3 +152,16 @@ func comparePrefix(a, b netip.Prefix) int {
 	}
 	return addr
 }
+
+// findDuplicatePrefix returns an error if there is a duplicate prefix in the pre-sorted input slice sortedPrefixes
+func findDuplicatePrefix(sortedPrefixes []netip.Prefix) error {
+	if len(sortedPrefixes) < 2 {
+		return nil
+	}
+	for i := 1; i < len(sortedPrefixes); i++ {
+		if comparePrefix(sortedPrefixes[i], sortedPrefixes[i-1]) == 0 {
+			return NewErrInvalidCertificateProperties("duplicate network detected: %v", sortedPrefixes[i])
+		}
+	}
+	return nil
+}

+ 15 - 5
cmd/nebula-cert/print_test.go

@@ -73,7 +73,7 @@ func Test_printCert(t *testing.T) {
 	tf.Truncate(0)
 	tf.Seek(0, 0)
 	ca, caKey := NewTestCaCert("test ca", nil, nil, time.Time{}, time.Time{}, nil, nil, nil)
-	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, []string{"hi"})
+	c, _ := NewTestCert(ca, caKey, "test", time.Time{}, time.Time{}, []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}, nil, []string{"hi"})
 
 	p, _ := c.MarshalPEM()
 	tf.Write(p)
@@ -97,7 +97,9 @@ func Test_printCert(t *testing.T) {
 		"isCa": false,
 		"issuer": "`+c.Issuer()+`",
 		"name": "test",
-		"networks": [],
+		"networks": [
+			"10.0.0.123/8"
+		],
 		"notAfter": "0001-01-01T00:00:00Z",
 		"notBefore": "0001-01-01T00:00:00Z",
 		"publicKey": "`+pk+`",
@@ -116,7 +118,9 @@ func Test_printCert(t *testing.T) {
 		"isCa": false,
 		"issuer": "`+c.Issuer()+`",
 		"name": "test",
-		"networks": [],
+		"networks": [
+			"10.0.0.123/8"
+		],
 		"notAfter": "0001-01-01T00:00:00Z",
 		"notBefore": "0001-01-01T00:00:00Z",
 		"publicKey": "`+pk+`",
@@ -135,7 +139,9 @@ func Test_printCert(t *testing.T) {
 		"isCa": false,
 		"issuer": "`+c.Issuer()+`",
 		"name": "test",
-		"networks": [],
+		"networks": [
+			"10.0.0.123/8"
+		],
 		"notAfter": "0001-01-01T00:00:00Z",
 		"notBefore": "0001-01-01T00:00:00Z",
 		"publicKey": "`+pk+`",
@@ -166,7 +172,7 @@ func Test_printCert(t *testing.T) {
 	assert.Nil(t, err)
 	assert.Equal(
 		t,
-		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
+		`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]
 `,
 		ob.String(),
 	)
@@ -212,6 +218,10 @@ func NewTestCert(ca cert.Certificate, signerKey []byte, name string, before, aft
 		after = ca.NotAfter()
 	}
 
+	if len(networks) == 0 {
+		networks = []netip.Prefix{netip.MustParsePrefix("10.0.0.123/8")}
+	}
+
 	pub, rawPriv := x25519Keypair()
 	nc := &cert.TBSCertificate{
 		Version:        cert.Version1,