Prechádzať zdrojové kódy

Handle license validation failures with a (re)boot in a limited state

gabrielseibel1 2 rokov pred
rodič
commit
a4febe0312
2 zmenil súbory, kde vykonal 30 pridanie a 16 odobranie
  1. 22 16
      ee/license.go
  2. 8 0
      main.go

+ 22 - 16
ee/license.go

@@ -12,7 +12,6 @@ import (
 	"golang.org/x/exp/slog"
 	"io"
 	"net/http"
-	"os"
 	"time"
 
 	"github.com/gravitl/netmaker/database"
@@ -35,7 +34,14 @@ type apiServerConf struct {
 // AddLicenseHooks - adds the validation and cache clear hooks
 func AddLicenseHooks() {
 	logic.HookManagerCh <- models.HookDetails{
-		Hook:     ValidateLicense,
+		Hook: func() error {
+			if err := ValidateLicense(); err != nil {
+				// stop the program when license is not valid anymore
+				// if the server restarts and still fails the license check, it can reboot in a limited mode
+				return fmt.Errorf("%w: %s", logic.HookManagerFatalError, err.Error())
+			}
+			return nil
+		},
 		Interval: time.Hour,
 	}
 	logic.HookManagerCh <- models.HookDetails{
@@ -48,25 +54,26 @@ func AddLicenseHooks() {
 // checks if a license is valid + limits are not exceeded
 // if license is free_tier and limits exceeds, then server should terminate
 // if license is not valid, server should terminate
+// TODO update comment
 func ValidateLicense() error {
 	licenseKeyValue := servercfg.GetLicenseKey()
 	netmakerTenantID := servercfg.GetNetmakerTenantID()
 	slog.Info("proceeding with Netmaker license validation...")
 	if len(licenseKeyValue) == 0 {
-		failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
+		return wrappedInErrValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
 	}
 	if len(netmakerTenantID) == 0 {
-		failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
+		return wrappedInErrValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
 	}
 
 	apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to get license public key: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to get license public key: %w", err))
 	}
 
 	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	if err != nil {
-		failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
 	}
 
 	licenseSecret := LicenseSecret{
@@ -76,35 +83,35 @@ func ValidateLicense() error {
 
 	secretData, err := json.Marshal(&licenseSecret)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to marshal license secret: %w", err))
 	}
 
 	encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
 	}
 
 	validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to validate license key: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to validate license key: %w", err))
 	}
 	if len(validationResponse) == 0 {
-		failValidation(errors.New("empty validation response"))
+		return wrappedInErrValidation(errors.New("empty validation response"))
 	}
 
 	var licenseResponse ValidatedLicense
 	if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
-		failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
 	}
 
 	respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to decrypt license: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to decrypt license: %w", err))
 	}
 
 	license := LicenseKey{}
 	if err = json.Unmarshal(respData, &license); err != nil {
-		failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
+		return wrappedInErrValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
 	}
 
 	slog.Info("License validation succeeded!")
@@ -158,9 +165,8 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
 	return pub, priv, nil
 }
 
-func failValidation(err error) {
-	slog.Error(errValidation.Error(), "error", err)
-	os.Exit(0)
+func wrappedInErrValidation(err error) error {
+	return fmt.Errorf("%w: %s", ErrValidation, err.Error())
 }
 
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {

+ 8 - 0
main.go

@@ -124,6 +124,14 @@ func initialize() { // Client Mode Prereq Check
 }
 
 func startControllers(wg *sync.WaitGroup, ctx context.Context) {
+	// limit the controllers when unlicensed
+	if servercfg.IsUnlicensed {
+		wg.Add(2)
+		go controller.HandleRESTRequests(wg, ctx, controller.LimitedHttpHandlers)
+		go logic.StartHookManager(ctx, wg)
+		return
+	}
+
 	if servercfg.IsDNSMode() {
 		err := logic.SetDNS()
 		if err != nil {