Browse Source

Merge pull request #2050 from gravitl/GRA-1198-enrollment_keys

Gra 1198 enrollment keys
dcarns 2 years ago
parent
commit
ad4bab064b

+ 4 - 4
config/config_test.go

@@ -1,7 +1,7 @@
-//Environment file for getting variables
-//Currently the only thing it does is set the master password
-//Should probably have it take over functions from OS such as port and mongodb connection details
-//Reads from the config/environments/dev.yaml file by default
+// Environment file for getting variables
+// Currently the only thing it does is set the master password
+// Should probably have it take over functions from OS such as port and mongodb connection details
+// Reads from the config/environments/dev.yaml file by default
 package config
 package config
 
 
 import (
 import (

+ 1 - 0
controllers/controller.go

@@ -26,6 +26,7 @@ var HttpHandlers = []interface{}{
 	ipHandlers,
 	ipHandlers,
 	loggerHandlers,
 	loggerHandlers,
 	hostHandlers,
 	hostHandlers,
+	enrollmentKeyHandlers,
 }
 }
 
 
 // HandleRESTRequests - handles the rest requests
 // HandleRESTRequests - handles the rest requests

+ 238 - 0
controllers/enrollmentkeys.go

@@ -0,0 +1,238 @@
+package controller
+
+import (
+	"encoding/json"
+	"fmt"
+	"net/http"
+	"time"
+
+	"github.com/gorilla/mux"
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/logic/hostactions"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
+	"github.com/gravitl/netmaker/servercfg"
+)
+
+func enrollmentKeyHandlers(r *mux.Router) {
+	r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))).Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))).Methods(http.MethodGet)
+	r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))).Methods(http.MethodDelete)
+	r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).Methods(http.MethodPost)
+}
+
+// swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
+//
+// Lists all EnrollmentKeys for admins.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: getEnrollmentKeysSlice
+func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) {
+	currentKeys, err := logic.GetAllEnrollmentKeys()
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	for i := range currentKeys {
+		currentKey := currentKeys[i]
+		if err = logic.Tokenize(currentKey, servercfg.GetAPIHost()); err != nil {
+			logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
+		}
+	}
+	// return JSON/API formatted keys
+	logger.Log(2, r.Header.Get("user"), "fetched enrollment keys")
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(currentKeys)
+}
+
+// swagger:route DELETE /api/v1/enrollment-keys/{keyID} enrollmentKeys deleteEnrollmentKey
+//
+// Deletes an EnrollmentKey from Netmaker server.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: deleteEnrollmentKeyResponse
+func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	keyID := params["keyID"]
+	err := logic.DeleteEnrollmentKey(keyID)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	logger.Log(2, r.Header.Get("user"), "deleted enrollment key", keyID)
+	w.WriteHeader(http.StatusOK)
+}
+
+// swagger:route POST /api/v1/enrollment-keys enrollmentKeys createEnrollmentKey
+//
+// Creates an EnrollmentKey for hosts to use on Netmaker server.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: createEnrollmentKeyResponse
+func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
+
+	var enrollmentKeyBody models.APIEnrollmentKey
+
+	err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
+			err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	var newTime time.Time
+	if enrollmentKeyBody.Expiration > 0 {
+		newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
+	}
+
+	newEnrollmentKey, err := logic.CreateEnrollmentKey(enrollmentKeyBody.UsesRemaining, newTime, enrollmentKeyBody.Networks, enrollmentKeyBody.Tags, enrollmentKeyBody.Unlimited)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil {
+		logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	logger.Log(2, r.Header.Get("user"), "created enrollment key")
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(newEnrollmentKey)
+}
+
+// swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister
+//
+// Handles a Netclient registration with server and add nodes accordingly.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: handleHostRegisterResponse
+func handleHostRegister(w http.ResponseWriter, r *http.Request) {
+	var params = mux.Vars(r)
+	token := params["token"]
+	logger.Log(0, "received registration attempt with token", token)
+	// check if token exists
+	enrollmentKey, err := logic.DeTokenize(token)
+	if err != nil {
+		logger.Log(0, "invalid enrollment key used", token, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	// get the host
+	var newHost models.Host
+	if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil {
+		logger.Log(0, r.Header.Get("user"), "error decoding request body: ",
+			err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	hostExists := false
+	// check if host already exists
+	if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 {
+		logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks")
+		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest"))
+		return
+	}
+	// version check
+	if !logic.IsVersionComptatible(newHost.Version) || newHost.TrafficKeyPublic == nil {
+		err := fmt.Errorf("incompatible netclient")
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+	key, keyErr := logic.RetrievePublicTrafficKey()
+	if keyErr != nil {
+		logger.Log(0, "error retrieving key:", keyErr.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	// use the token
+	if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok {
+		logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration")
+		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"))
+		return
+	}
+	if !hostExists {
+		// register host
+		logic.CheckHostPorts(&newHost)
+		if err = logic.CreateHost(&newHost); err != nil {
+			logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration -", err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
+		}
+	} else {
+		// need to revise the list of networks from key
+		// based on the ones host currently has
+		var networksToAdd = []string{}
+		currentNets := logic.GetHostNetworks(newHost.ID.String())
+		for _, newNet := range enrollmentKey.Networks {
+			if !logic.StringSliceContains(currentNets, newNet) {
+				networksToAdd = append(networksToAdd, newNet)
+			}
+		}
+		enrollmentKey.Networks = networksToAdd
+	}
+	// ready the response
+	server := servercfg.GetServerInfo()
+	server.TrafficKey = key
+	logger.Log(0, newHost.Name, newHost.ID.String(), "registered with Netmaker")
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(&server)
+	// notify host of changes, peer and node updates
+	go checkNetRegAndHostUpdate(enrollmentKey.Networks, &newHost)
+}
+
+// run through networks and send a host update
+func checkNetRegAndHostUpdate(networks []string, h *models.Host) {
+	// publish host update through MQ
+	for i := range networks {
+		network := networks[i]
+		if ok, _ := logic.NetworkExists(network); ok {
+			newNode, err := logic.UpdateHostNetwork(h, network, true)
+			if err != nil {
+				logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
+				continue
+			}
+			logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
+			hostactions.AddAction(models.HostUpdate{
+				Action: models.JoinHostToNetwork,
+				Host:   *h,
+				Node:   *newNode,
+			})
+		}
+	}
+	if servercfg.IsMessageQueueBackend() {
+		mq.HostUpdate(&models.HostUpdate{
+			Action: models.RequestAck,
+			Host:   *h,
+		})
+		if err := mq.PublishPeerUpdate(); err != nil {
+			logger.Log(0, "failed to publish peer update during registration -", err.Error())
+		}
+	}
+}

+ 9 - 8
controllers/hosts.go

@@ -10,8 +10,10 @@ import (
 	"github.com/gorilla/mux"
 	"github.com/gorilla/mux"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/logic/hostactions"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/mq"
+	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/crypto/bcrypt"
 )
 )
 
 
@@ -230,18 +232,17 @@ func addHostToNetwork(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 	logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name)
 	logger.Log(1, "added new node", newNode.ID.String(), "to host", currHost.Name)
-	if err = mq.HostUpdate(&models.HostUpdate{
+	hostactions.AddAction(models.HostUpdate{
 		Action: models.JoinHostToNetwork,
 		Action: models.JoinHostToNetwork,
 		Host:   *currHost,
 		Host:   *currHost,
 		Node:   *newNode,
 		Node:   *newNode,
-	}); err != nil {
-		logger.Log(0, r.Header.Get("user"), "failed to update host to join network:", hostid, network, err.Error())
+	})
+	if servercfg.IsMessageQueueBackend() {
+		mq.HostUpdate(&models.HostUpdate{
+			Action: models.RequestAck,
+			Host:   *currHost,
+		})
 	}
 	}
-	go func() { // notify of peer change
-		if err := mq.PublishPeerUpdate(); err != nil {
-			logger.Log(1, "error publishing peer update ", err.Error())
-		}
-	}()
 
 
 	logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network))
 	logger.Log(2, r.Header.Get("user"), fmt.Sprintf("added host %s to network %s", currHost.Name, network))
 	w.WriteHeader(http.StatusOK)
 	w.WriteHeader(http.StatusOK)

+ 3 - 0
database/database.go

@@ -57,6 +57,8 @@ const (
 	CACHE_TABLE_NAME = "cache"
 	CACHE_TABLE_NAME = "cache"
 	// HOSTS_TABLE_NAME - the table name for hosts
 	// HOSTS_TABLE_NAME - the table name for hosts
 	HOSTS_TABLE_NAME = "hosts"
 	HOSTS_TABLE_NAME = "hosts"
+	// ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys
+	ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys"
 
 
 	// == ERROR CONSTS ==
 	// == ERROR CONSTS ==
 	// NO_RECORD - no singular result found
 	// NO_RECORD - no singular result found
@@ -138,6 +140,7 @@ func createTables() {
 	createTable(USER_GROUPS_TABLE_NAME)
 	createTable(USER_GROUPS_TABLE_NAME)
 	createTable(CACHE_TABLE_NAME)
 	createTable(CACHE_TABLE_NAME)
 	createTable(HOSTS_TABLE_NAME)
 	createTable(HOSTS_TABLE_NAME)
+	createTable(ENROLLMENT_KEYS_TABLE_NAME)
 }
 }
 
 
 func createTable(tableName string) error {
 func createTable(tableName string) error {

+ 0 - 2
logic/accesskeys_test.go

@@ -5,7 +5,6 @@ import "testing"
 func Test_genKeyName(t *testing.T) {
 func Test_genKeyName(t *testing.T) {
 	for i := 0; i < 100; i++ {
 	for i := 0; i < 100; i++ {
 		kname := genKeyName()
 		kname := genKeyName()
-		t.Log(kname)
 		if len(kname) != 20 {
 		if len(kname) != 20 {
 			t.Fatalf("improper length of key name, expected 20 got :%d", len(kname))
 			t.Fatalf("improper length of key name, expected 20 got :%d", len(kname))
 		}
 		}
@@ -15,7 +14,6 @@ func Test_genKeyName(t *testing.T) {
 func Test_genKey(t *testing.T) {
 func Test_genKey(t *testing.T) {
 	for i := 0; i < 100; i++ {
 	for i := 0; i < 100; i++ {
 		kname := GenKey()
 		kname := GenKey()
-		t.Log(kname)
 		if len(kname) != 16 {
 		if len(kname) != 16 {
 			t.Fatalf("improper length of key name, expected 16 got :%d", len(kname))
 			t.Fatalf("improper length of key name, expected 16 got :%d", len(kname))
 		}
 		}

+ 221 - 0
logic/enrollmentkey.go

@@ -0,0 +1,221 @@
+package logic
+
+import (
+	b64 "encoding/base64"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"time"
+
+	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/netclient/ncutils"
+)
+
+// EnrollmentErrors - struct for holding EnrollmentKey error messages
+var EnrollmentErrors = struct {
+	InvalidCreate      error
+	NoKeyFound         error
+	InvalidKey         error
+	NoUsesRemaining    error
+	FailedToTokenize   error
+	FailedToDeTokenize error
+}{
+	InvalidCreate:      fmt.Errorf("invalid enrollment key created"),
+	NoKeyFound:         fmt.Errorf("no enrollmentkey found"),
+	InvalidKey:         fmt.Errorf("invalid key provided"),
+	NoUsesRemaining:    fmt.Errorf("no uses remaining"),
+	FailedToTokenize:   fmt.Errorf("failed to tokenize"),
+	FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
+}
+
+// CreateEnrollmentKey - creates a new enrollment key in db
+func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
+	newKeyID, err := getUniqueEnrollmentID()
+	if err != nil {
+		return nil, err
+	}
+	k = &models.EnrollmentKey{
+		Value:         newKeyID,
+		Expiration:    time.Time{},
+		UsesRemaining: 0,
+		Unlimited:     unlimited,
+		Networks:      []string{},
+		Tags:          []string{},
+	}
+	if uses > 0 {
+		k.UsesRemaining = uses
+	}
+	if !expiration.IsZero() {
+		k.Expiration = expiration
+	}
+	if len(networks) > 0 {
+		k.Networks = networks
+	}
+	if len(tags) > 0 {
+		k.Tags = tags
+	}
+	if ok := k.Validate(); !ok {
+		return nil, EnrollmentErrors.InvalidCreate
+	}
+	if err = upsertEnrollmentKey(k); err != nil {
+		return nil, err
+	}
+	return
+}
+
+// GetAllEnrollmentKeys - fetches all enrollment keys from DB
+func GetAllEnrollmentKeys() ([]*models.EnrollmentKey, error) {
+	currentKeys, err := getEnrollmentKeysMap()
+	if err != nil {
+		return nil, err
+	}
+	var currentKeysList = []*models.EnrollmentKey{}
+	for k := range currentKeys {
+		currentKeysList = append(currentKeysList, currentKeys[k])
+	}
+	return currentKeysList, nil
+}
+
+// GetEnrollmentKey - fetches a single enrollment key
+// returns nil and error if not found
+func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
+	currentKeys, err := getEnrollmentKeysMap()
+	if err != nil {
+		return nil, err
+	}
+	if key, ok := currentKeys[value]; ok {
+		return key, nil
+	}
+	return nil, EnrollmentErrors.NoKeyFound
+}
+
+// DeleteEnrollmentKey - delete's a given enrollment key by value
+func DeleteEnrollmentKey(value string) error {
+	_, err := GetEnrollmentKey(value)
+	if err != nil {
+		return err
+	}
+	return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
+}
+
+// TryToUseEnrollmentKey - checks first if key can be decremented
+// returns true if it is decremented or isvalid
+func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool {
+	key, err := decrementEnrollmentKey(k.Value)
+	if err != nil {
+		if errors.Is(err, EnrollmentErrors.NoUsesRemaining) {
+			return k.IsValid()
+		}
+	} else {
+		k.UsesRemaining = key.UsesRemaining
+		return true
+	}
+	return false
+}
+
+// Tokenize - tokenizes an enrollment key to be used via registration
+// and attaches it to the Token field on the struct
+func Tokenize(k *models.EnrollmentKey, serverAddr string) error {
+	if len(serverAddr) == 0 || k == nil {
+		return EnrollmentErrors.FailedToTokenize
+	}
+	newToken := models.EnrollmentToken{
+		Server: serverAddr,
+		Value:  k.Value,
+	}
+	data, err := json.Marshal(&newToken)
+	if err != nil {
+		return err
+	}
+	k.Token = b64.StdEncoding.EncodeToString(data)
+	return nil
+}
+
+// DeTokenize - detokenizes a base64 encoded string
+// and finds the associated enrollment key
+func DeTokenize(b64Token string) (*models.EnrollmentKey, error) {
+	if len(b64Token) == 0 {
+		return nil, EnrollmentErrors.FailedToDeTokenize
+	}
+	tokenData, err := b64.StdEncoding.DecodeString(b64Token)
+	if err != nil {
+		return nil, err
+	}
+
+	var newToken models.EnrollmentToken
+	err = json.Unmarshal(tokenData, &newToken)
+	if err != nil {
+		return nil, err
+	}
+	k, err := GetEnrollmentKey(newToken.Value)
+	if err != nil {
+		return nil, err
+	}
+	return k, nil
+}
+
+// == private ==
+
+// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining
+func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) {
+	k, err := GetEnrollmentKey(value)
+	if err != nil {
+		return nil, err
+	}
+	if k.UsesRemaining == 0 {
+		return nil, EnrollmentErrors.NoUsesRemaining
+	}
+	k.UsesRemaining = k.UsesRemaining - 1
+	if err = upsertEnrollmentKey(k); err != nil {
+		return nil, err
+	}
+
+	return k, nil
+}
+
+func upsertEnrollmentKey(k *models.EnrollmentKey) error {
+	if k == nil {
+		return EnrollmentErrors.InvalidKey
+	}
+	data, err := json.Marshal(k)
+	if err != nil {
+		return err
+	}
+	return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
+}
+
+func getUniqueEnrollmentID() (string, error) {
+	currentKeys, err := getEnrollmentKeysMap()
+	if err != nil {
+		return "", err
+	}
+	newID := ncutils.MakeRandomString(models.EnrollmentKeyLength)
+	for _, ok := currentKeys[newID]; ok; {
+		newID = ncutils.MakeRandomString(models.EnrollmentKeyLength)
+	}
+	return newID, nil
+}
+
+func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
+	records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
+	if err != nil {
+		if !database.IsEmptyRecord(err) {
+			return nil, err
+		}
+	}
+	if records == nil {
+		records = make(map[string]string)
+	}
+	currentKeys := make(map[string]*models.EnrollmentKey, 0)
+	if len(records) > 0 {
+		for k := range records {
+			var currentKey models.EnrollmentKey
+			if err = json.Unmarshal([]byte(records[k]), &currentKey); err != nil {
+				continue
+			}
+			currentKeys[k] = &currentKey
+		}
+	}
+	return currentKeys, nil
+}

+ 206 - 0
logic/enrollmentkey_test.go

@@ -0,0 +1,206 @@
+package logic
+
+import (
+	"testing"
+	"time"
+
+	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/models"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCreateEnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	t.Run("Can_Not_Create_Key", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
+		assert.Nil(t, newKey)
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
+	})
+	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+		assert.Nil(t, err)
+		assert.Equal(t, 1, newKey.UsesRemaining)
+		assert.True(t, newKey.IsValid())
+	})
+	t.Run("Can_Create_Key_Time", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+	})
+	t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+	})
+	t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+		assert.True(t, len(newKey.Networks) == 2)
+	})
+	t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
+		assert.Nil(t, err)
+		assert.True(t, newKey.IsValid())
+		assert.True(t, len(newKey.Tags) == 2)
+	})
+
+	t.Run("Can_Get_List_of_Keys", func(t *testing.T) {
+		keys, err := GetAllEnrollmentKeys()
+		assert.Nil(t, err)
+		assert.True(t, len(keys) > 0)
+		for i := range keys {
+			assert.Equal(t, len(keys[i].Value), models.EnrollmentKeyLength)
+		}
+	})
+	removeAllEnrollments()
+}
+
+func TestDelete_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+	t.Run("Can_Delete_Key", func(t *testing.T) {
+		assert.True(t, newKey.IsValid())
+		err := DeleteEnrollmentKey(newKey.Value)
+		assert.Nil(t, err)
+		oldKey, err := GetEnrollmentKey(newKey.Value)
+		assert.Nil(t, oldKey)
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
+	})
+	t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) {
+		err := DeleteEnrollmentKey("notakey")
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
+	})
+	removeAllEnrollments()
+}
+
+func TestDecrement_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+	t.Run("Check_initial_uses", func(t *testing.T) {
+		assert.True(t, newKey.IsValid())
+		assert.Equal(t, newKey.UsesRemaining, 1)
+	})
+	t.Run("Check can decrement", func(t *testing.T) {
+		assert.Equal(t, newKey.UsesRemaining, 1)
+		k, err := decrementEnrollmentKey(newKey.Value)
+		assert.Nil(t, err)
+		newKey = k
+	})
+	t.Run("Check can not decrement", func(t *testing.T) {
+		assert.Equal(t, newKey.UsesRemaining, 0)
+		_, err := decrementEnrollmentKey(newKey.Value)
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.NoUsesRemaining)
+	})
+
+	removeAllEnrollments()
+}
+
+func TestUsability_EnrollmentKey(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
+	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
+	t.Run("Check if valid use key can be used", func(t *testing.T) {
+		assert.Equal(t, key1.UsesRemaining, 1)
+		ok := TryToUseEnrollmentKey(key1)
+		assert.True(t, ok)
+		assert.Equal(t, 0, key1.UsesRemaining)
+	})
+
+	t.Run("Check if valid time key can be used", func(t *testing.T) {
+		assert.True(t, !key2.Expiration.IsZero())
+		ok := TryToUseEnrollmentKey(key2)
+		assert.True(t, ok)
+	})
+
+	t.Run("Check if valid unlimited key can be used", func(t *testing.T) {
+		assert.True(t, key3.Unlimited)
+		ok := TryToUseEnrollmentKey(key3)
+		assert.True(t, ok)
+	})
+
+	t.Run("check invalid key can not be used", func(t *testing.T) {
+		ok := TryToUseEnrollmentKey(key1)
+		assert.False(t, ok)
+	})
+}
+
+func removeAllEnrollments() {
+	database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
+}
+
+//Test that cheks if it can tokenize
+//Test that cheks if it can't tokenize
+
+func TestTokenize_EnrollmentKeys(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+	const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
+	const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
+	const serverAddr = "api.myserver.com"
+	t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) {
+		err := Tokenize(nil, "ServerAddress")
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.FailedToTokenize)
+	})
+	t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) {
+		err := Tokenize(newKey, "")
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.FailedToTokenize)
+	})
+
+	t.Run("Can_Tokenize", func(t *testing.T) {
+		err := Tokenize(newKey, serverAddr)
+		assert.Nil(t, err)
+		assert.True(t, len(newKey.Token) > 0)
+	})
+
+	t.Run("Is_Correct_B64_Token", func(t *testing.T) {
+		newKey.Value = defaultValue
+		err := Tokenize(newKey, serverAddr)
+		assert.Nil(t, err)
+		assert.Equal(t, newKey.Token, b64value)
+	})
+	removeAllEnrollments()
+}
+
+func TestDeTokenize_EnrollmentKeys(t *testing.T) {
+	database.InitializeDatabase()
+	defer database.CloseDB()
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+	const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
+	const serverAddr = "api.myserver.com"
+
+	t.Run("Can_Not_DeTokenize", func(t *testing.T) {
+		value, err := DeTokenize("")
+		assert.Nil(t, value)
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.FailedToDeTokenize)
+	})
+	t.Run("Can_Not_Find_Key", func(t *testing.T) {
+		value, err := DeTokenize(b64Value)
+		assert.Nil(t, value)
+		assert.NotNil(t, err)
+		assert.Equal(t, err, EnrollmentErrors.NoKeyFound)
+	})
+	t.Run("Can_DeTokenize", func(t *testing.T) {
+		err := Tokenize(newKey, serverAddr)
+		assert.Nil(t, err)
+		output, err := DeTokenize(newKey.Token)
+		assert.Nil(t, err)
+		assert.NotNil(t, output)
+		assert.Equal(t, newKey.Value, output.Value)
+	})
+
+	removeAllEnrollments()
+}

+ 4 - 13
logic/host_test.go

@@ -2,37 +2,28 @@ package logic
 
 
 import (
 import (
 	"context"
 	"context"
+	"fmt"
 	"net"
 	"net"
 	"testing"
 	"testing"
 
 
 	"github.com/google/uuid"
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/matryer/is"
 	"github.com/matryer/is"
 )
 )
 
 
-func TestMain(m *testing.M) {
+func TestCheckPorts(t *testing.T) {
 	database.InitializeDatabase()
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	defer database.CloseDB()
-	CreateAdmin(&models.User{
-		UserName: "admin",
-		Password: "password",
-		IsAdmin:  true,
-		Networks: []string{},
-		Groups:   []string{},
-	})
 	peerUpdate := make(chan *models.Node)
 	peerUpdate := make(chan *models.Node)
 	go ManageZombies(context.Background(), peerUpdate)
 	go ManageZombies(context.Background(), peerUpdate)
 	go func() {
 	go func() {
-		for update := range peerUpdate {
+		for y := range peerUpdate {
+			fmt.Printf("Pointless %v\n", y)
 			//do nothing
 			//do nothing
-			logger.Log(3, "received node update", update.Action)
 		}
 		}
 	}()
 	}()
-}
 
 
-func TestCheckPorts(t *testing.T) {
 	h := models.Host{
 	h := models.Host{
 		ID:              uuid.New(),
 		ID:              uuid.New(),
 		EndpointIP:      net.ParseIP("192.168.1.1"),
 		EndpointIP:      net.ParseIP("192.168.1.1"),

+ 38 - 0
logic/hostactions/hostactions.go

@@ -0,0 +1,38 @@
+package hostactions
+
+import (
+	"sync"
+
+	"github.com/gravitl/netmaker/models"
+)
+
+// nodeActionHandler - handles the storage of host action updates
+var nodeActionHandler sync.Map
+
+// AddAction - adds a host action to a host's list to be retrieved from broker update
+func AddAction(hu models.HostUpdate) {
+	currentRecords, ok := nodeActionHandler.Load(hu.Host.ID.String())
+	if !ok { // no list exists yet
+		nodeActionHandler.Store(hu.Host.ID.String(), []models.HostUpdate{hu})
+	} else { // list exists, append to it
+		currentList := currentRecords.([]models.HostUpdate)
+		currentList = append(currentList, hu)
+		nodeActionHandler.Store(hu.Host.ID.String(), currentList)
+	}
+}
+
+// GetAction - gets an action if exists
+// TODO: may need to move to DB rather than sync map for HA
+func GetAction(id string) *models.HostUpdate {
+	currentRecords, ok := nodeActionHandler.Load(id)
+	if !ok {
+		return nil
+	}
+	currentList := currentRecords.([]models.HostUpdate)
+	if len(currentList) > 0 {
+		hu := currentList[0]
+		nodeActionHandler.Store(hu.Host.ID.String(), currentList[1:])
+		return &hu
+	}
+	return nil
+}

+ 7 - 2
logic/hosts.go

@@ -90,7 +90,7 @@ func CreateHost(h *models.Host) error {
 	if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
 	if (err != nil && !database.IsEmptyRecord(err)) || (err == nil) {
 		return ErrHostExists
 		return ErrHostExists
 	}
 	}
-	//encrypt that password so we never see it
+	// encrypt that password so we never see it
 	hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5)
 	hash, err := bcrypt.GenerateFromPassword([]byte(h.HostPass), 5)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -236,7 +236,12 @@ func AssociateNodeToHost(n *models.Node, h *models.Host) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	h.Nodes = append(h.Nodes, n.ID.String())
+	currentHost, err := GetHost(h.ID.String())
+	if err != nil {
+		return err
+	}
+	h.HostPass = currentHost.HostPass
+	h.Nodes = append(currentHost.Nodes, n.ID.String())
 	return UpsertHost(h)
 	return UpsertHost(h)
 }
 }
 
 

+ 4 - 4
logic/peers.go

@@ -34,7 +34,6 @@ func GetProxyUpdateForHost(host *models.Host) (models.ProxyManagerPayload, error
 		} else {
 		} else {
 			logger.Log(0, "couldn't find relay host for:  ", host.ID.String())
 			logger.Log(0, "couldn't find relay host for:  ", host.ID.String())
 		}
 		}
-
 	}
 	}
 	if host.IsRelay {
 	if host.IsRelay {
 		relayedHosts := GetRelayedHosts(host)
 		relayedHosts := GetRelayedHosts(host)
@@ -142,9 +141,10 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
 	if deletedNode != nil {
 	if deletedNode != nil {
 		deletedNodes = append(deletedNodes, *deletedNode)
 		deletedNodes = append(deletedNodes, *deletedNode)
 	}
 	}
-	logger.Log(1, "peer update for host ", host.ID.String())
+	logger.Log(1, "peer update for host", host.ID.String())
 	peerIndexMap := make(map[string]int)
 	peerIndexMap := make(map[string]int)
 	for _, nodeID := range host.Nodes {
 	for _, nodeID := range host.Nodes {
+		nodeID := nodeID
 		node, err := GetNodeByID(nodeID)
 		node, err := GetNodeByID(nodeID)
 		if err != nil {
 		if err != nil {
 			continue
 			continue
@@ -163,7 +163,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
 		}
 		}
 		for _, peer := range currentPeers {
 		for _, peer := range currentPeers {
 			peer := peer
 			peer := peer
-			if peer.ID == node.ID {
+			if peer.ID.String() == node.ID.String() {
 				logger.Log(2, "peer update, skipping self")
 				logger.Log(2, "peer update, skipping self")
 				//skip yourself
 				//skip yourself
 				continue
 				continue
@@ -185,7 +185,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, deletedNode *models
 				continue
 				continue
 			}
 			}
 			if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) {
 			if !nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) {
-				log.Println("peer update, skipping node for acl")
+				logger.Log(2, "peer update, skipping node for acl")
 				//skip if not permitted by acl
 				//skip if not permitted by acl
 				continue
 				continue
 			}
 			}

+ 59 - 0
models/enrollment_key.go

@@ -0,0 +1,59 @@
+package models
+
+import (
+	"time"
+)
+
+// EnrollmentToken - the tokenized version of an enrollmentkey;
+// to be used for host registration
+type EnrollmentToken struct {
+	Server string `json:"server"`
+	Value  string `json:"value"`
+}
+
+// EnrollmentKeyLength - the length of an enrollment key - 62^16 unique possibilities
+const EnrollmentKeyLength = 32
+
+// EnrollmentKey - the key used to register hosts and join them to specific networks
+type EnrollmentKey struct {
+	Expiration    time.Time `json:"expiration"`
+	UsesRemaining int       `json:"uses_remaining"`
+	Value         string    `json:"value"`
+	Networks      []string  `json:"networks"`
+	Unlimited     bool      `json:"unlimited"`
+	Tags          []string  `json:"tags"`
+	Token         string    `json:"token,omitempty"` // B64 value of EnrollmentToken
+}
+
+// APIEnrollmentKey - used to create enrollment keys via API
+type APIEnrollmentKey struct {
+	Expiration    int64    `json:"expiration"`
+	UsesRemaining int      `json:"uses_remaining"`
+	Networks      []string `json:"networks"`
+	Unlimited     bool     `json:"unlimited"`
+	Tags          []string `json:"tags"`
+}
+
+// EnrollmentKey.IsValid - checks if the key is still valid to use
+func (k *EnrollmentKey) IsValid() bool {
+	if k == nil {
+		return false
+	}
+	if k.UsesRemaining > 0 {
+		return true
+	}
+	if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) {
+		return true
+	}
+
+	return k.Unlimited
+}
+
+// EnrollmentKey.Validate - validate's an EnrollmentKey
+// should be used during creation
+func (k *EnrollmentKey) Validate() bool {
+	return k.Networks != nil &&
+		k.Tags != nil &&
+		len(k.Value) == EnrollmentKeyLength &&
+		k.IsValid()
+}

+ 4 - 0
models/host.go

@@ -74,6 +74,10 @@ const (
 	DeleteHost = "DELETE_HOST"
 	DeleteHost = "DELETE_HOST"
 	// JoinHostToNetwork - constant for host network join action
 	// JoinHostToNetwork - constant for host network join action
 	JoinHostToNetwork = "JOIN_HOST_TO_NETWORK"
 	JoinHostToNetwork = "JOIN_HOST_TO_NETWORK"
+	// Acknowledgement - ACK response for hosts
+	Acknowledgement = "ACK"
+	// RequestAck - request an ACK
+	RequestAck = "REQ_ACK"
 )
 )
 
 
 // HostUpdate - struct for host update
 // HostUpdate - struct for host update

+ 15 - 0
mq/handlers.go

@@ -9,6 +9,7 @@ import (
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/logic/hostactions"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/servercfg"
@@ -144,6 +145,19 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
 		logger.Log(3, fmt.Sprintf("recieved host update: %s\n", hostUpdate.Host.ID.String()))
 		logger.Log(3, fmt.Sprintf("recieved host update: %s\n", hostUpdate.Host.ID.String()))
 		var sendPeerUpdate bool
 		var sendPeerUpdate bool
 		switch hostUpdate.Action {
 		switch hostUpdate.Action {
+		case models.Acknowledgement:
+			hu := hostactions.GetAction(currentHost.ID.String())
+			if hu != nil {
+				if err = HostUpdate(hu); err != nil {
+					logger.Log(0, "failed to send new node to host", hostUpdate.Host.Name, currentHost.ID.String(), err.Error())
+					return
+				} else {
+					if err = PublishSingleHostPeerUpdate(currentHost, nil); err != nil {
+						logger.Log(0, "failed peers publish after join acknowledged", hostUpdate.Host.Name, currentHost.ID.String(), err.Error())
+						return
+					}
+				}
+			}
 		case models.UpdateHost:
 		case models.UpdateHost:
 			sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost)
 			sendPeerUpdate = logic.UpdateHostFromClient(&hostUpdate.Host, currentHost)
 			err := logic.UpsertHost(currentHost)
 			err := logic.UpsertHost(currentHost)
@@ -169,6 +183,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
 			}
 			}
 			sendPeerUpdate = true
 			sendPeerUpdate = true
 		}
 		}
+
 		if sendPeerUpdate {
 		if sendPeerUpdate {
 			err := PublishPeerUpdate()
 			err := PublishPeerUpdate()
 			if err != nil {
 			if err != nil {

+ 3 - 0
mq/publishers.go

@@ -61,6 +61,9 @@ func PublishSingleHostPeerUpdate(host *models.Host, deletedNode *models.Node) er
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
+	if len(peerUpdate.Peers) == 0 { // no peers to send
+		return nil
+	}
 	if host.ProxyEnabled {
 	if host.ProxyEnabled {
 		proxyUpdate, err := logic.GetProxyUpdateForHost(host)
 		proxyUpdate, err := logic.GetProxyUpdateForHost(host)
 		if err != nil {
 		if err != nil {