| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 | package logicimport (	b64 "encoding/base64"	"encoding/json"	"errors"	"fmt"	"golang.org/x/exp/slices"	"time"	"github.com/gravitl/netmaker/database"	"github.com/gravitl/netmaker/models")// EnrollmentErrors - struct for holding EnrollmentKey error messagesvar EnrollmentErrors = struct {	InvalidCreate      error	NoKeyFound         error	InvalidKey         error	NoUsesRemaining    error	FailedToTokenize   error	FailedToDeTokenize error}{	InvalidCreate:      fmt.Errorf("invalid enrollment key created"),	NoKeyFound:         fmt.Errorf("no enrollmentkey found"),	InvalidKey:         fmt.Errorf("invalid key provided"),	NoUsesRemaining:    fmt.Errorf("no uses remaining"),	FailedToTokenize:   fmt.Errorf("failed to tokenize"),	FailedToDeTokenize: fmt.Errorf("failed to detokenize"),}// CreateEnrollmentKey - creates a new enrollment key in dbfunc CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {	newKeyID, err := getUniqueEnrollmentID()	if err != nil {		return nil, err	}	k = &models.EnrollmentKey{		Value:         newKeyID,		Expiration:    time.Time{},		UsesRemaining: 0,		Unlimited:     unlimited,		Networks:      []string{},		Tags:          []string{},		Type:          models.Undefined,	}	if uses > 0 {		k.UsesRemaining = uses		k.Type = models.Uses	} else if !expiration.IsZero() {		k.Expiration = expiration		k.Type = models.TimeExpiration	} else if k.Unlimited {		k.Type = models.Unlimited	}	if len(networks) > 0 {		k.Networks = networks	}	if len(tags) > 0 {		k.Tags = tags	}	if ok := k.Validate(); !ok {		return nil, EnrollmentErrors.InvalidCreate	}	if err = upsertEnrollmentKey(k); err != nil {		return nil, err	}	return}// GetAllEnrollmentKeys - fetches all enrollment keys from DB// TODO drop double pointerfunc GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {	currentKeys, err := getEnrollmentKeysMap()	if err != nil {		return nil, err	}	var currentKeysList = []*models.EnrollmentKey{}	for k := range currentKeys {		currentKeysList = append(currentKeysList, currentKeys[k])	}	return currentKeysList, nil}// GetEnrollmentKey - fetches a single enrollment key// returns nil and error if not foundfunc GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {	currentKeys, err := getEnrollmentKeysMap()	if err != nil {		return nil, err	}	if key, ok := currentKeys[value]; ok {		return key, nil	}	return nil, EnrollmentErrors.NoKeyFound}// DeleteEnrollmentKey - delete's a given enrollment key by valuefunc DeleteEnrollmentKey(value string) error {	_, err := GetEnrollmentKey(value)	if err != nil {		return err	}	return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)}// TryToUseEnrollmentKey - checks first if key can be decremented// returns true if it is decremented or isvalidfunc TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {	key, err := decrementEnrollmentKey(k.Value)	if err != nil {		if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {			return k.IsValid()		}	} else {		k.UsesRemaining = key.UsesRemaining		return true	}	return false}// Tokenize - tokenizes an enrollment key to be used via registration// and attaches it to the Token field on the structfunc Tokenize(k *models.EnrollmentKey, serverAddr string) error {	if len(serverAddr) == 0 || k == nil {		return EnrollmentErrors.FailedToTokenize	}	newToken := models.EnrollmentToken{		Server: serverAddr,		Value:  k.Value,	}	data, err := json.Marshal(&newToken)	if err != nil {		return err	}	k.Token = b64.StdEncoding.EncodeToString(data)	return nil}// DeTokenize - detokenizes a base64 encoded string// and finds the associated enrollment keyfunc DeTokenize(b64Token string) (*models.EnrollmentKey, error) {	if len(b64Token) == 0 {		return nil, EnrollmentErrors.FailedToDeTokenize	}	tokenData, err := b64.StdEncoding.DecodeString(b64Token)	if err != nil {		return nil, err	}	var newToken models.EnrollmentToken	err = json.Unmarshal(tokenData, &newToken)	if err != nil {		return nil, err	}	k, err := GetEnrollmentKey(newToken.Value)	if err != nil {		return nil, err	}	return k, nil}// == private ==// decrementEnrollmentKey - decrements the uses on a key if above 0 remainingfunc decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {	k, err := GetEnrollmentKey(value)	if err != nil {		return nil, err	}	if k.UsesRemaining == 0 {		return nil, EnrollmentErrors.NoUsesRemaining	}	k.UsesRemaining = k.UsesRemaining - 1	if err = upsertEnrollmentKey(k); err != nil {		return nil, err	}	return k, nil}func upsertEnrollmentKey(k *models.EnrollmentKey) error {	if k == nil {		return EnrollmentErrors.InvalidKey	}	data, err := json.Marshal(k)	if err != nil {		return err	}	return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)}func getUniqueEnrollmentID() (string, error) {	currentKeys, err := getEnrollmentKeysMap()	if err != nil {		return "", err	}	newID := RandomString(models.EnrollmentKeyLength)	for _, ok := currentKeys[newID]; ok; {		newID = RandomString(models.EnrollmentKeyLength)	}	return newID, nil}func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {	records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)	if err != nil {		if !database.IsEmptyRecord(err) {			return nil, err		}	}	if records == nil {		records = make(map[string]string)	}	currentKeys := make(map[string]*models.EnrollmentKey, 0)	if len(records) > 0 {		for k := range records {			var currentKey models.EnrollmentKey			if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil {				continue			}			currentKeys[k] = ¤tKey		}	}	return currentKeys, nil}// UserHasNetworksAccess - checks if a user `u` has access to all `networks`func UserHasNetworksAccess(networks []string, u *models.User) bool {	if u.IsAdmin {		return true	}	for _, n := range networks {		if !slices.Contains(u.Networks, n) {			return false		}	}	return true}
 |