Browse Source

add empty record check

afeiszli 3 years ago
parent
commit
ccd80eb10c
1 changed files with 3 additions and 3 deletions
  1. 3 3
      main.go

+ 3 - 3
main.go

@@ -191,7 +191,7 @@ func genCerts() error {
 	var private *ed25519.PrivateKey
 	var private *ed25519.PrivateKey
 	var err error
 	var err error
 	private, err = serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME)
 	private, err = serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME)
-	if errors.Is(err, os.ErrNotExist) {
+	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) {
 		logger.Log(0, "generating new root key")
 		logger.Log(0, "generating new root key")
 		_, newKey, err := ed25519.GenerateKey(rand.Reader)
 		_, newKey, err := ed25519.GenerateKey(rand.Reader)
 		if err != nil {
 		if err != nil {
@@ -207,7 +207,7 @@ func genCerts() error {
 	ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME)
 	ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME)
 	//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
 	//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
 	//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
-	if errors.Is(err, os.ErrNotExist) {
+	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) {
 		logger.Log(0, "generating new root CA")
 		logger.Log(0, "generating new root CA")
 		caName := tls.NewName("CA Root", "US", "Gravitl")
 		caName := tls.NewName("CA Root", "US", "Gravitl")
 		csr, err := tls.NewCSR(*private, caName)
 		csr, err := tls.NewCSR(*private, caName)
@@ -226,7 +226,7 @@ func genCerts() error {
 		return err
 		return err
 	}
 	}
 	cert, err := serverctl.ReadCertFromDB(tls.SERVER_PEM_NAME)
 	cert, err := serverctl.ReadCertFromDB(tls.SERVER_PEM_NAME)
-	if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
+	if errors.Is(err, os.ErrNotExist) || database.IsEmptyRecord(err) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 		//gen new key
 		//gen new key
 		logger.Log(0, "generating new server key/certificate")
 		logger.Log(0, "generating new server key/certificate")
 		_, key, err := ed25519.GenerateKey(rand.Reader)
 		_, key, err := ed25519.GenerateKey(rand.Reader)