Browse Source

fix(NET-786): enhance enrollment key validation (#2726)

Aceix 1 year ago
parent
commit
033e203d91
4 changed files with 59 additions and 10 deletions
  1. 30 0
      controllers/enrollmentkeys.go
  2. 3 3
      logic/enrollmentkey.go
  3. 1 1
      logic/enrollmentkey_test.go
  4. 25 6
      models/enrollment_key.go

+ 30 - 0
controllers/enrollmentkeys.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"time"
 
+	"github.com/go-playground/validator/v10"
 	"github.com/google/uuid"
 	"github.com/gorilla/mux"
 
@@ -115,6 +116,35 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 	if enrollmentKeyBody.Expiration > 0 {
 		newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
 	}
+	v := validator.New()
+	err = v.Struct(enrollmentKeyBody)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error validating request body: ",
+			err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("validation error: name length must be between 3 and 32: %w", err), "badrequest"))
+		return
+	}
+
+	if existingKeys, err := logic.GetAllEnrollmentKeys(); err != nil {
+		logger.Log(0, r.Header.Get("user"), "error validating request body: ",
+			err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	} else {
+		// check if any tags are duplicate
+		existingTags := make(map[string]struct{})
+		for _, existingKey := range existingKeys {
+			for _, t := range existingKey.Tags {
+				existingTags[t] = struct{}{}
+			}
+		}
+		for _, t := range enrollmentKeyBody.Tags {
+			if _, ok := existingTags[t]; ok {
+				logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("key names must be unique"), "badrequest"))
+				return
+			}
+		}
+	}
 
 	relayId := uuid.Nil
 	if enrollmentKeyBody.Relay != "" {

+ 3 - 3
logic/enrollmentkey.go

@@ -22,7 +22,7 @@ var EnrollmentErrors = struct {
 	FailedToTokenize   error
 	FailedToDeTokenize error
 }{
-	InvalidCreate:      fmt.Errorf("invalid enrollment key created"),
+	InvalidCreate:      fmt.Errorf("failed to create enrollment key. paramters invalid"),
 	NoKeyFound:         fmt.Errorf("no enrollmentkey found"),
 	InvalidKey:         fmt.Errorf("invalid key provided"),
 	NoUsesRemaining:    fmt.Errorf("no uses remaining"),
@@ -61,8 +61,8 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 	if len(tags) > 0 {
 		k.Tags = tags
 	}
-	if ok := k.Validate(); !ok {
-		return nil, EnrollmentErrors.InvalidCreate
+	if err := k.Validate(); err != nil {
+		return nil, err
 	}
 	if relay != uuid.Nil {
 		relayNode, err := GetNodeByID(relay.String())

+ 1 - 1
logic/enrollmentkey_test.go

@@ -17,7 +17,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
 		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil)
 		assert.Nil(t, newKey)
 		assert.NotNil(t, err)
-		assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
+		assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
 	})
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
 		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)

+ 25 - 6
models/enrollment_key.go

@@ -1,6 +1,8 @@
 package models
 
 import (
+	"errors"
+	"fmt"
 	"time"
 
 	"github.com/google/uuid"
@@ -13,6 +15,14 @@ const (
 	Unlimited
 )
 
+var (
+	ErrNilEnrollmentKey          = errors.New("enrollment key is nil")
+	ErrNilNetworksEnrollmentKey  = errors.New("enrollment key networks is nil")
+	ErrNilTagsEnrollmentKey      = errors.New("enrollment key tags is nil")
+	ErrInvalidEnrollmentKey      = errors.New("enrollment key is not valid")
+	ErrInvalidEnrollmentKeyValue = errors.New("enrollment key value is not valid")
+)
+
 // KeyType - the type of enrollment key
 type KeyType int
 
@@ -50,7 +60,7 @@ type APIEnrollmentKey struct {
 	UsesRemaining int      `json:"uses_remaining"`
 	Networks      []string `json:"networks"`
 	Unlimited     bool     `json:"unlimited"`
-	Tags          []string `json:"tags"`
+	Tags          []string `json:"tags" validate:"required,dive,min=3,max=32"`
 	Type          KeyType  `json:"type"`
 	Relay         string   `json:"relay"`
 }
@@ -81,9 +91,18 @@ func (k *EnrollmentKey) IsValid() bool {
 
 // EnrollmentKey.Validate - validate's an EnrollmentKey
 // should be used during creation
-func (k *EnrollmentKey) Validate() bool {
-	return k.Networks != nil &&
-		k.Tags != nil &&
-		len(k.Value) == EnrollmentKeyLength &&
-		k.IsValid()
+func (k *EnrollmentKey) Validate() error {
+	if k == nil {
+		return ErrNilEnrollmentKey
+	}
+	if k.Tags == nil {
+		return ErrNilTagsEnrollmentKey
+	}
+	if len(k.Value) != EnrollmentKeyLength {
+		return fmt.Errorf("%w: length not %d characters", ErrInvalidEnrollmentKeyValue, EnrollmentKeyLength)
+	}
+	if !k.IsValid() {
+		return fmt.Errorf("%w: uses remaining: %d, expiration: %s, unlimited: %t", ErrInvalidEnrollmentKey, k.UsesRemaining, k.Expiration, k.Unlimited)
+	}
+	return nil
 }