Browse Source

Merge pull request #1338 from gravitl/hotfix_v0.14.5_cert_writes

always save certs on server start
Matthew R Kasun 3 years ago
parent
commit
6487ea17d1
1 changed files with 47 additions and 23 deletions
  1. 47 23
      main.go

+ 47 - 23
main.go

@@ -192,6 +192,9 @@ func genCerts() error {
 	logger.Log(0, "checking keys and certificates")
 	logger.Log(0, "checking keys and certificates")
 	var private *ed25519.PrivateKey
 	var private *ed25519.PrivateKey
 	var err error
 	var err error
+
+	// == ROOT key handling ==
+
 	private, err = serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME)
 	private, err = serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME)
 	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) {
 	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) {
 		logger.Log(0, "generating new root key")
 		logger.Log(0, "generating new root key")
@@ -199,13 +202,17 @@ func genCerts() error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_KEY_NAME, newKey); err != nil {
-			return err
-		}
 		private = &newKey
 		private = &newKey
 	} else if err != nil {
 	} else if err != nil {
 		return err
 		return err
 	}
 	}
+	logger.Log(2, "saving root.key")
+	if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_KEY_NAME, *private); err != nil {
+		return err
+	}
+
+	// == ROOT cert handling ==
+
 	ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME)
 	ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME)
 	//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
 	//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
 	//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
@@ -220,13 +227,17 @@ func genCerts() error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_PEM_NAME, rootCA); err != nil {
-			return err
-		}
 		ca = rootCA
 		ca = rootCA
 	} else if err != nil {
 	} else if err != nil {
 		return err
 		return err
 	}
 	}
+	logger.Log(2, "saving root.pem")
+	if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_PEM_NAME, ca); err != nil {
+		return err
+	}
+
+	// == SERVER cert handling ==
+
 	cert, err := serverctl.ReadCertFromDB(tls.SERVER_PEM_NAME)
 	cert, err := serverctl.ReadCertFromDB(tls.SERVER_PEM_NAME)
 	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 		//gen new key
 		//gen new key
@@ -240,21 +251,32 @@ func genCerts() error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		cert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY)
+		newCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
 		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_KEY_NAME, key); err != nil {
 		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_KEY_NAME, key); err != nil {
 			return err
 			return err
 		}
 		}
-		if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_PEM_NAME, cert); err != nil {
+		cert = newCert
+	} else if err != nil {
+		return err
+	} else if err == nil {
+		if serverKey, err := serverctl.ReadKeyFromDB(tls.SERVER_KEY_NAME); err == nil {
+			logger.Log(2, "saving server.key")
+			if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_KEY_NAME, *serverKey); err != nil {
+				return err
+			}
+		} else {
 			return err
 			return err
 		}
 		}
-	} else if err != nil {
+	}
+	logger.Log(2, "saving server.pem")
+	if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_PEM_NAME, cert); err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	logger.Log(2, "ensure the root.pem, root.key, server.pem, and server.key files are updated on your broker")
+	// == SERVER-CLIENT connection cert handling ==
 
 
 	serverClientCert, err := serverctl.ReadCertFromDB(tls.SERVER_CLIENT_PEM)
 	serverClientCert, err := serverctl.ReadCertFromDB(tls.SERVER_CLIENT_PEM)
 	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || serverClientCert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || serverClientCert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
@@ -269,7 +291,7 @@ func genCerts() error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		serverClientCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY)
+		newServerClientCert, err := tls.NewEndEntityCert(*private, csr, ca, tls.CERTIFICATE_VALIDITY)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -277,25 +299,27 @@ func genCerts() error {
 		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, key); err != nil {
 		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, key); err != nil {
 			return err
 			return err
 		}
 		}
-		if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil {
-			return err
-		}
+		serverClientCert = newServerClientCert
 	} else if err != nil {
 	} else if err != nil {
 		return err
 		return err
 	} else if err == nil {
 	} else if err == nil {
-		logger.Log(0, "detected valid server client cert, re-saving for future consumption")
-		key, err := serverctl.ReadKeyFromDB(tls.SERVER_CLIENT_KEY)
-		if err != nil {
-			return err
-		}
-		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, *key); err != nil {
-			return err
-		}
-		if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil {
+		logger.Log(2, "saving serverclient.key")
+		if serverClientKey, err := serverctl.ReadKeyFromDB(tls.SERVER_CLIENT_KEY); err == nil {
+			if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_KEY, *serverClientKey); err != nil {
+				return err
+			}
+		} else {
 			return err
 			return err
 		}
 		}
 	}
 	}
 
 
+	logger.Log(2, "saving serverclient.pem")
+	if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_CLIENT_PEM, serverClientCert); err != nil {
+		return err
+	}
+
+	logger.Log(1, "ensure the root.pem, root.key, server.pem, and server.key files are updated on your broker")
+
 	return serverctl.SetClientTLSConf(
 	return serverctl.SetClientTLSConf(
 		functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_PEM,
 		functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_PEM,
 		functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_KEY,
 		functions.GetNetmakerPath()+ncutils.GetSeparator()+tls.SERVER_CLIENT_KEY,