Browse Source

added try to use func and edited tests

0xdcarns 2 years ago
parent
commit
0e5e34ef0c
3 changed files with 76 additions and 27 deletions
  1. 32 16
      logic/enrollmentkey.go
  2. 36 11
      logic/enrollmentkey_test.go
  3. 8 0
      logic/host_test.go

+ 32 - 16
logic/enrollmentkey.go

@@ -2,6 +2,7 @@ package logic
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"fmt"
 	"time"
 	"time"
 
 
@@ -12,15 +13,15 @@ import (
 
 
 // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages
 // EnrollmentKeyErrors - struct for holding EnrollmentKey error messages
 var EnrollmentKeyErrors = struct {
 var EnrollmentKeyErrors = struct {
-	InvalidCreate   string
-	NoKeyFound      string
-	InvalidKey      string
-	NoUsesRemaining string
+	InvalidCreate   error
+	NoKeyFound      error
+	InvalidKey      error
+	NoUsesRemaining error
 }{
 }{
-	InvalidCreate:   "invalid enrollment key created",
-	NoKeyFound:      "no enrollmentkey found",
-	InvalidKey:      "invalid key provided",
-	NoUsesRemaining: "no uses remaining",
+	InvalidCreate:   fmt.Errorf("invalid enrollment key created"),
+	NoKeyFound:      fmt.Errorf("no enrollmentkey found"),
+	InvalidKey:      fmt.Errorf("invalid key provided"),
+	NoUsesRemaining: fmt.Errorf("no uses remaining"),
 }
 }
 
 
 // CreateEnrollmentKey - creates a new enrollment key in db
 // CreateEnrollmentKey - creates a new enrollment key in db
@@ -50,7 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 		k.Tags = tags
 		k.Tags = tags
 	}
 	}
 	if ok := k.Validate(); !ok {
 	if ok := k.Validate(); !ok {
-		return nil, fmt.Errorf(EnrollmentKeyErrors.InvalidCreate)
+		return nil, EnrollmentKeyErrors.InvalidCreate
 	}
 	}
 	if err = upsertEnrollmentKey(k); err != nil {
 	if err = upsertEnrollmentKey(k); err != nil {
 		return nil, err
 		return nil, err
@@ -81,7 +82,7 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
 	if key, ok := currentKeys[value]; ok {
 	if key, ok := currentKeys[value]; ok {
 		return key, nil
 		return key, nil
 	}
 	}
-	return nil, fmt.Errorf(EnrollmentKeyErrors.NoKeyFound)
+	return nil, EnrollmentKeyErrors.NoKeyFound
 }
 }
 
 
 // DeleteEnrollmentKey - delete's a given enrollment key by value
 // DeleteEnrollmentKey - delete's a given enrollment key by value
@@ -93,14 +94,31 @@ func DeleteEnrollmentKey(value string) error {
 	return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
 	return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
 }
 }
 
 
-// DecrementEnrollmentKey - decrements the uses on a key if above 0 remaining
-func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
+// TryToUseEnrollmentKey - checks first if key can be decremented
+// returns true if it is decremented or isvalid
+func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
+	key, err := decrementEnrollmentKey(k.Value)
+	if err != nil {
+		if errors.Is(err, EnrollmentKeyErrors.NoUsesRemaining) {
+			return k.IsValid()
+		}
+	} else {
+		k.UsesRemaining = key.UsesRemaining
+		return true
+	}
+	return false
+}
+
+// == private ==
+
+// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
+func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
 	k, err := GetEnrollmentKey(value)
 	k, err := GetEnrollmentKey(value)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	if k.UsesRemaining == 0 {
 	if k.UsesRemaining == 0 {
-		return nil, fmt.Errorf(EnrollmentKeyErrors.NoUsesRemaining)
+		return nil, EnrollmentKeyErrors.NoUsesRemaining
 	}
 	}
 	k.UsesRemaining = k.UsesRemaining - 1
 	k.UsesRemaining = k.UsesRemaining - 1
 	if err = upsertEnrollmentKey(k); err != nil {
 	if err = upsertEnrollmentKey(k); err != nil {
@@ -110,11 +128,9 @@ func DecrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
 	return k, nil
 	return k, nil
 }
 }
 
 
-// == private ==
-
 func upsertEnrollmentKey(k *models.EnrollmentKey) error {
 func upsertEnrollmentKey(k *models.EnrollmentKey) error {
 	if k == nil {
 	if k == nil {
-		return fmt.Errorf(EnrollmentKeyErrors.InvalidKey)
+		return EnrollmentKeyErrors.InvalidKey
 	}
 	}
 	data, err := json.Marshal(k)
 	data, err := json.Marshal(k)
 	if err != nil {
 	if err != nil {

+ 36 - 11
logic/enrollmentkey_test.go

@@ -15,7 +15,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
 		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
 		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
 		assert.Nil(t, newKey)
 		assert.Nil(t, newKey)
 		assert.NotNil(t, err)
 		assert.NotNil(t, err)
-		assert.Equal(t, err.Error(), EnrollmentKeyErrors.InvalidCreate)
+		assert.Equal(t, err, EnrollmentKeyErrors.InvalidCreate)
 	})
 	})
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
 		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
 		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
@@ -59,12 +59,12 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 		oldKey, err := GetEnrollmentKey(newKey.Value)
 		oldKey, err := GetEnrollmentKey(newKey.Value)
 		assert.Nil(t, oldKey)
 		assert.Nil(t, oldKey)
 		assert.NotNil(t, err)
 		assert.NotNil(t, err)
-		assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound)
+		assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound)
 	})
 	})
 	t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
 	t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
 		err := DeleteEnrollmentKey("notakey")
 		err := DeleteEnrollmentKey("notakey")
 		assert.NotNil(t, err)
 		assert.NotNil(t, err)
-		assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoKeyFound)
+		assert.Equal(t, err, EnrollmentKeyErrors.NoKeyFound)
 	})
 	})
 	removeAllEnrollments()
 	removeAllEnrollments()
 }
 }
@@ -72,32 +72,57 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 func TestDecrement_EnrollmentKey(t *testing.T) {
 func TestDecrement_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, true)
+	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
 	t.Run("Check_initial_uses", func(t *testing.T) {
 	t.Run("Check_initial_uses", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		assert.True(t, newKey.IsValid())
 		assert.Equal(t, newKey.UsesRemaining, 1)
 		assert.Equal(t, newKey.UsesRemaining, 1)
 	})
 	})
 	t.Run("Check can decrement", func(t *testing.T) {
 	t.Run("Check can decrement", func(t *testing.T) {
 		assert.Equal(t, newKey.UsesRemaining, 1)
 		assert.Equal(t, newKey.UsesRemaining, 1)
-		k, err := DecrementEnrollmentKey(newKey.Value)
+		k, err := decrementEnrollmentKey(newKey.Value)
 		assert.Nil(t, err)
 		assert.Nil(t, err)
 		newKey = k
 		newKey = k
 	})
 	})
 	t.Run("Check can not decrement", func(t *testing.T) {
 	t.Run("Check can not decrement", func(t *testing.T) {
 		assert.Equal(t, newKey.UsesRemaining, 0)
 		assert.Equal(t, newKey.UsesRemaining, 0)
-		_, err := DecrementEnrollmentKey(newKey.Value)
+		_, err := decrementEnrollmentKey(newKey.Value)
 		assert.NotNil(t, err)
 		assert.NotNil(t, err)
-		assert.Equal(t, err.Error(), EnrollmentKeyErrors.NoUsesRemaining)
+		assert.Equal(t, err, EnrollmentKeyErrors.NoUsesRemaining)
 	})
 	})
 
 
 	removeAllEnrollments()
 	removeAllEnrollments()
 }
 }
 
 
-// func TestValidity_EnrollmentKey(t *testing.T) {
-// 	database.InitializeDatabase()
-// 	defer database.CloseDB()
+func TestUsability_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
+	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
+	t.Run("Check if valid use key can be used", func(t *testing.T) {
+		assert.Equal(t, key1.UsesRemaining, 1)
+		ok := TryToUseEnrollmentKey(key1)
+		assert.True(t, ok)
+		assert.Equal(t, 0, key1.UsesRemaining)
+	})
+
+	t.Run("Check if valid time key can be used", func(t *testing.T) {
+		assert.True(t, !key2.Expiration.IsZero())
+		ok := TryToUseEnrollmentKey(key2)
+		assert.True(t, ok)
+	})
 
 
-// }
+	t.Run("Check if valid unlimited key can be used", func(t *testing.T) {
+		assert.True(t, key3.Unlimited)
+		ok := TryToUseEnrollmentKey(key3)
+		assert.True(t, ok)
+	})
+
+	t.Run("check invalid key can not be used", func(t *testing.T) {
+		ok := TryToUseEnrollmentKey(key1)
+		assert.False(t, ok)
+	})
+}
 
 
 func removeAllEnrollments() {
 func removeAllEnrollments() {
 	database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
 	database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)

+ 8 - 0
logic/host_test.go

@@ -1,6 +1,7 @@
 package logic
 package logic
 
 
 import (
 import (
+	"context"
 	"net"
 	"net"
 	"testing"
 	"testing"
 
 
@@ -13,6 +14,13 @@ import (
 func TestCheckPorts(t *testing.T) {
 func TestCheckPorts(t *testing.T) {
 	database.InitializeDatabase()
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	defer database.CloseDB()
+	peerUpdate := make(chan *models.Node)
+	go ManageZombies(context.Background(), peerUpdate)
+	go func() {
+		for _ = range peerUpdate {
+			//do nothing
+		}
+	}()
 
 
 	h := models.Host{
 	h := models.Host{
 		ID:              uuid.New(),
 		ID:              uuid.New(),