|
@@ -5,6 +5,7 @@ import (
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
+ "strings"
|
|
|
"sync"
|
|
|
"time"
|
|
|
|
|
@@ -120,7 +121,6 @@ func UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID)
|
|
|
}
|
|
|
|
|
|
// GetAllEnrollmentKeys - fetches all enrollment keys from DB
|
|
|
-// TODO drop double pointer
|
|
|
func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) {
|
|
|
currentKeys, err := getEnrollmentKeysMap()
|
|
|
if err != nil {
|
|
@@ -335,3 +335,59 @@ func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) {
|
|
|
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error {
|
|
|
+ keys, err := GetAllEnrollmentKeys()
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("failed to retrieve keys: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ var errs []error
|
|
|
+ for _, key := range keys {
|
|
|
+ newNetworks := []string{}
|
|
|
+ newTags := []models.TagID{}
|
|
|
+ update := false
|
|
|
+
|
|
|
+ // Check and update networks
|
|
|
+ for _, net := range key.Networks {
|
|
|
+ if net == network {
|
|
|
+ update = true
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ newNetworks = append(newNetworks, net)
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check and update tags
|
|
|
+ for _, tag := range key.Groups {
|
|
|
+ tagParts := strings.Split(tag.String(), ".")
|
|
|
+ if len(tagParts) == 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ tagNetwork := tagParts[0]
|
|
|
+ if tagNetwork == network {
|
|
|
+ update = true
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ newTags = append(newTags, tag)
|
|
|
+ }
|
|
|
+
|
|
|
+ if update && len(newNetworks) == 0 && delete {
|
|
|
+ if err := DeleteEnrollmentKey(key.Value, true); err != nil {
|
|
|
+ errs = append(errs, fmt.Errorf("failed to delete key %s: %w", key.Value, err))
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if update {
|
|
|
+ key.Networks = newNetworks
|
|
|
+ key.Groups = newTags
|
|
|
+ if err := upsertEnrollmentKey(&key); err != nil {
|
|
|
+ errs = append(errs, fmt.Errorf("failed to update key %s: %w", key.Value, err))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(errs) > 0 {
|
|
|
+ return fmt.Errorf("errors unlinking network/tags from keys: %v", errs)
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+}
|