Ver código fonte

Handle license validation failures with a middleware

gabrielseibel1 2 anos atrás
pai
commit
a7266c76fa

+ 7 - 0
controllers/controller.go

@@ -14,6 +14,9 @@ import (
 	"github.com/gravitl/netmaker/servercfg"
 )
 
+// HttpMiddlewares - middleware functions for REST interactions
+var HttpMiddlewares []mux.MiddlewareFunc
+
 // HttpHandlers - handler functions for REST interactions
 var HttpHandlers = []interface{}{
 	nodeHandlers,
@@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
 	originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
 	methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})
 
+	for _, middleware := range HttpMiddlewares {
+		r.Use(middleware)
+	}
+
 	for _, handler := range HttpHandlers {
 		handler.(func(*mux.Router))(r)
 	}

+ 9 - 4
controllers/server.go

@@ -69,15 +69,20 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
 //				200: serverConfigResponse
 func getStatus(w http.ResponseWriter, r *http.Request) {
 	type status struct {
-		DB           bool `json:"db_connected"`
-		Broker       bool `json:"broker_connected"`
-		UnlicensedEE bool `json:"unlicensed_ee"`
+		DB           bool   `json:"db_connected"`
+		Broker       bool   `json:"broker_connected"`
+		LicenseError string `json:"license_error"`
+	}
+
+	licenseErr := ""
+	if servercfg.ErrLicenseValidation != nil {
+		licenseErr = servercfg.ErrLicenseValidation.Error()
 	}
 
 	currentServerStatus := status{
 		DB:           database.IsConnected(),
 		Broker:       mq.IsConnected(),
-		UnlicensedEE: servercfg.Is_EE && servercfg.IsUnlicensed,
+		LicenseError: licenseErr,
 	}
 
 	w.Header().Set("Content-Type", "application/json")

+ 18 - 0
ee/ee_controllers/middleware.go

@@ -0,0 +1,18 @@
+package ee_controllers
+
+import (
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/servercfg"
+	"net/http"
+	"strings"
+)
+
+func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
+	return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
+		if servercfg.ErrLicenseValidation != nil && !strings.HasPrefix(request.URL.Path, "/api/server") {
+			logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "unauthorized"))
+			return
+		}
+		handler.ServeHTTP(writer, request)
+	})
+}

+ 11 - 4
ee/initialize.go

@@ -7,10 +7,10 @@ import (
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/ee/ee_controllers"
 	eelogic "github.com/gravitl/netmaker/ee/logic"
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/exp/slog"
 )
 
 // InitEE - Initialize EE Logic
@@ -18,6 +18,10 @@ func InitEE() {
 	setIsEnterprise()
 	servercfg.Is_EE = true
 	models.SetLogo(retrieveEELogo())
+	controller.HttpMiddlewares = append(
+		controller.HttpMiddlewares,
+		ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
+	)
 	controller.HttpHandlers = append(
 		controller.HttpHandlers,
 		ee_controllers.MetricHandlers,
@@ -27,8 +31,11 @@ func InitEE() {
 	)
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
 		// == License Handling ==
-		ValidateLicense()
-		logger.Log(0, "proceeding with Paid Tier license")
+		if err := ValidateLicense(); err != nil {
+			slog.Error(err.Error())
+			return
+		}
+		slog.Info("proceeding with Paid Tier license")
 		logic.SetFreeTierForTelemetry(false)
 		// == End License Handling ==
 		AddLicenseHooks()
@@ -48,7 +55,7 @@ func resetFailover() {
 		for _, net := range nets {
 			err = eelogic.ResetFailover(net.NetID)
 			if err != nil {
-				logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error())
+				slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
 			}
 		}
 	}

+ 33 - 18
ee/license.go

@@ -12,7 +12,6 @@ import (
 	"golang.org/x/exp/slog"
 	"io"
 	"net/http"
-	"os"
 	"time"
 
 	"github.com/gravitl/netmaker/database"
@@ -44,29 +43,40 @@ func AddLicenseHooks() {
 	}
 }
 
-// ValidateLicense - the initial license check for netmaker server
+// ValidateLicense - the initial and periodic license check for netmaker server
 // checks if a license is valid + limits are not exceeded
-// if license is free_tier and limits exceeds, then server should terminate
-// if license is not valid, server should terminate
-func ValidateLicense() error {
+// if license is free_tier and limits exceeds, then function should error
+// if license is not valid, function should error
+func ValidateLicense() (err error) {
+	defer func() {
+		if err != nil {
+			err = fmt.Errorf("%w: %s", errValidation, err.Error())
+			servercfg.ErrLicenseValidation = err
+		}
+	}()
+
 	licenseKeyValue := servercfg.GetLicenseKey()
 	netmakerTenantID := servercfg.GetNetmakerTenantID()
 	slog.Info("proceeding with Netmaker license validation...")
 	if len(licenseKeyValue) == 0 {
-		failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
+		err = errors.New("empty license-key (LICENSE_KEY environment variable)")
+		return err
 	}
 	if len(netmakerTenantID) == 0 {
-		failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
+		err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
+		return err
 	}
 
 	apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to get license public key: %w", err))
+		err = fmt.Errorf("failed to get license public key: %w", err)
+		return err
 	}
 
 	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	if err != nil {
-		failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
+		err = fmt.Errorf("failed to fetch api server keys: %w", err)
+		return err
 	}
 
 	licenseSecret := LicenseSecret{
@@ -76,35 +86,42 @@ func ValidateLicense() error {
 
 	secretData, err := json.Marshal(&licenseSecret)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
+		err = fmt.Errorf("failed to marshal license secret: %w", err)
+		return err
 	}
 
 	encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
+		err = fmt.Errorf("failed to encrypt license secret data: %w", err)
+		return err
 	}
 
 	validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to validate license key: %w", err))
+		err = fmt.Errorf("failed to validate license key: %w", err)
+		return err
 	}
 	if len(validationResponse) == 0 {
-		failValidation(errors.New("empty validation response"))
+		err = errors.New("empty validation response")
+		return err
 	}
 
 	var licenseResponse ValidatedLicense
 	if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
-		failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
+		err = fmt.Errorf("failed to unmarshal validation response: %w", err)
+		return err
 	}
 
 	respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to decrypt license: %w", err))
+		err = fmt.Errorf("failed to decrypt license: %w", err)
+		return err
 	}
 
 	license := LicenseKey{}
 	if err = json.Unmarshal(respData, &license); err != nil {
-		failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
+		err = fmt.Errorf("failed to unmarshal license key: %w", err)
+		return err
 	}
 
 	slog.Info("License validation succeeded!")
@@ -159,8 +176,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
 }
 
 func failValidation(err error) {
-	slog.Error(errValidation.Error(), "error", err)
-	os.Exit(0)
 }
 
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {

+ 8 - 4
logic/timer.go

@@ -3,10 +3,11 @@ package logic
 import (
 	"context"
 	"fmt"
+	"github.com/gravitl/netmaker/logger"
+	"golang.org/x/exp/slog"
 	"sync"
 	"time"
 
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 )
 
@@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
 	for {
 		select {
 		case <-ctx.Done():
-			logger.Log(0, "## Stopping Hook Manager")
+			slog.Error("## Stopping Hook Manager")
 			return
 		case newhook := <-HookManagerCh:
 			wg.Add(1)
@@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
 		case <-ctx.Done():
 			return
 		case <-ticker.C:
-			hook()
+			if err := hook(); err != nil {
+				slog.Error(err.Error())
+			}
 		}
 	}
 
@@ -85,6 +88,7 @@ var timeHooks = []interface{}{
 }
 
 func loggerDump() error {
+	// TODO use slog?
 	logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
 	return nil
 }
@@ -93,7 +97,7 @@ func loggerDump() error {
 func runHooks() {
 	for _, hook := range timeHooks {
 		if err := hook.(func() error)(); err != nil {
-			logger.Log(1, "error occurred when running timer function:", err.Error())
+			slog.Error("error occurred when running timer function", "error", err.Error())
 		}
 	}
 }

+ 3 - 3
servercfg/serverconf.go

@@ -18,9 +18,9 @@ import (
 const EmqxBrokerType = "emqx"
 
 var (
-	Version      = "dev"
-	Is_EE        = false
-	IsUnlicensed = false
+	Version              = "dev"
+	Is_EE                = false
+	ErrLicenseValidation error
 )
 
 // SetHost - sets the host ip