| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 | package logicimport (	b64 "encoding/base64"	"encoding/json"	"errors"	"fmt"	"strings"	"sync"	"time"	"github.com/google/uuid"	"github.com/gravitl/netmaker/database"	"github.com/gravitl/netmaker/models"	"github.com/gravitl/netmaker/servercfg"	"golang.org/x/exp/slices")// EnrollmentErrors - struct for holding EnrollmentKey error messagesvar EnrollmentErrors = struct {	InvalidCreate      error	NoKeyFound         error	InvalidKey         error	NoUsesRemaining    error	FailedToTokenize   error	FailedToDeTokenize error}{	InvalidCreate:      fmt.Errorf("failed to create enrollment key. paramters invalid"),	NoKeyFound:         fmt.Errorf("no enrollmentkey found"),	InvalidKey:         fmt.Errorf("invalid key provided"),	NoUsesRemaining:    fmt.Errorf("no uses remaining"),	FailedToTokenize:   fmt.Errorf("failed to tokenize"),	FailedToDeTokenize: fmt.Errorf("failed to detokenize"),}var (	enrollmentkeyCacheMutex = &sync.RWMutex{}	enrollmentkeyCacheMap   = make(map[string]models.EnrollmentKey))// CreateEnrollmentKey - creates a new enrollment key in dbfunc CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, defaultKey, autoEgress bool) (*models.EnrollmentKey, error) {	newKeyID, err := getUniqueEnrollmentID()	if err != nil {		return nil, err	}	k := &models.EnrollmentKey{		Value:         newKeyID,		Expiration:    time.Time{},		UsesRemaining: 0,		Unlimited:     unlimited,		Networks:      []string{},		Tags:          []string{},		Type:          models.Undefined,		Relay:         relay,		Groups:        groups,		Default:       defaultKey,		AutoEgress:    autoEgress,	}	if uses > 0 {		k.UsesRemaining = uses		k.Type = models.Uses	} else if !expiration.IsZero() {		k.Expiration = expiration		k.Type = models.TimeExpiration	} else if k.Unlimited {		k.Type = models.Unlimited	}	if len(networks) > 0 {		k.Networks = networks	}	if len(tags) > 0 {		k.Tags = tags	}	if err := k.Validate(); err != nil {		return nil, err	}	if relay != uuid.Nil {		relayNode, err := GetNodeByID(relay.String())		if err != nil {			return nil, err		}		if !slices.Contains(k.Networks, relayNode.Network) {			return nil, errors.New("relay node not in key's networks")		}		if !relayNode.IsRelay {			return nil, errors.New("relay node is not a relay")		}	}	if err = upsertEnrollmentKey(k); err != nil {		return nil, err	}	return k, nil}// UpdateEnrollmentKey - updates an existing enrollment key's associated relayfunc UpdateEnrollmentKey(keyId string, relayId uuid.UUID, groups []models.TagID) (*models.EnrollmentKey, error) {	key, err := GetEnrollmentKey(keyId)	if err != nil {		return nil, err	}	if relayId != uuid.Nil {		relayNode, err := GetNodeByID(relayId.String())		if err != nil {			return nil, err		}		if !slices.Contains(key.Networks, relayNode.Network) {			return nil, errors.New("relay node not in key's networks")		}		if !relayNode.IsRelay {			return nil, errors.New("relay node is not a relay")		}	}	key.Relay = relayId	key.Groups = groups	if err = upsertEnrollmentKey(&key); err != nil {		return nil, err	}	return &key, nil}// GetAllEnrollmentKeys - fetches all enrollment keys from DBfunc GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) {	currentKeys, err := getEnrollmentKeysMap()	if err != nil {		return nil, err	}	var currentKeysList = []models.EnrollmentKey{}	for k := range currentKeys {		currentKeysList = append(currentKeysList, currentKeys[k])	}	return currentKeysList, nil}// GetEnrollmentKey - fetches a single enrollment key// returns nil and error if not foundfunc GetEnrollmentKey(value string) (key models.EnrollmentKey, err error) {	currentKeys, err := getEnrollmentKeysMap()	if err != nil {		return key, err	}	if key, ok := currentKeys[value]; ok {		return key, nil	}	return key, EnrollmentErrors.NoKeyFound}func deleteEnrollmentkeyFromCache(key string) {	enrollmentkeyCacheMutex.Lock()	delete(enrollmentkeyCacheMap, key)	enrollmentkeyCacheMutex.Unlock()}// DeleteEnrollmentKey - delete's a given enrollment key by valuefunc DeleteEnrollmentKey(value string, force bool) error {	key, err := GetEnrollmentKey(value)	if err != nil {		return err	}	if key.Default && !force {		return errors.New("cannot delete default network key")	}	err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)	if err == nil {		if servercfg.CacheEnabled() {			deleteEnrollmentkeyFromCache(value)		}	}	return err}// TryToUseEnrollmentKey - checks first if key can be decremented// returns true if it is decremented or isvalidfunc TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {	key, err := decrementEnrollmentKey(k.Value)	if err != nil {		if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {			return k.IsValid()		}	} else {		k.UsesRemaining = key.UsesRemaining		return true	}	return false}// Tokenize - tokenizes an enrollment key to be used via registration// and attaches it to the Token field on the structfunc Tokenize(k *models.EnrollmentKey, serverAddr string) error {	if len(serverAddr) == 0 || k == nil {		return EnrollmentErrors.FailedToTokenize	}	newToken := models.EnrollmentToken{		Server: serverAddr,		Value:  k.Value,	}	data, err := json.Marshal(&newToken)	if err != nil {		return err	}	k.Token = b64.StdEncoding.EncodeToString(data)	return nil}// DeTokenize - detokenizes a base64 encoded string// and finds the associated enrollment keyfunc DeTokenize(b64Token string) (*models.EnrollmentKey, error) {	if len(b64Token) == 0 {		return nil, EnrollmentErrors.FailedToDeTokenize	}	tokenData, err := b64.StdEncoding.DecodeString(b64Token)	if err != nil {		return nil, err	}	var newToken models.EnrollmentToken	err = json.Unmarshal(tokenData, &newToken)	if err != nil {		return nil, err	}	k, err := GetEnrollmentKey(newToken.Value)	if err != nil {		return nil, err	}	return &k, nil}// == private ==// decrementEnrollmentKey - decrements the uses on a key if above 0 remainingfunc decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {	k, err := GetEnrollmentKey(value)	if err != nil {		return nil, err	}	if k.UsesRemaining == 0 {		return nil, EnrollmentErrors.NoUsesRemaining	}	k.UsesRemaining = k.UsesRemaining - 1	if err = upsertEnrollmentKey(&k); err != nil {		return nil, err	}	return &k, nil}func upsertEnrollmentKey(k *models.EnrollmentKey) error {	if k == nil {		return EnrollmentErrors.InvalidKey	}	data, err := json.Marshal(k)	if err != nil {		return err	}	err = database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)	if err == nil {		if servercfg.CacheEnabled() {			storeEnrollmentkeyInCache(k.Value, *k)		}	}	return nil}func getUniqueEnrollmentID() (string, error) {	currentKeys, err := getEnrollmentKeysMap()	if err != nil {		return "", err	}	newID := RandomString(models.EnrollmentKeyLength)	for _, ok := currentKeys[newID]; ok; {		newID = RandomString(models.EnrollmentKeyLength)	}	return newID, nil}func getEnrollmentkeysFromCache() map[string]models.EnrollmentKey {	return enrollmentkeyCacheMap}func storeEnrollmentkeyInCache(key string, enrollmentkey models.EnrollmentKey) {	enrollmentkeyCacheMutex.Lock()	enrollmentkeyCacheMap[key] = enrollmentkey	enrollmentkeyCacheMutex.Unlock()}func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) {	if servercfg.CacheEnabled() {		keys := getEnrollmentkeysFromCache()		if len(keys) != 0 {			return keys, nil		}	}	records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)	if err != nil {		if !database.IsEmptyRecord(err) {			return nil, err		}	}	if records == nil {		records = make(map[string]string)	}	currentKeys := make(map[string]models.EnrollmentKey, 0)	if len(records) > 0 {		for k := range records {			var currentKey models.EnrollmentKey			if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil {				continue			}			currentKeys[k] = currentKey			if servercfg.CacheEnabled() {				storeEnrollmentkeyInCache(currentKey.Value, currentKey)			}		}	}	return currentKeys, nil}func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) {	keys, _ := GetAllEnrollmentKeys()	for _, key := range keys {		newTags := []models.TagID{}		update := false		for _, tagID := range key.Groups {			if tagID == deletedTagID {				update = true				continue			}			newTags = append(newTags, tagID)		}		if update {			key.Groups = newTags			upsertEnrollmentKey(&key)		}	}}func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error {	keys, err := GetAllEnrollmentKeys()	if err != nil {		return fmt.Errorf("failed to retrieve keys: %w", err)	}	var errs []error	for _, key := range keys {		newNetworks := []string{}		newTags := []models.TagID{}		update := false		// Check and update networks		for _, net := range key.Networks {			if net == network {				update = true				continue			}			newNetworks = append(newNetworks, net)		}		// Check and update tags		for _, tag := range key.Groups {			tagParts := strings.Split(tag.String(), ".")			if len(tagParts) == 0 {				continue			}			tagNetwork := tagParts[0]			if tagNetwork == network {				update = true				continue			}			newTags = append(newTags, tag)		}		if update && len(newNetworks) == 0 && delete {			if err := DeleteEnrollmentKey(key.Value, true); err != nil {				errs = append(errs, fmt.Errorf("failed to delete key %s: %w", key.Value, err))			}			continue		}		if update {			key.Networks = newNetworks			key.Groups = newTags			if err := upsertEnrollmentKey(&key); err != nil {				errs = append(errs, fmt.Errorf("failed to update key %s: %w", key.Value, err))			}		}	}	if len(errs) > 0 {		return fmt.Errorf("errors unlinking network/tags from keys: %v", errs)	}	return nil}
 |