Browse Source

add trial license logic

abhishek9686 1 year ago
parent
commit
6749fb4516
6 changed files with 209 additions and 41 deletions
  1. 22 22
      database/database.go
  2. 6 6
      logic/telemetry.go
  3. 4 3
      logic/timer.go
  4. 2 2
      logic/traffic.go
  5. 29 8
      pro/initialize.go
  6. 146 0
      pro/trial.go

+ 22 - 22
database/database.go

@@ -124,29 +124,29 @@ func InitializeDatabase() error {
 }
 
 func createTables() {
-	createTable(NETWORKS_TABLE_NAME)
-	createTable(NODES_TABLE_NAME)
-	createTable(CERTS_TABLE_NAME)
-	createTable(DELETED_NODES_TABLE_NAME)
-	createTable(USERS_TABLE_NAME)
-	createTable(DNS_TABLE_NAME)
-	createTable(EXT_CLIENT_TABLE_NAME)
-	createTable(PEERS_TABLE_NAME)
-	createTable(SERVERCONF_TABLE_NAME)
-	createTable(SERVER_UUID_TABLE_NAME)
-	createTable(GENERATED_TABLE_NAME)
-	createTable(NODE_ACLS_TABLE_NAME)
-	createTable(SSO_STATE_CACHE)
-	createTable(METRICS_TABLE_NAME)
-	createTable(NETWORK_USER_TABLE_NAME)
-	createTable(USER_GROUPS_TABLE_NAME)
-	createTable(CACHE_TABLE_NAME)
-	createTable(HOSTS_TABLE_NAME)
-	createTable(ENROLLMENT_KEYS_TABLE_NAME)
-	createTable(HOST_ACTIONS_TABLE_NAME)
+	CreateTable(NETWORKS_TABLE_NAME)
+	CreateTable(NODES_TABLE_NAME)
+	CreateTable(CERTS_TABLE_NAME)
+	CreateTable(DELETED_NODES_TABLE_NAME)
+	CreateTable(USERS_TABLE_NAME)
+	CreateTable(DNS_TABLE_NAME)
+	CreateTable(EXT_CLIENT_TABLE_NAME)
+	CreateTable(PEERS_TABLE_NAME)
+	CreateTable(SERVERCONF_TABLE_NAME)
+	CreateTable(SERVER_UUID_TABLE_NAME)
+	CreateTable(GENERATED_TABLE_NAME)
+	CreateTable(NODE_ACLS_TABLE_NAME)
+	CreateTable(SSO_STATE_CACHE)
+	CreateTable(METRICS_TABLE_NAME)
+	CreateTable(NETWORK_USER_TABLE_NAME)
+	CreateTable(USER_GROUPS_TABLE_NAME)
+	CreateTable(CACHE_TABLE_NAME)
+	CreateTable(HOSTS_TABLE_NAME)
+	CreateTable(ENROLLMENT_KEYS_TABLE_NAME)
+	CreateTable(HOST_ACTIONS_TABLE_NAME)
 }
 
-func createTable(tableName string) error {
+func CreateTable(tableName string) error {
 	return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
 }
 
@@ -194,7 +194,7 @@ func DeleteAllRecords(tableName string) error {
 	if err != nil {
 		return err
 	}
-	err = createTable(tableName)
+	err = CreateTable(tableName)
 	if err != nil {
 		return err
 	}

+ 6 - 6
logic/telemetry.go

@@ -32,12 +32,12 @@ func sendTelemetry() error {
 		return nil
 	}
 
-	var telRecord, err = fetchTelemetryRecord()
+	var telRecord, err = FetchTelemetryRecord()
 	if err != nil {
 		return err
 	}
 	// get telemetry data
-	d, err := fetchTelemetryData()
+	d, err := FetchTelemetryData()
 	if err != nil {
 		return err
 	}
@@ -71,8 +71,8 @@ func sendTelemetry() error {
 	})
 }
 
-// fetchTelemetry - fetches telemetry data: count of various object types in DB
-func fetchTelemetryData() (telemetryData, error) {
+// FetchTelemetryData - fetches telemetry data: count of various object types in DB
+func FetchTelemetryData() (telemetryData, error) {
 	var data telemetryData
 
 	data.IsPro = servercfg.IsPro
@@ -138,8 +138,8 @@ func getClientCount(nodes []models.Node) clientCount {
 	return count
 }
 
-// fetchTelemetryRecord - get the existing UUID and Timestamp from the DB
-func fetchTelemetryRecord() (models.Telemetry, error) {
+// FetchTelemetryRecord - get the existing UUID and Timestamp from the DB
+func FetchTelemetryRecord() (models.Telemetry, error) {
 	var rawData string
 	var telObj models.Telemetry
 	var err error

+ 4 - 3
logic/timer.go

@@ -3,11 +3,12 @@ package logic
 import (
 	"context"
 	"fmt"
-	"github.com/gravitl/netmaker/logger"
-	"golang.org/x/exp/slog"
 	"sync"
 	"time"
 
+	"github.com/gravitl/netmaker/logger"
+	"golang.org/x/exp/slog"
+
 	"github.com/gravitl/netmaker/models"
 )
 
@@ -24,7 +25,7 @@ var HookManagerCh = make(chan models.HookDetails, 3)
 // TimerCheckpoint - Checks if 24 hours has passed since telemetry was last sent. If so, sends telemetry data to posthog
 func TimerCheckpoint() error {
 	// get the telemetry record in the DB, which contains a timestamp
-	telRecord, err := fetchTelemetryRecord()
+	telRecord, err := FetchTelemetryRecord()
 	if err != nil {
 		return err
 	}

+ 2 - 2
logic/traffic.go

@@ -2,7 +2,7 @@ package logic
 
 // RetrievePrivateTrafficKey - retrieves private key of server
 func RetrievePrivateTrafficKey() ([]byte, error) {
-	var telRecord, err = fetchTelemetryRecord()
+	var telRecord, err = FetchTelemetryRecord()
 	if err != nil {
 		return nil, err
 	}
@@ -12,7 +12,7 @@ func RetrievePrivateTrafficKey() ([]byte, error) {
 
 // RetrievePublicTrafficKey - retrieves public key of server
 func RetrievePublicTrafficKey() ([]byte, error) {
-	var telRecord, err = fetchTelemetryRecord()
+	var telRecord, err = FetchTelemetryRecord()
 	if err != nil {
 		return nil, err
 	}

+ 29 - 8
pro/initialize.go

@@ -4,6 +4,8 @@
 package pro
 
 import (
+	"time"
+
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
@@ -17,6 +19,7 @@ import (
 // InitPro - Initialize Pro Logic
 func InitPro() {
 	servercfg.IsPro = true
+	proLogic.InitTrial()
 	models.SetLogo(retrieveProLogo())
 	controller.HttpMiddlewares = append(
 		controller.HttpMiddlewares,
@@ -31,18 +34,36 @@ func InitPro() {
 	)
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
 		// == License Handling ==
-		ClearLicenseCache()
-		if err := ValidateLicense(); err != nil {
-			slog.Error(err.Error())
-			return
+		enableLicenseHook := false
+		trialEndDate, err := getTrialEndDate()
+		if err != nil {
+			slog.Error("failed to get trial end date", "error", err)
+			enableLicenseHook = true
+		}
+		// check if trial ended
+		if time.Now().After(trialEndDate) {
+			// trial ended already
+			enableLicenseHook = true
+		}
+		if enableLicenseHook {
+			slog.Info("starting license checker")
+			ClearLicenseCache()
+			if err := ValidateLicense(); err != nil {
+				slog.Error(err.Error())
+				return
+			}
+			slog.Info("proceeding with Paid Tier license")
+			logic.SetFreeTierForTelemetry(false)
+			// == End License Handling ==
+			AddLicenseHooks()
+		} else {
+			addTrialLicenseHook()
 		}
-		slog.Info("proceeding with Paid Tier license")
-		logic.SetFreeTierForTelemetry(false)
-		// == End License Handling ==
-		AddLicenseHooks()
+
 		if servercfg.GetServerConfig().RacAutoDisable {
 			AddRacHooks()
 		}
+
 	})
 	logic.ResetFailOver = proLogic.ResetFailOver
 	logic.ResetFailedOverPeer = proLogic.ResetFailedOverPeer

+ 146 - 0
pro/trial.go

@@ -0,0 +1,146 @@
+//go:build ee
+// +build ee
+
+package pro
+
+import (
+	"crypto/rand"
+	"encoding/json"
+	"errors"
+	"time"
+
+	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/netclient/ncutils"
+	"golang.org/x/crypto/nacl/box"
+	"golang.org/x/exp/slog"
+)
+
+type TrialInfo struct {
+	PrivKey []byte `json:"priv_key"`
+	PubKey  []byte `json:"pub_key"`
+	Secret  string `json:"secret"`
+}
+
+func addTrialLicenseHook() {
+	logic.HookManagerCh <- models.HookDetails{
+		Hook:     TrialLicenseHook,
+		Interval: time.Hour,
+	}
+}
+
+type TrialDates struct {
+	TrialStartedAt time.Time `json:"trial_started_at"`
+	TrialEndsAt    time.Time `json:"trial_ends_at"`
+}
+
+const trial_table_name = "trial"
+
+const trial_data_key = "trialdata"
+
+// store trial date
+func InitTrial() error {
+	telData, err := logic.FetchTelemetryData()
+	if err != nil {
+		return err
+	}
+	if telData.Hosts > 0 || telData.Networks > 0 || telData.Users > 0 {
+		return nil
+	}
+	err = database.CreateTable(trial_table_name)
+	if err != nil {
+		slog.Error("failed to create table", "table name", trial_table_name, "err", err.Error())
+		return err
+	}
+	// setup encryption keys
+	trafficPubKey, trafficPrivKey, err := box.GenerateKey(rand.Reader) // generate traffic keys
+	if err != nil {
+		return err
+	}
+	tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
+	if err != nil {
+		return err
+	}
+
+	tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
+	if err != nil {
+		return err
+	}
+	trialDates := TrialDates{
+		TrialStartedAt: time.Now(),
+		TrialEndsAt:    time.Now().Add(time.Hour * 24 * 30),
+	}
+	t := TrialInfo{
+		PrivKey: tPriv,
+		PubKey:  tPub,
+	}
+	tel, err := logic.FetchTelemetryRecord()
+	if err != nil {
+		return err
+	}
+
+	trialDatesData, err := json.Marshal(trialDates)
+	if err != nil {
+		return err
+	}
+	trialDatesSecret, err := ncutils.BoxEncrypt(trialDatesData, (*[32]byte)(tel.TrafficKeyPub), (*[32]byte)(t.PrivKey))
+	if err != nil {
+		return err
+	}
+	t.Secret = string(trialDatesSecret)
+	trialData, err := json.Marshal(t)
+	if err != nil {
+		return err
+	}
+	err = database.Insert(trial_data_key, string(trialData), trial_table_name)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func TrialLicenseHook() error {
+	endDate, err := getTrialEndDate()
+	if err != nil {
+		logger.FatalLog0("failed to trial end date", err.Error())
+	}
+	if time.Now().After(endDate) {
+		logger.FatalLog0("***IMPORTANT: Your Trial Has Ended, to continue using pro version, please visit https://app.netmaker.io/ and create on-prem tenant to obtain a license***\nIf you wish to downgrade to community version, please run this command `/root/nm-quick.sh -d`")
+
+	}
+	return nil
+}
+
+// get trial date
+func getTrialEndDate() (time.Time, error) {
+	record, err := database.FetchRecord(trial_table_name, trial_data_key)
+	if err != nil {
+		return time.Time{}, err
+	}
+	var trialInfo TrialInfo
+	err = json.Unmarshal([]byte(record), &trialInfo)
+	if err != nil {
+		return time.Time{}, err
+	}
+	tel, err := logic.FetchTelemetryRecord()
+	if err != nil {
+		return time.Time{}, err
+	}
+	// decrypt secret
+	secretDecrypt, err := ncutils.BoxDecrypt([]byte(trialInfo.Secret), (*[32]byte)(trialInfo.PubKey), (*[32]byte)(tel.TrafficKeyPriv))
+	if err != nil {
+		return time.Time{}, err
+	}
+	trialDates := TrialDates{}
+	err = json.Unmarshal(secretDecrypt, &trialDates)
+	if err != nil {
+		return time.Time{}, err
+	}
+	if trialDates.TrialEndsAt.IsZero() {
+		return time.Time{}, errors.New("invalid date")
+	}
+	return trialDates.TrialEndsAt, nil
+
+}