Browse Source

[NET-404] Run in limited mode when ee checks fail (#2474)

* Add limited http handlers functionality to rest handler

* Export ee.errValidation (ee.ErrValidation)

* Export a fatal error handled by the hook manager

* Export a new status variable for unlicensed server

* Mark server as unlicensed when ee checks fail

* Handle license validation failures with a (re)boot in a limited state

* Revert "Export a fatal error handled by the hook manager"

This reverts commit 069c21974a8d36e889c73ad78023448d787d62a5.

* Revert "Export ee.errValidation (ee.ErrValidation)"

This reverts commit 59dbab8c79773ca5d879f28cbaf53f3dd4297b9b.

* Revert "Add limited http handlers functionality to rest handler"

This reverts commit e2f1f28facaca54713db76a588839cd2733cf673.

* Revert "Handle license validation failures with a (re)boot in a limited state"

This reverts commit 58cfbbaf522a1345aac1fa67964ebff0a6d60cd8.

* Revert "Mark server as unlicensed when ee checks fail"

This reverts commit 77c6dbdd3c9cfa6e7d6becedef6251e8617ae367.

* Handle license validation failures with a middleware

* Forbid responses if unlicensed ee and not in status api

* Remove unused func
Gabriel de Souza Seibel 2 years ago
parent
commit
922e7dbf2c
7 changed files with 90 additions and 43 deletions
  1. 7 0
      controllers/controller.go
  2. 11 12
      controllers/server.go
  3. 17 0
      ee/ee_controllers/middleware.go
  4. 11 4
      ee/initialize.go
  5. 33 21
      ee/license.go
  6. 8 4
      logic/timer.go
  7. 3 2
      servercfg/serverconf.go

+ 7 - 0
controllers/controller.go

@@ -14,6 +14,9 @@ import (
 	"github.com/gravitl/netmaker/servercfg"
 )
 
+// HttpMiddlewares - middleware functions for REST interactions
+var HttpMiddlewares []mux.MiddlewareFunc
+
 // HttpHandlers - handler functions for REST interactions
 var HttpHandlers = []interface{}{
 	nodeHandlers,
@@ -42,6 +45,10 @@ func HandleRESTRequests(wg *sync.WaitGroup, ctx context.Context) {
 	originsOk := handlers.AllowedOrigins(strings.Split(servercfg.GetAllowedOrigin(), ","))
 	methodsOk := handlers.AllowedMethods([]string{http.MethodGet, http.MethodPut, http.MethodPost, http.MethodDelete})
 
+	for _, middleware := range HttpMiddlewares {
+		r.Use(middleware)
+	}
+
 	for _, handler := range HttpHandlers {
 		handler.(func(*mux.Router))(r)
 	}

+ 11 - 12
controllers/server.go

@@ -68,22 +68,21 @@ func getUsage(w http.ResponseWriter, r *http.Request) {
 //			Responses:
 //				200: serverConfigResponse
 func getStatus(w http.ResponseWriter, r *http.Request) {
-	// TODO
-	// - check health of broker
 	type status struct {
-		DB     bool `json:"db_connected"`
-		Broker bool `json:"broker_connected"`
-		Usage  struct {
-			Hosts    int `json:"hosts"`
-			Clients  int `json:"clients"`
-			Networks int `json:"networks"`
-			Users    int `json:"users"`
-		} `json:"usage"`
+		DB           bool   `json:"db_connected"`
+		Broker       bool   `json:"broker_connected"`
+		LicenseError string `json:"license_error"`
+	}
+
+	licenseErr := ""
+	if servercfg.ErrLicenseValidation != nil {
+		licenseErr = servercfg.ErrLicenseValidation.Error()
 	}
 
 	currentServerStatus := status{
-		DB:     database.IsConnected(),
-		Broker: mq.IsConnected(),
+		DB:           database.IsConnected(),
+		Broker:       mq.IsConnected(),
+		LicenseError: licenseErr,
 	}
 
 	w.Header().Set("Content-Type", "application/json")

+ 17 - 0
ee/ee_controllers/middleware.go

@@ -0,0 +1,17 @@
+package ee_controllers
+
+import (
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/servercfg"
+	"net/http"
+)
+
+func OnlyServerAPIWhenUnlicensedMiddleware(handler http.Handler) http.Handler {
+	return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
+		if servercfg.ErrLicenseValidation != nil && request.URL.Path != "/api/server/status" {
+			logic.ReturnErrorResponse(writer, request, logic.FormatError(servercfg.ErrLicenseValidation, "forbidden"))
+			return
+		}
+		handler.ServeHTTP(writer, request)
+	})
+}

+ 11 - 4
ee/initialize.go

@@ -7,10 +7,10 @@ import (
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/ee/ee_controllers"
 	eelogic "github.com/gravitl/netmaker/ee/logic"
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/exp/slog"
 )
 
 // InitEE - Initialize EE Logic
@@ -18,6 +18,10 @@ func InitEE() {
 	setIsEnterprise()
 	servercfg.Is_EE = true
 	models.SetLogo(retrieveEELogo())
+	controller.HttpMiddlewares = append(
+		controller.HttpMiddlewares,
+		ee_controllers.OnlyServerAPIWhenUnlicensedMiddleware,
+	)
 	controller.HttpHandlers = append(
 		controller.HttpHandlers,
 		ee_controllers.MetricHandlers,
@@ -27,8 +31,11 @@ func InitEE() {
 	)
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
 		// == License Handling ==
-		ValidateLicense()
-		logger.Log(0, "proceeding with Paid Tier license")
+		if err := ValidateLicense(); err != nil {
+			slog.Error(err.Error())
+			return
+		}
+		slog.Info("proceeding with Paid Tier license")
 		logic.SetFreeTierForTelemetry(false)
 		// == End License Handling ==
 		AddLicenseHooks()
@@ -48,7 +55,7 @@ func resetFailover() {
 		for _, net := range nets {
 			err = eelogic.ResetFailover(net.NetID)
 			if err != nil {
-				logger.Log(0, "failed to reset failover on network", net.NetID, ":", err.Error())
+				slog.Error("failed to reset failover", "network", net.NetID, "error", err.Error())
 			}
 		}
 	}

+ 33 - 21
ee/license.go

@@ -12,7 +12,6 @@ import (
 	"golang.org/x/exp/slog"
 	"io"
 	"net/http"
-	"os"
 	"time"
 
 	"github.com/gravitl/netmaker/database"
@@ -44,29 +43,40 @@ func AddLicenseHooks() {
 	}
 }
 
-// ValidateLicense - the initial license check for netmaker server
+// ValidateLicense - the initial and periodic license check for netmaker server
 // checks if a license is valid + limits are not exceeded
-// if license is free_tier and limits exceeds, then server should terminate
-// if license is not valid, server should terminate
-func ValidateLicense() error {
+// if license is free_tier and limits exceeds, then function should error
+// if license is not valid, function should error
+func ValidateLicense() (err error) {
+	defer func() {
+		if err != nil {
+			err = fmt.Errorf("%w: %s", errValidation, err.Error())
+			servercfg.ErrLicenseValidation = err
+		}
+	}()
+
 	licenseKeyValue := servercfg.GetLicenseKey()
 	netmakerTenantID := servercfg.GetNetmakerTenantID()
 	slog.Info("proceeding with Netmaker license validation...")
 	if len(licenseKeyValue) == 0 {
-		failValidation(errors.New("empty license-key (LICENSE_KEY environment variable)"))
+		err = errors.New("empty license-key (LICENSE_KEY environment variable)")
+		return err
 	}
 	if len(netmakerTenantID) == 0 {
-		failValidation(errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)"))
+		err = errors.New("empty tenant-id (NETMAKER_TENANT_ID environment variable)")
+		return err
 	}
 
 	apiPublicKey, err := getLicensePublicKey(licenseKeyValue)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to get license public key: %w", err))
+		err = fmt.Errorf("failed to get license public key: %w", err)
+		return err
 	}
 
 	tempPubKey, tempPrivKey, err := FetchApiServerKeys()
 	if err != nil {
-		failValidation(fmt.Errorf("failed to fetch api server keys: %w", err))
+		err = fmt.Errorf("failed to fetch api server keys: %w", err)
+		return err
 	}
 
 	licenseSecret := LicenseSecret{
@@ -76,35 +86,42 @@ func ValidateLicense() error {
 
 	secretData, err := json.Marshal(&licenseSecret)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to marshal license secret: %w", err))
+		err = fmt.Errorf("failed to marshal license secret: %w", err)
+		return err
 	}
 
 	encryptedData, err := ncutils.BoxEncrypt(secretData, apiPublicKey, tempPrivKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to encrypt license secret data: %w", err))
+		err = fmt.Errorf("failed to encrypt license secret data: %w", err)
+		return err
 	}
 
 	validationResponse, err := validateLicenseKey(encryptedData, tempPubKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to validate license key: %w", err))
+		err = fmt.Errorf("failed to validate license key: %w", err)
+		return err
 	}
 	if len(validationResponse) == 0 {
-		failValidation(errors.New("empty validation response"))
+		err = errors.New("empty validation response")
+		return err
 	}
 
 	var licenseResponse ValidatedLicense
 	if err = json.Unmarshal(validationResponse, &licenseResponse); err != nil {
-		failValidation(fmt.Errorf("failed to unmarshal validation response: %w", err))
+		err = fmt.Errorf("failed to unmarshal validation response: %w", err)
+		return err
 	}
 
 	respData, err := ncutils.BoxDecrypt(base64decode(licenseResponse.EncryptedLicense), apiPublicKey, tempPrivKey)
 	if err != nil {
-		failValidation(fmt.Errorf("failed to decrypt license: %w", err))
+		err = fmt.Errorf("failed to decrypt license: %w", err)
+		return err
 	}
 
 	license := LicenseKey{}
 	if err = json.Unmarshal(respData, &license); err != nil {
-		failValidation(fmt.Errorf("failed to unmarshal license key: %w", err))
+		err = fmt.Errorf("failed to unmarshal license key: %w", err)
+		return err
 	}
 
 	slog.Info("License validation succeeded!")
@@ -158,11 +175,6 @@ func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) {
 	return pub, priv, nil
 }
 
-func failValidation(err error) {
-	slog.Error(errValidation.Error(), "error", err)
-	os.Exit(0)
-}
-
 func getLicensePublicKey(licensePubKeyEncoded string) (*[32]byte, error) {
 	decodedPubKey := base64decode(licensePubKeyEncoded)
 	return ncutils.ConvertBytesToKey(decodedPubKey)

+ 8 - 4
logic/timer.go

@@ -3,10 +3,11 @@ package logic
 import (
 	"context"
 	"fmt"
+	"github.com/gravitl/netmaker/logger"
+	"golang.org/x/exp/slog"
 	"sync"
 	"time"
 
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 )
 
@@ -52,7 +53,7 @@ func StartHookManager(ctx context.Context, wg *sync.WaitGroup) {
 	for {
 		select {
 		case <-ctx.Done():
-			logger.Log(0, "## Stopping Hook Manager")
+			slog.Error("## Stopping Hook Manager")
 			return
 		case newhook := <-HookManagerCh:
 			wg.Add(1)
@@ -70,7 +71,9 @@ func addHookWithInterval(ctx context.Context, wg *sync.WaitGroup, hook func() er
 		case <-ctx.Done():
 			return
 		case <-ticker.C:
-			hook()
+			if err := hook(); err != nil {
+				slog.Error(err.Error())
+			}
 		}
 	}
 
@@ -85,6 +88,7 @@ var timeHooks = []interface{}{
 }
 
 func loggerDump() error {
+	// TODO use slog?
 	logger.DumpFile(fmt.Sprintf("data/netmaker.log.%s", time.Now().Format(logger.TimeFormatDay)))
 	return nil
 }
@@ -93,7 +97,7 @@ func loggerDump() error {
 func runHooks() {
 	for _, hook := range timeHooks {
 		if err := hook.(func() error)(); err != nil {
-			logger.Log(1, "error occurred when running timer function:", err.Error())
+			slog.Error("error occurred when running timer function", "error", err.Error())
 		}
 	}
 }

+ 3 - 2
servercfg/serverconf.go

@@ -18,8 +18,9 @@ import (
 const EmqxBrokerType = "emqx"
 
 var (
-	Version = "dev"
-	Is_EE   = false
+	Version              = "dev"
+	Is_EE                = false
+	ErrLicenseValidation error
 )
 
 // SetHost - sets the host ip