Browse Source

[NET-398] Netmaker should indicate licensing/payment issues (#2465)

* Add more details to license validation errors

* Use err.Error() to get string to log

* Enhance error handling for license validation

* Use slog in validation
Gabriel de Souza Seibel 2 years ago
parent
commit
4bcb3d0196
1 changed files with 29 additions and 16 deletions
  1. 29 16
      ee/license.go

+ 29 - 16
ee/license.go

@@ -7,13 +7,15 @@ import (
 	"bytes"
 	"bytes"
 	"crypto/rand"
 	"crypto/rand"
 	"encoding/json"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"fmt"
+	"golang.org/x/exp/slog"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
+	"os"
 	"time"
 	"time"
 
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/netclient/ncutils"
@@ -49,19 +51,22 @@ func AddLicenseHooks() {
 func ValidateLicense() error {
 func ValidateLicense() error {
 	licenseKeyValue := servercfg.GetLicenseKey()
 	licenseKeyValue := servercfg.GetLicenseKey()
 	netmakerTenantID := servercfg.GetNetmakerTenantID()
 	netmakerTenantID := servercfg.GetNetmakerTenantID()
-	logger.Log(0, "proceeding with Netmaker license validation...")
-	if len(licenseKeyValue) == 0 || len(netmakerTenantID) == 0 {
-		logger.FatalLog0(errValidation.Error())
+	slog.Info("proceeding with Netmaker license validation...")
+	if len(licenseKeyValue) == 0 {
+		failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
+	}
+	if len(netmakerTenantID) == 0 {
+		failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
 	}
 	}
 
 
 	apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
 	apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
 	if err != nil {
 	if err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to get license public key: %w", err))
 	}
 	}
 
 
 	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	if err != nil {
 	if err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
 	}
 	}
 
 
 	licenseSecret := LicenseSecret{
 	licenseSecret := LicenseSecret{
@@ -71,35 +76,38 @@ func ValidateLicense() error {
 
 
 	secretData, err := json.Marshal(&licenseSecret)
 	secretData, err := json.Marshal(&licenseSecret)
 	if err != nil {
 	if err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
 	}
 	}
 
 
 	encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
 	encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
 	if err != nil {
 	if err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
 	}
 	}
 
 
 	validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
 	validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
-	if err != nil || len(validationResponse) == 0 {
-		logger.FatalLog0(errValidation.Error())
+	if err != nil {
+		failValidation(fmt.Errorf("failed to validate license key: %w", err))
+	}
+	if len(validationResponse) == 0 {
+		failValidation(errors.New("empty validation response"))
 	}
 	}
 
 
 	var licenseResponse ValidatedLicense
 	var licenseResponse ValidatedLicense
 	if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
 	if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
 	}
 	}
 
 
 	respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
 	respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
 	if err != nil {
 	if err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to decrypt license: %w", err))
 	}
 	}
 
 
 	license := LicenseKey{}
 	license := LicenseKey{}
 	if err = json.Unmarshal(respData, &license); err != nil {
 	if err = json.Unmarshal(respData, &license); err != nil {
-		logger.FatalLog0(errValidation.Error())
+		failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
 	}
 	}
 
 
-	logger.Log(0, "License validation succeeded!")
+	slog.Info("License validation succeeded!")
 	return nil
 	return nil
 }
 }
 
 
@@ -150,6 +158,11 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
 	return pub, priv, nil
 	return pub, priv, nil
 }
 }
 
 
+func failValidation(err error) {
+	slog.Error(errValidation.Error(), "error", err)
+	os.Exit(0)
+}
+
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
 	decodedPubKey := base64decode(licensePubKeyEncoded)
 	decodedPubKey := base64decode(licensePubKeyEncoded)
 	return ncutils.ConvertBytesToKey(decodedPubKey)
 	return ncutils.ConvertBytesToKey(decodedPubKey)
@@ -187,11 +200,11 @@ func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, erro
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
-		logger.Log(3, "proceeding with cached response, Netmaker API may be down")
+		slog.Warn("proceeding with cached response, Netmaker API may be down")
 	} else {
 	} else {
 		defer validateResponse.Body.Close()
 		defer validateResponse.Body.Close()
 		if validateResponse.StatusCode != 200 {
 		if validateResponse.StatusCode != 200 {
-			return nil, fmt.Errorf("could not validate license")
+			return nil, fmt.Errorf("could not validate license, got status code %d", validateResponse.StatusCode)
 		} // if you received a 200 cache the response locally
 		} // if you received a 200 cache the response locally
 
 
 		body, err = io.ReadAll(validateResponse.Body)
 		body, err = io.ReadAll(validateResponse.Body)