Browse Source

refactoring cert logic to use database

afeiszli 3 years ago
parent
commit
f28d361bea
5 changed files with 124 additions and 90 deletions
  1. 3 2
      controllers/server.go
  2. 7 7
      main.go
  3. 1 1
      netclient/functions/register.go
  4. 105 0
      serverctl/tls.go
  5. 8 80
      tls/tls.go

+ 3 - 2
controllers/server.go

@@ -15,6 +15,7 @@ import (
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/config"
 	"github.com/gravitl/netmaker/netclient/config"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/servercfg"
+	"github.com/gravitl/netmaker/serverctl"
 	"github.com/gravitl/netmaker/tls"
 	"github.com/gravitl/netmaker/tls"
 )
 )
 
 
@@ -142,12 +143,12 @@ func register(w http.ResponseWriter, r *http.Request) {
 
 
 // genCerts generates a client certificate and returns the certificate and root CA
 // genCerts generates a client certificate and returns the certificate and root CA
 func genCerts(clientKey *ed25519.PrivateKey, name *pkix.Name) (*x509.Certificate, *x509.Certificate, error) {
 func genCerts(clientKey *ed25519.PrivateKey, name *pkix.Name) (*x509.Certificate, *x509.Certificate, error) {
-	ca, err := tls.ReadCertFromFile("/etc/netmaker/root.pem")
+	ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME)
 	if err != nil {
 	if err != nil {
 		logger.Log(2, "root ca not found ", err.Error())
 		logger.Log(2, "root ca not found ", err.Error())
 		return nil, nil, fmt.Errorf("root ca not found %w", err)
 		return nil, nil, fmt.Errorf("root ca not found %w", err)
 	}
 	}
-	key, err := tls.ReadKeyFromFile("/etc/netmaker/root.key")
+	key, err := serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME)
 	if err != nil {
 	if err != nil {
 		logger.Log(2, "root key not found ", err.Error())
 		logger.Log(2, "root key not found ", err.Error())
 		return nil, nil, fmt.Errorf("root key not found %w", err)
 		return nil, nil, fmt.Errorf("root key not found %w", err)

+ 7 - 7
main.go

@@ -190,21 +190,21 @@ func genCerts() error {
 	logger.Log(0, "checking keys and certificates")
 	logger.Log(0, "checking keys and certificates")
 	var private *ed25519.PrivateKey
 	var private *ed25519.PrivateKey
 	var err error
 	var err error
-	private, err = tls.ReadKeyFromFile(functions.GetNetmakerPath() + "/root.key")
+	private, err = serverctl.ReadKeyFromDB(tls.ROOT_KEY_NAME)
 	if errors.Is(err, os.ErrNotExist) {
 	if errors.Is(err, os.ErrNotExist) {
 		logger.Log(0, "generating new root key")
 		logger.Log(0, "generating new root key")
 		_, newKey, err := ed25519.GenerateKey(rand.Reader)
 		_, newKey, err := ed25519.GenerateKey(rand.Reader)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if err := tls.SaveKeyToFile(functions.GetNetmakerPath(), "/root.key", newKey); err != nil {
+		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_KEY_NAME, newKey); err != nil {
 			return err
 			return err
 		}
 		}
 		private = &newKey
 		private = &newKey
 	} else if err != nil {
 	} else if err != nil {
 		return err
 		return err
 	}
 	}
-	ca, err := tls.ReadCertFromFile(functions.GetNetmakerPath() + ncutils.GetSeparator() + "root.pem")
+	ca, err := serverctl.ReadCertFromDB(tls.ROOT_PEM_NAME)
 	//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
 	//if cert doesn't exist or will expire within 10 days --- but can't do this as clients won't be able to connect
 	//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	//if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	if errors.Is(err, os.ErrNotExist) {
 	if errors.Is(err, os.ErrNotExist) {
@@ -218,14 +218,14 @@ func genCerts() error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if err := tls.SaveCertToFile(functions.GetNetmakerPath(), ncutils.GetSeparator()+"root.pem", rootCA); err != nil {
+		if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.ROOT_PEM_NAME, rootCA); err != nil {
 			return err
 			return err
 		}
 		}
 		ca = rootCA
 		ca = rootCA
 	} else if err != nil {
 	} else if err != nil {
 		return err
 		return err
 	}
 	}
-	cert, err := tls.ReadCertFromFile(functions.GetNetmakerPath() + "/server.pem")
+	cert, err := serverctl.ReadCertFromDB(tls.SERVER_PEM_NAME)
 	if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 	if errors.Is(err, os.ErrNotExist) || cert.NotAfter.Before(time.Now().Add(time.Hour*24*10)) {
 		//gen new key
 		//gen new key
 		logger.Log(0, "generating new server key/certificate")
 		logger.Log(0, "generating new server key/certificate")
@@ -242,10 +242,10 @@ func genCerts() error {
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
-		if err := tls.SaveKeyToFile(functions.GetNetmakerPath(), "/server.key", key); err != nil {
+		if err := serverctl.SaveKey(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_KEY_NAME, key); err != nil {
 			return err
 			return err
 		}
 		}
-		if err := tls.SaveCertToFile(functions.GetNetmakerPath(), "/server.pem", cert); err != nil {
+		if err := serverctl.SaveCert(functions.GetNetmakerPath()+ncutils.GetSeparator(), tls.SERVER_PEM_NAME, cert); err != nil {
 			return err
 			return err
 		}
 		}
 	} else if err != nil {
 	} else if err != nil {

+ 1 - 1
netclient/functions/register.go

@@ -88,7 +88,7 @@ func RegisterWithServer(private *ed25519.PrivateKey, cfg *config.ClientConfig) e
 	//the pubkeys are included in the response so the values in the certificate can be updated appropriately
 	//the pubkeys are included in the response so the values in the certificate can be updated appropriately
 	resp.CA.PublicKey = resp.CAPubKey
 	resp.CA.PublicKey = resp.CAPubKey
 	resp.Cert.PublicKey = resp.CertPubKey
 	resp.Cert.PublicKey = resp.CertPubKey
-	if err := tls.SaveCertToFile(ncutils.GetNetclientServerPath(cfg.Server.Server)+ncutils.GetSeparator(), "root.pem", &resp.CA); err != nil {
+	if err := tls.SaveCertToFile(ncutils.GetNetclientServerPath(cfg.Server.Server)+ncutils.GetSeparator(), tls.ROOT_PEM_NAME, &resp.CA); err != nil {
 		return err
 		return err
 	}
 	}
 	if err := tls.SaveCertToFile(ncutils.GetNetclientServerPath(cfg.Server.Server)+ncutils.GetSeparator(), "client.pem", &resp.Cert); err != nil {
 	if err := tls.SaveCertToFile(ncutils.GetNetclientServerPath(cfg.Server.Server)+ncutils.GetSeparator(), "client.pem", &resp.Cert); err != nil {

+ 105 - 0
serverctl/tls.go

@@ -0,0 +1,105 @@
+package serverctl
+
+import (
+	"crypto/ed25519"
+	"crypto/x509"
+	"encoding/json"
+	"encoding/pem"
+	"errors"
+	"fmt"
+
+	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/tls"
+)
+
+// SaveCert - save a certificate to file and DB
+func SaveCert(path, name string, cert *x509.Certificate) error {
+	if err := SaveCertToDB(name, cert); err != nil {
+		return err
+	}
+	return tls.SaveCertToFile(path, name, cert)
+}
+
+// SaveCertToDB - save a certificate to the certs database
+func SaveCertToDB(name string, cert *x509.Certificate) error {
+	if certBytes := pem.EncodeToMemory(&pem.Block{
+		Type:  "CERTIFICATE",
+		Bytes: cert.Raw,
+	}); len(certBytes) > 0 {
+		data, err := json.Marshal(&certBytes)
+		if err != nil {
+			return fmt.Errorf("failed to marshal certificate - %v ", err)
+		}
+		return database.Insert(name, string(data), database.CERTS_TABLE_NAME)
+	} else {
+		return fmt.Errorf("failed to write cert to DB - %s ", name)
+	}
+}
+
+// SaveKey - save a private key (ed25519) to file and DB
+func SaveKey(path, name string, key ed25519.PrivateKey) error {
+	if err := SaveKeyToDB(name, key); err != nil {
+		return err
+	}
+	return tls.SaveKeyToFile(path, name, key)
+}
+
+// SaveKeyToDB - save a private key (ed25519) to the specified path
+func SaveKeyToDB(name string, key ed25519.PrivateKey) error {
+	privBytes, err := x509.MarshalPKCS8PrivateKey(key)
+	if err != nil {
+		return fmt.Errorf("failed to marshal key %v ", err)
+	}
+	if pemBytes := pem.EncodeToMemory(&pem.Block{
+		Type:  "PRIVATE KEY",
+		Bytes: privBytes,
+	}); len(pemBytes) > 0 {
+		data, err := json.Marshal(&pemBytes)
+		if err != nil {
+			return fmt.Errorf("failed to marshal key %v ", err)
+		}
+		return database.Insert(name, string(data), database.CERTS_TABLE_NAME)
+	} else {
+		return fmt.Errorf("failed to write key to DB - %v ", err)
+	}
+}
+
+// ReadCertFromDB - reads a certificate from the database
+func ReadCertFromDB(name string) (*x509.Certificate, error) {
+	certString, err := database.FetchRecord(database.CERTS_TABLE_NAME, name)
+	if err != nil {
+		return nil, fmt.Errorf("unable to read file %w", err)
+	}
+	var certBytes []byte
+	if err = json.Unmarshal([]byte(certString), &certBytes); err != nil {
+		return nil, fmt.Errorf("unable to unmarshal db cert %w", err)
+	}
+	block, _ := pem.Decode(certBytes)
+	if block == nil || block.Type != "CERTIFICATE" {
+		return nil, errors.New("not a cert " + block.Type)
+	}
+	cert, err := x509.ParseCertificate(block.Bytes)
+	if err != nil {
+		return nil, fmt.Errorf("unable to parse cert %w", err)
+	}
+	return cert, nil
+}
+
+// ReadKeyFromDB - reads a private key (ed25519) from the database
+func ReadKeyFromDB(name string) (*ed25519.PrivateKey, error) {
+	keyString, err := database.FetchRecord(database.CERTS_TABLE_NAME, name)
+	if err != nil {
+		return nil, fmt.Errorf("unable to read key value from db - %w", err)
+	}
+	var bytes []byte
+	if err = json.Unmarshal([]byte(keyString), &bytes); err != nil {
+		return nil, fmt.Errorf("unable to unmarshal db key - %w", err)
+	}
+	keyBytes, _ := pem.Decode(bytes)
+	key, err := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
+	if err != nil {
+		return nil, fmt.Errorf("unable to parse key from DB -  %w", err)
+	}
+	private := key.(ed25519.PrivateKey)
+	return &private, nil
+}

+ 8 - 80
tls/tls.go

@@ -6,7 +6,6 @@ import (
 	"crypto/x509"
 	"crypto/x509"
 	"crypto/x509/pkix"
 	"crypto/x509/pkix"
 	"encoding/base64"
 	"encoding/base64"
-	"encoding/json"
 	"encoding/pem"
 	"encoding/pem"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
@@ -15,7 +14,6 @@ import (
 	"time"
 	"time"
 
 
 	"filippo.io/edwards25519"
 	"filippo.io/edwards25519"
-	"github.com/gravitl/netmaker/database"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 )
 
 
@@ -25,10 +23,16 @@ const (
 	CERTIFICATE_VALIDITY = 365
 	CERTIFICATE_VALIDITY = 365
 
 
 	// SERVER_KEY_NAME - name of server cert private key
 	// SERVER_KEY_NAME - name of server cert private key
-	SERVER_KEY_NAME = "serverkey"
+	SERVER_KEY_NAME = "server.key"
 
 
 	// ROOT_KEY_NAME - name of root cert private key
 	// ROOT_KEY_NAME - name of root cert private key
-	ROOT_KEY_NAME = "rootkey"
+	ROOT_KEY_NAME = "root.key"
+
+	// SERVER_PEM_NAME - name of server pem
+	SERVER_PEM_NAME = "server.pem"
+
+	// ROOT_PEM_NAME - name of root pem
+	ROOT_PEM_NAME = "root.pem"
 )
 )
 
 
 type (
 type (
@@ -220,22 +224,6 @@ func SaveCertToFile(path, name string, cert *x509.Certificate) error {
 	return nil
 	return nil
 }
 }
 
 
-// SaveCertToDB - save a certificate to the certs database
-func SaveCertToDB(name string, cert *x509.Certificate) error {
-	if certBytes := pem.EncodeToMemory(&pem.Block{
-		Type:  "CERTIFICATE",
-		Bytes: cert.Raw,
-	}); len(certBytes) > 0 {
-		data, err := json.Marshal(&certBytes)
-		if err != nil {
-			return fmt.Errorf("failed to marshal certificate - %v ", err)
-		}
-		return database.Insert(name, string(data), database.CERTS_TABLE_NAME)
-	} else {
-		return fmt.Errorf("failed to write cert to DB - %s ", name)
-	}
-}
-
 // SaveKeyToFile save a private key (ed25519) to the certs database
 // SaveKeyToFile save a private key (ed25519) to the certs database
 func SaveKeyToFile(path, name string, key ed25519.PrivateKey) error {
 func SaveKeyToFile(path, name string, key ed25519.PrivateKey) error {
 	//func SaveKey(name string, key *ecdsa.PrivateKey) error {
 	//func SaveKey(name string, key *ecdsa.PrivateKey) error {
@@ -260,26 +248,6 @@ func SaveKeyToFile(path, name string, key ed25519.PrivateKey) error {
 	return nil
 	return nil
 }
 }
 
 
-// SaveKeyToDB - save a private key (ed25519) to the specified path
-func SaveKeyToDB(name string, key ed25519.PrivateKey) error {
-	privBytes, err := x509.MarshalPKCS8PrivateKey(key)
-	if err != nil {
-		return fmt.Errorf("failed to marshal key %v ", err)
-	}
-	if pemBytes := pem.EncodeToMemory(&pem.Block{
-		Type:  "PRIVATE KEY",
-		Bytes: privBytes,
-	}); len(pemBytes) > 0 {
-		data, err := json.Marshal(&pemBytes)
-		if err != nil {
-			return fmt.Errorf("failed to marshal key %v ", err)
-		}
-		return database.Insert(name, string(data), database.CERTS_TABLE_NAME)
-	} else {
-		return fmt.Errorf("failed to write key to DB - %v ", err)
-	}
-}
-
 // ReadCertFromFile reads a certificate from disk
 // ReadCertFromFile reads a certificate from disk
 func ReadCertFromFile(name string) (*x509.Certificate, error) {
 func ReadCertFromFile(name string) (*x509.Certificate, error) {
 	contents, err := os.ReadFile(name)
 	contents, err := os.ReadFile(name)
@@ -297,27 +265,6 @@ func ReadCertFromFile(name string) (*x509.Certificate, error) {
 	return cert, nil
 	return cert, nil
 }
 }
 
 
-// ReadCertFromDB - reads a certificate from the database
-func ReadCertFromDB(name string) (*x509.Certificate, error) {
-	certString, err := database.FetchRecord(database.CERTS_TABLE_NAME, name)
-	if err != nil {
-		return nil, fmt.Errorf("unable to read file %w", err)
-	}
-	var certBytes []byte
-	if err = json.Unmarshal([]byte(certString), &certBytes); err != nil {
-		return nil, fmt.Errorf("unable to unmarshal db cert %w", err)
-	}
-	block, _ := pem.Decode(certBytes)
-	if block == nil || block.Type != "CERTIFICATE" {
-		return nil, errors.New("not a cert " + block.Type)
-	}
-	cert, err := x509.ParseCertificate(block.Bytes)
-	if err != nil {
-		return nil, fmt.Errorf("unable to parse cert %w", err)
-	}
-	return cert, nil
-}
-
 // ReadKeyFromFile reads a private key (ed25519) from disk
 // ReadKeyFromFile reads a private key (ed25519) from disk
 func ReadKeyFromFile(name string) (*ed25519.PrivateKey, error) {
 func ReadKeyFromFile(name string) (*ed25519.PrivateKey, error) {
 	bytes, err := os.ReadFile(name)
 	bytes, err := os.ReadFile(name)
@@ -333,25 +280,6 @@ func ReadKeyFromFile(name string) (*ed25519.PrivateKey, error) {
 	return &private, nil
 	return &private, nil
 }
 }
 
 
-// ReadKeyFromDB - reads a private key (ed25519) from the database
-func ReadKeyFromDB(name string) (*ed25519.PrivateKey, error) {
-	keyString, err := database.FetchRecord(database.CERTS_TABLE_NAME, name)
-	if err != nil {
-		return nil, fmt.Errorf("unable to read key value from db - %w", err)
-	}
-	var bytes []byte
-	if err = json.Unmarshal([]byte(keyString), &bytes); err != nil {
-		return nil, fmt.Errorf("unable to unmarshal db key - %w", err)
-	}
-	keyBytes, _ := pem.Decode(bytes)
-	key, err := x509.ParsePKCS8PrivateKey(keyBytes.Bytes)
-	if err != nil {
-		return nil, fmt.Errorf("unable to parse key from DB -  %w", err)
-	}
-	private := key.(ed25519.PrivateKey)
-	return &private, nil
-}
-
 // serialNumber generates a serial number for a certificate
 // serialNumber generates a serial number for a certificate
 func serialNumber() *big.Int {
 func serialNumber() *big.Int {
 	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
 	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)