Browse Source

feat(NET-688): auto relaying via enrollment keys (#2647)

* feat(NET-688): auto relaying via enrollment keys

* feat(NET-688): address pr comments
Aceix 1 year ago
parent
commit
61ef6142ff

+ 11 - 2
auth/host_session.go

@@ -15,6 +15,7 @@ import (
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/exp/slog"
 )
 
 // SessionHandler - called by the HTTP router when user
@@ -202,7 +203,7 @@ func SessionHandler(conn *websocket.Conn) {
 		if err = conn.WriteMessage(messageType, reponseData); err != nil {
 			logger.Log(0, "error during message writing:", err.Error())
 		}
-		go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host)
+		go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil)
 	case <-timeout: // the read from req.answerCh has timed out
 		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
 			logger.Log(0, "error during timeout message writing:", err.Error())
@@ -221,7 +222,7 @@ func SessionHandler(conn *websocket.Conn) {
 }
 
 // CheckNetRegAndHostUpdate - run through networks and send a host update
-func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
+func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID) {
 	// publish host update through MQ
 	for i := range networks {
 		network := networks[i]
@@ -231,6 +232,14 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
 				logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
 				continue
 			}
+			if relayNodeId != uuid.Nil && !newNode.IsRelayed {
+				newNode.IsRelayed = true
+				newNode.RelayedBy = relayNodeId.String()
+				slog.Info(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network))
+				if err := logic.UpsertNode(newNode); err != nil {
+					slog.Error("failed to update node", "nodeid", relayNodeId.String())
+				}
+			}
 			logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
 			hostactions.AddAction(models.HostUpdate{
 				Action: models.JoinHostToNetwork,

+ 66 - 1
controllers/enrollmentkeys.go

@@ -6,6 +6,7 @@ import (
 	"net/http"
 	"time"
 
+	"github.com/google/uuid"
 	"github.com/gorilla/mux"
 
 	"github.com/gravitl/netmaker/auth"
@@ -26,6 +27,8 @@ func enrollmentKeyHandlers(r *mux.Router) {
 		Methods(http.MethodDelete)
 	r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)).
 		Methods(http.MethodPost)
+	r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey))).
+		Methods(http.MethodPut)
 }
 
 // swagger:route GET /api/v1/enrollment-keys enrollmentKeys getEnrollmentKeys
@@ -113,12 +116,23 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 		newTime = time.Unix(enrollmentKeyBody.Expiration, 0)
 	}
 
+	relayId := uuid.Nil
+	if enrollmentKeyBody.Relay != "" {
+		relayId, err = uuid.Parse(enrollmentKeyBody.Relay)
+		if err != nil {
+			logger.Log(0, r.Header.Get("user"), "error parsing relay id: ", err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
+	}
+
 	newEnrollmentKey, err := logic.CreateEnrollmentKey(
 		enrollmentKeyBody.UsesRemaining,
 		newTime,
 		enrollmentKeyBody.Networks,
 		enrollmentKeyBody.Tags,
 		enrollmentKeyBody.Unlimited,
+		relayId,
 	)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error())
@@ -136,6 +150,57 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) {
 	json.NewEncoder(w).Encode(newEnrollmentKey)
 }
 
+// swagger:route PUT /api/v1/enrollment-keys/:id enrollmentKeys updateEnrollmentKey
+//
+// Updates an EnrollmentKey for hosts to use on Netmaker server. Updates only the relay to use.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: EnrollmentKey
+func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) {
+	var enrollmentKeyBody models.APIEnrollmentKey
+	params := mux.Vars(r)
+	keyId := params["keyID"]
+
+	err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody)
+	if err != nil {
+		slog.Error("error decoding request body", "error", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	relayId := uuid.Nil
+	if enrollmentKeyBody.Relay != "" {
+		relayId, err = uuid.Parse(enrollmentKeyBody.Relay)
+		if err != nil {
+			slog.Error("error parsing relay id", "error", err)
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
+	}
+
+	newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, relayId)
+	if err != nil {
+		slog.Error("failed to update enrollment key", "error", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil {
+		slog.Error("failed to update enrollment key", "error", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	slog.Info("updated enrollment key", "id", keyId)
+	w.WriteHeader(http.StatusOK)
+	json.NewEncoder(w).Encode(newEnrollmentKey)
+}
+
 // swagger:route POST /api/v1/enrollment-keys/{token} enrollmentKeys handleHostRegister
 //
 // Handles a Netclient registration with server and add nodes accordingly.
@@ -286,5 +351,5 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) {
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(&response)
 	// notify host of changes, peer and node updates
-	go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost)
+	go auth.CheckNetRegAndHostUpdate(enrollmentKey.Networks, &newHost, enrollmentKey.Relay)
 }

+ 47 - 3
logic/enrollmentkey.go

@@ -7,8 +7,10 @@ import (
 	"fmt"
 	"time"
 
+	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
+	"golang.org/x/exp/slices"
 )
 
 // EnrollmentErrors - struct for holding EnrollmentKey error messages
@@ -29,12 +31,12 @@ var EnrollmentErrors = struct {
 }
 
 // CreateEnrollmentKey - creates a new enrollment key in db
-func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool) (k *models.EnrollmentKey, err error) {
+func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
 	newKeyID, err := getUniqueEnrollmentID()
 	if err != nil {
 		return nil, err
 	}
-	k = &models.EnrollmentKey{
+	k := &models.EnrollmentKey{
 		Value:         newKeyID,
 		Expiration:    time.Time{},
 		UsesRemaining: 0,
@@ -42,6 +44,7 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 		Networks:      []string{},
 		Tags:          []string{},
 		Type:          models.Undefined,
+		Relay:         relay,
 	}
 	if uses > 0 {
 		k.UsesRemaining = uses
@@ -61,10 +64,51 @@ func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string
 	if ok := k.Validate(); !ok {
 		return nil, EnrollmentErrors.InvalidCreate
 	}
+	if relay != uuid.Nil {
+		relayNode, err := GetNodeByID(relay.String())
+		if err != nil {
+			return nil, err
+		}
+		if !slices.Contains(k.Networks, relayNode.Network) {
+			return nil, errors.New("relay node not in key's networks")
+		}
+		if !relayNode.IsRelay {
+			return nil, errors.New("relay node is not a relay")
+		}
+	}
 	if err = upsertEnrollmentKey(k); err != nil {
 		return nil, err
 	}
-	return
+	return k, nil
+}
+
+// UpdateEnrollmentKey - updates an existing enrollment key's associated relay
+func UpdateEnrollmentKey(keyId string, relayId uuid.UUID) (*models.EnrollmentKey, error) {
+	key, err := GetEnrollmentKey(keyId)
+	if err != nil {
+		return nil, err
+	}
+
+	if relayId != uuid.Nil {
+		relayNode, err := GetNodeByID(relayId.String())
+		if err != nil {
+			return nil, err
+		}
+		if !slices.Contains(key.Networks, relayNode.Network) {
+			return nil, errors.New("relay node not in key's networks")
+		}
+		if !relayNode.IsRelay {
+			return nil, errors.New("relay node is not a relay")
+		}
+	}
+
+	key.Relay = relayId
+
+	if err = upsertEnrollmentKey(key); err != nil {
+		return nil, err
+	}
+
+	return key, nil
 }
 
 // GetAllEnrollmentKeys - fetches all enrollment keys from DB

+ 14 - 13
logic/enrollmentkey_test.go

@@ -4,6 +4,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
 	"github.com/stretchr/testify/assert"
@@ -13,35 +14,35 @@ func TestCreateEnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	t.Run("Can_Not_Create_Key", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, false, uuid.Nil)
 		assert.Nil(t, newKey)
 		assert.NotNil(t, err)
 		assert.Equal(t, err, EnrollmentErrors.InvalidCreate)
 	})
 	t.Run("Can_Create_Key_Uses", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+		newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
 		assert.Nil(t, err)
 		assert.Equal(t, 1, newKey.UsesRemaining)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Time", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false)
+		newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, false, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_Unlimited", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 	})
 	t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Networks) == 2)
 	})
 	t.Run("Can_Create_Key_WithTags", func(t *testing.T) {
-		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true)
+		newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, true, uuid.Nil)
 		assert.Nil(t, err)
 		assert.True(t, newKey.IsValid())
 		assert.True(t, len(newKey.Tags) == 2)
@@ -61,7 +62,7 @@ func TestCreateEnrollmentKey(t *testing.T) {
 func TestDelete_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
 	t.Run("Can_Delete_Key", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		err := DeleteEnrollmentKey(newKey.Value)
@@ -82,7 +83,7 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 func TestDecrement_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
+	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
 	t.Run("Check_initial_uses", func(t *testing.T) {
 		assert.True(t, newKey.IsValid())
 		assert.Equal(t, newKey.UsesRemaining, 1)
@@ -106,9 +107,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
 func TestUsability_EnrollmentKey(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false)
-	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false)
-	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true)
+	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, false, uuid.Nil)
+	key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, false, uuid.Nil)
+	key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, true, uuid.Nil)
 	t.Run("Check if valid use key can be used", func(t *testing.T) {
 		assert.Equal(t, key1.UsesRemaining, 1)
 		ok := TryToUseEnrollmentKey(key1)
@@ -144,7 +145,7 @@ func removeAllEnrollments() {
 func TestTokenize_EnrollmentKeys(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
 	const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5"
 	const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
@@ -177,7 +178,7 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
 func TestDeTokenize_EnrollmentKeys(t *testing.T) {
 	database.InitializeDatabase()
 	defer database.CloseDB()
-	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true)
+	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, true, uuid.Nil)
 	const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9"
 	const serverAddr = "api.myserver.com"
 

+ 4 - 0
models/enrollment_key.go

@@ -2,6 +2,8 @@ package models
 
 import (
 	"time"
+
+	"github.com/google/uuid"
 )
 
 const (
@@ -39,6 +41,7 @@ type EnrollmentKey struct {
 	Tags          []string  `json:"tags"`
 	Token         string    `json:"token,omitempty"` // B64 value of EnrollmentToken
 	Type          KeyType   `json:"type"`
+	Relay         uuid.UUID `json:"relay"`
 }
 
 // APIEnrollmentKey - used to create enrollment keys via API
@@ -49,6 +52,7 @@ type APIEnrollmentKey struct {
 	Unlimited     bool     `json:"unlimited"`
 	Tags          []string `json:"tags"`
 	Type          KeyType  `json:"type"`
+	Relay         string   `json:"relay"`
 }
 
 // RegisterResponse - the response to a successful enrollment register

+ 1 - 0
models/host.go

@@ -160,4 +160,5 @@ type RegisterMsg struct {
 	User         string `json:"user,omitempty"`
 	Password     string `json:"password,omitempty"`
 	JoinAll      bool   `json:"join_all,omitempty"`
+	Relay        string `json:"relay,omitempty"`
 }

+ 1 - 0
pro/logic/relays.go

@@ -149,6 +149,7 @@ func RelayUpdates(currentNode, newNode *models.Node) bool {
 	return relayUpdates
 }
 
+// UpdateRelayed - updates a relay's relayed nodes, and sends updates to the relayed nodes over MQ
 func UpdateRelayed(currentNode, newNode *models.Node) {
 	updatenodes := updateRelayNodes(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
 	if len(updatenodes) > 0 {