Browse Source

initial commit, began unit tests

0xdcarns 2 years ago
parent
commit
a5e7147b69
4 changed files with 246 additions and 0 deletions
  1. 3 0
      database/database.go
  2. 133 0
      logic/enrollment_key.go
  3. 70 0
      logic/enrollment_key_test.go
  4. 40 0
      models/enrollment_key.go

+ 3 - 0
database/database.go

@@ -57,6 +57,8 @@ const (
 	CACHE_TABLE_NAME = "cache"
 	// HOSTS_TABLE_NAME - the table name for hosts
 	HOSTS_TABLE_NAME = "hosts"
+	// ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys
+	ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys"
 
 	// == ERROR CONSTS ==
 	// NO_RECORD - no singular result found
@@ -138,6 +140,7 @@ func createTables() {
 	createTable(USER_GROUPS_TABLE_NAME)
 	createTable(CACHE_TABLE_NAME)
 	createTable(HOSTS_TABLE_NAME)
+	createTable(ENROLLMENT_KEYS_TABLE_NAME)
 }
 
 func createTable(tableName string) error {

+ 133 - 0
logic/enrollment_key.go

@@ -0,0 +1,133 @@
+package logic
+
+import (
+	"encoding/json"
+	"fmt"
+	"time"
+
+	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/netclient/ncutils"
+)
+
+// EnrollmentKeyErrors - struct for holding EnrollmentKey error messages
+var EnrollmentKeyErrors = struct {
+	InvalidCreate string
+	NoKeyFound    string
+	InvalidKey    string
+}{
+	InvalidCreate: "invalid enrollment key created",
+	NoKeyFound:    "no enrollmentkey found",
+	InvalidKey:    "invalid key provided",
+}
+
+// CreateEnrollmentKey - creates a new enrollment key in db
+func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
+	newKeyID, err := getUniqueEnrollmentID()
+	if err != nil {
+		return nil, err
+	}
+	k = &models.EnrollmentKey{
+		Value:         newKeyID,
+		Expiration:    time.Time{},
+		UsesRemaining: 0,
+		Unlimited:     unlimited,
+		Networks:      []string{},
+		Tags:          []string{},
+	}
+	if uses > 0 {
+		k.UsesRemaining = uses
+	}
+	if !expiration.IsZero() {
+		k.Expiration = expiration
+	}
+	if len(networks) > 0 {
+		k.Networks = networks
+	}
+	if len(tags) > 0 {
+		k.Tags = tags
+	}
+	if ok := k.Validate(); !ok {
+		return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate)
+	}
+	if err = upsertEnrollmentKey(k); err != nil {
+		return nil, err
+	}
+	return
+}
+
+// GetAllEnrollmentKeys - fetches all enrollment keys from DB
+func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
+	currentKeys, err := getEnrollmentKeysMap()
+	if err != nil {
+		return nil, err
+	}
+	var currentKeysList = make([]*models.EnrollmentKey, len(currentKeys))
+	for k := range currentKeys {
+		currentKeysList = append(currentKeysList, currentKeys[k])
+	}
+	return currentKeysList, nil
+}
+
+// GetEnrollmentKey - fetches a single enrollment key
+// returns nil and error if not found
+func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
+	currentKeys, err := getEnrollmentKeysMap()
+	if err != nil {
+		return nil, err
+	}
+	if key, ok := currentKeys[value]; ok {
+		return key, nil
+	}
+	return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound)
+}
+
+// DeleteEnrollmentKey - delete's a given enrollment key by value
+func DeleteEnrollmentKey(value string) error {
+	return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
+}
+
+// == private ==
+
+func upsertEnrollmentKey(k *models.EnrollmentKey) error {
+	if k == nil {
+		return fmt.Errorf(EnrollmentKeyErrors.InvalidKey)
+	}
+	data, err := json.Marshal(k)
+	if err != nil {
+		return err
+	}
+	return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
+}
+
+func getUniqueEnrollmentID() (string, error) {
+	currentKeys, err := getEnrollmentKeysMap()
+	if err != nil {
+		return "", err
+	}
+	newID := ncutils.MakeRandomString(32)
+	for _, ok := currentKeys[newID]; !ok; {
+		newID = ncutils.MakeRandomString(32)
+	}
+	return newID, nil
+}
+
+func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
+	records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
+	if err != nil {
+		if !database.IsEmptyRecord(err) {
+			return nil, err
+		}
+	}
+	currentKeys := make(map[string]*models.EnrollmentKey)
+	if len(records) > 0 {
+		for k := range records {
+			var currentKey models.EnrollmentKey
+			if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
+				continue
+			}
+			currentKeys[k] = &currentKey
+		}
+	}
+	return currentKeys, nil
+}

+ 70 - 0
logic/enrollment_key_test.go

@@ -0,0 +1,70 @@
+package logic
+
+import (
+	"testing"
+	"time"
+
+	"github.com/gravitl/netmaker/database"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCreate_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	t.Run("Can_Not_Create_Key", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
+		assert.Nil(t, newKey)
+		assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate)
+	})
+	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+		assert.Nil(t, err)
+		assert.Equal(t, 1, newKey.UsesRemaining)
+		assert.True(t, newKey.IsValid())
+	})
+	t.Run("Can_Create_Key_Time", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+	})
+	t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+	})
+	t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+		assert.True(t, len(newKey.Networks) == 2)
+	})
+	t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+		assert.True(t, len(newKey.Tags) == 2)
+	})
+	removeAllEnrollments()
+}
+
+func TestDelete_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+
+}
+
+func TestDecrement_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+
+}
+
+func TestValidity_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+
+}
+
+func removeAllEnrollments() {
+	database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
+}

+ 40 - 0
models/enrollment_key.go

@@ -0,0 +1,40 @@
+package models
+
+import "time"
+
+// EnrollmentKeyLength - the length of an enrollment key
+const EnrollmentKeyLength = 32
+
+// EnrollmentKey - the
+type EnrollmentKey struct {
+	Expiration    time.Time `json:"expiration"`
+	UsesRemaining int       `json:"uses_remaining"`
+	Value         string    `json:"value"`
+	Networks      []string  `json:"networks"`
+	Unlimited     bool      `json:"unlimited"`
+	Tags          []string  `json:"tags"`
+}
+
+// EnrollmentKey.IsValid - checks if the key is still valid to use
+func (k *EnrollmentKey) IsValid() bool {
+	if k == nil {
+		return false
+	}
+	if k.UsesRemaining > 0 {
+		return true
+	}
+	if !k.Expiration.IsZero() && time.Now().After(k.Expiration) {
+		return true
+	}
+
+	return k.Unlimited
+}
+
+// EnrollmentKey.Validate - validate's an EnrollmentKey
+// should be used during creation
+func (k *EnrollmentKey) Validate() bool {
+	return k.Networks != nil &&
+		k.Tags != nil &&
+		len(k.Value) == EnrollmentKeyLength &&
+		k.IsValid()
+}