2
0
Эх сурвалжийг харах

add tag groups to enrollment key

abhishek9686 11 сар өмнө
parent
commit
db2550b7bd

+ 10 - 1
controllers/enrollmentkeys.go

@@ -156,6 +156,7 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 		newTime,
 		enrollmentKeyBody.Networks,
 		enrollmentKeyBody.Tags,
+		enrollmentKeyBody.Groups,
 		enrollmentKeyBody.Unlimited,
 		relayId,
 	)
@@ -206,7 +207,7 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId)
+	newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId, enrollmentKeyBody.Groups)
 	if err != nil {
 		slog.Error("failed to update enrollment key", "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
@@ -307,6 +308,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 				return
 			}
 		}
+		newHost.Tags = make(map[models.TagID]struct{})
+		for _, tagI := range enrollmentKey.Groups {
+			newHost.Tags[tagI] = struct{}{}
+		}
 		if err = logic.CreateHost(&newHost); err != nil {
 			logger.Log(
 				0,
@@ -337,6 +342,10 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 		logic.UpdateHostFromClient(&newHost, currHost)
+		currHost.Tags = make(map[models.TagID]struct{})
+		for _, tagI := range enrollmentKey.Groups {
+			currHost.Tags[tagI] = struct{}{}
+		}
 		err = logic.UpsertHost(currHost)
 		if err != nil {
 			slog.Error("failed to update host", "id", currHost.ID, "error", err)

+ 4 - 3
logic/enrollmentkey.go

@@ -37,7 +37,7 @@ var (
 )
 
 // CreateEnrollmentKey - creates a new enrollment key in db
-func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
+func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
 	newKeyID, err := getUniqueEnrollmentID()
 	if err != nil {
 		return nil, err
@@ -51,6 +51,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 		Tags:          []string{},
 		Type:          models.Undefined,
 		Relay:         relay,
+		Groups:        groups,
 	}
 	if uses > 0 {
 		k.UsesRemaining = uses
@@ -89,7 +90,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 }
 
 // UpdateEnrollmentKey - updates an existing enrollment key's associated relay
-func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) {
+func UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID) (*models.EnrollmentKey, error) {
 	key, err := GetEnrollmentKey(keyId)
 	if err != nil {
 		return nil, err
@@ -109,7 +110,7 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey
 	}
 
 	key.Relay = relayId
-
+	key.Groups = groups
 	if err = upsertEnrollmentKey(&key); err != nil {
 		return nil, err
 	}

+ 13 - 13
logic/enrollmentkey_test.go

@@ -14,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	t.Run("Can_Not_Create_Key", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil)
 		assert.Nil(t, newKey)
 		assert.NotNil(t, err)
 		assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey)
 	})
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
+		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil)
 		assert.Nil(t, err)
 		assert.Equal(t, 1, newKey.UsesRemaining)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Time", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false, uuid.Nil)
+		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Networks) == 2)
 	})
 	t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true, uuid.Nil)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Tags) == 2)
@@ -62,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
 func TestDelete_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
 	t.Run("Can_Delete_Key", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		err := DeleteEnrollmentKey(newKey.Value)
@@ -83,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 func TestDecrement_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
+	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil)
 	t.Run("Check_initial_uses", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		assert.Equal(t, newKey.UsesRemaining, 1)
@@ -107,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
 func TestUsability_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
-	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false, uuid.Nil)
-	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
+	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil)
+	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil)
+	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil)
 	t.Run("Check if valid use key can be used", func(t *testing.T) {
 		assert.Equal(t, key1.UsesRemaining, 1)
 		ok := TryToUseEnrollmentKey(key1)
@@ -145,7 +145,7 @@ func removeAllEnrollments() {
 func TestTokenize_EnrollmentKeys(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
 	const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
 	const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
@@ -178,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
 func TestDeTokenize_EnrollmentKeys(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil)
 	const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
 

+ 2 - 0
models/enrollment_key.go

@@ -52,6 +52,7 @@ type EnrollmentKey struct {
 	Token         string    `json:"token,omitempty"` // B64 value of EnrollmentToken
 	Type          KeyType   `json:"type"`
 	Relay         uuid.UUID `json:"relay"`
+	Groups        []TagID   `json:"groups"`
 }
 
 // APIEnrollmentKey - used to create enrollment keys via API
@@ -63,6 +64,7 @@ type APIEnrollmentKey struct {
 	Tags          []string `json:"tags" validate:"required,dive,min=3,max=32"`
 	Type          KeyType  `json:"type"`
 	Relay         string   `json:"relay"`
+	Groups        []TagID  `json:"groups"`
 }
 
 // RegisterResponse - the response to a successful enrollment register