Quellcode durchsuchen

improve nebula-cert sign version auto-select (#1535)

Jack Doan vor 3 Wochen
Ursprung
Commit
6d7cf611c9
2 geänderte Dateien mit 19 neuen und 15 gelöschten Zeilen
  1. 17 13
      cmd/nebula-cert/sign.go
  2. 2 2
      cmd/nebula-cert/sign_test.go

+ 17 - 13
cmd/nebula-cert/sign.go

@@ -43,7 +43,7 @@ type signFlags struct {
 func newSignFlags() *signFlags {
 	sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)}
 	sf.set.Usage = func() {}
-	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.")
+	sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use. The default is to match the version of the signing CA")
 	sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key")
 	sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert")
 	sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname")
@@ -167,6 +167,10 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		return fmt.Errorf("ca certificate is expired")
 	}
 
+	if version == 0 {
+		version = caCert.Version()
+	}
+
 	// if no duration is given, expire one second before the root expires
 	if *sf.duration <= 0 {
 		*sf.duration = time.Until(caCert.NotAfter()) - time.Second*1
@@ -279,21 +283,19 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 	notBefore := time.Now()
 	notAfter := notBefore.Add(*sf.duration)
 
-	if version == 0 || version == cert.Version1 {
-		// Make sure we at least have an ip
+	switch version {
+	case cert.Version1:
+		// Make sure we have only one ipv4 address
 		if len(v4Networks) != 1 {
 			return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address")
 		}
 
-		if version == cert.Version1 {
-			// If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses
-			if len(v6Networks) > 0 {
-				return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4")
-			}
+		if len(v6Networks) > 0 {
+			return newHelpErrorf("invalid -networks definition: v1 certificates can only contain ipv4 addresses")
+		}
 
-			if len(v6UnsafeNetworks) > 0 {
-				return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4")
-			}
+		if len(v6UnsafeNetworks) > 0 {
+			return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
 		}
 
 		t := &cert.TBSCertificate{
@@ -323,9 +325,8 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		}
 
 		crts = append(crts, nc)
-	}
 
-	if version == 0 || version == cert.Version2 {
+	case cert.Version2:
 		t := &cert.TBSCertificate{
 			Version:        cert.Version2,
 			Name:           *sf.name,
@@ -353,6 +354,9 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader)
 		}
 
 		crts = append(crts, nc)
+	default:
+		// this should be unreachable
+		return fmt.Errorf("invalid version: %d", version)
 	}
 
 	if !isP11 && *sf.inPubPath == "" {

+ 2 - 2
cmd/nebula-cert/sign_test.go

@@ -55,7 +55,7 @@ func Test_signHelp(t *testing.T) {
 			"  -unsafe-networks string\n"+
 			"    \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+
 			"  -version uint\n"+
-			"    \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n",
+			"    \tOptional: version of the certificate format to use. The default is to match the version of the signing CA\n",
 		ob.String(),
 	)
 }
@@ -204,7 +204,7 @@ func Test_signCert(t *testing.T) {
 	ob.Reset()
 	eb.Reset()
 	args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"}
-	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4")
+	assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only contain ipv4 addresses")
 	assert.Empty(t, ob.String())
 	assert.Empty(t, eb.String())