Просмотр исходного кода

Enhance error handling for license validation

gabrielseibel1 2 лет назад
Родитель
Сommit
7322b8b7de
1 измененных файлов с 23 добавлено и 14 удалено
  1. 23 14
      ee/license.go

+ 23 - 14
ee/license.go

@@ -7,6 +7,7 @@ import (
 	"bytes"
 	"crypto/rand"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"net/http"
@@ -50,19 +51,21 @@ func ValidateLicense() error {
 	licenseKeyValue := servercfg.GetLicenseKey()
 	netmakerTenantID := servercfg.GetNetmakerTenantID()
 	logger.Log(0, "proceeding with Netmaker license validation...")
-	if len(licenseKeyValue) == 0 || len(netmakerTenantID) == 0 {
-		err := fmt.Errorf("license-key: %s, tenant-id: %s", licenseKeyValue, netmakerTenantID)
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+	if len(licenseKeyValue) == 0 {
+		failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
+	}
+	if len(netmakerTenantID) == 0 {
+		failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
 	}
 
 	apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
 	if err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to get license public key: %w", err))
 	}
 
 	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	if err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
 	}
 
 	licenseSecret := LicenseSecret{
@@ -72,33 +75,35 @@ func ValidateLicense() error {
 
 	secretData, err := json.Marshal(&licenseSecret)
 	if err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
 	}
 
 	encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
 	if err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
 	}
 
 	validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
-	if err != nil || len(validationResponse) == 0 {
-		err := fmt.Errorf("err: %w, validation-response: %s", err, string(validationResponse))
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+	if err != nil {
+		failValidation(fmt.Errorf("failed to validate license key: %w", err))
+	}
+	if len(validationResponse) == 0 {
+		failValidation(errors.New("empty validation response"))
 	}
 
 	var licenseResponse ValidatedLicense
 	if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
 	}
 
 	respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
 	if err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to decrypt license: %w", err))
 	}
 
 	license := LicenseKey{}
 	if err = json.Unmarshal(respData, &license); err != nil {
-		logger.FatalLog0(fmt.Errorf("%w: %w", errValidation, err).Error())
+		failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
 	}
 
 	logger.Log(0, "License validation succeeded!")
@@ -152,6 +157,10 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
 	return pub, priv, nil
 }
 
+func failValidation(err error) {
+	logger.FatalLog0(errValidation.Error(), ":", err.Error())
+}
+
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
 	decodedPubKey := base64decode(licensePubKeyEncoded)
 	return ncutils.ConvertBytesToKey(decodedPubKey)
@@ -193,7 +202,7 @@ func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, erro
 	} else {
 		defer validateResponse.Body.Close()
 		if validateResponse.StatusCode != 200 {
-			return nil, fmt.Errorf("could not validate license")
+			return nil, fmt.Errorf("could not validate license, got status code %d", validateResponse.StatusCode)
 		} // if you received a 200 cache the response locally
 
 		body, err = io.ReadAll(validateResponse.Body)