Browse Source

Net 500: validate network parameter passed to node endpoints (#2480)

* enforce unique names for ext client names

* only check for unique id on creation

* check for unique id if changed

* validate network parameter passed to node endpoints

---------

Co-authored-by: Abhishek K <[email protected]>
Matthew R Kasun 2 years ago
parent
commit
ab4ddbb042
1 changed files with 47 additions and 25 deletions
  1. 47 25
      controllers/node.go

+ 47 - 25
controllers/node.go

@@ -16,6 +16,7 @@ import (
 	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/crypto/bcrypt"
+	"golang.org/x/exp/slog"
 )
 )
 
 
 var hostIDHeader = "host-id"
 var hostIDHeader = "host-id"
@@ -373,11 +374,10 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 
 
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
-	node, err := logic.GetNodeByID(nodeid)
+
+	node, err := validateParams(nodeid, params["network"])
 	if err != nil {
 	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	host, err := logic.GetHost(node.HostID.String())
 	host, err := logic.GetHost(node.HostID.String())
@@ -442,16 +442,20 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 	var gateway models.EgressGatewayRequest
 	var gateway models.EgressGatewayRequest
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
-	w.Header().Set("Content-Type", "application/json")
-	err := json.NewDecoder(r.Body).Decode(&gateway)
+	node, err := validateParams(params["nodeid"], params["network"])
 	if err != nil {
 	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		return
+	}
+	w.Header().Set("Content-Type", "application/json")
+	if err := json.NewDecoder(r.Body).Decode(&gateway); err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 		return
 	}
 	}
 	gateway.NetID = params["network"]
 	gateway.NetID = params["network"]
 	gateway.NodeID = params["nodeid"]
 	gateway.NodeID = params["nodeid"]
-	node, err := logic.CreateEgressGateway(gateway)
+	node, err = logic.CreateEgressGateway(gateway)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v",
 			fmt.Sprintf("failed to create egress gateway on node [%s] on network [%s]: %v",
@@ -487,7 +491,12 @@ func deleteEgressGateway(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 	netid := params["network"]
 	netid := params["network"]
-	node, err := logic.DeleteEgressGateway(netid, nodeid)
+	node, err := validateParams(nodeid, netid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		return
+	}
+	node, err = logic.DeleteEgressGateway(netid, nodeid)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v",
 			fmt.Sprintf("failed to delete egress gateway on node [%s] on network [%s]: %v",
@@ -524,10 +533,14 @@ func createIngressGateway(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 	netid := params["network"]
 	netid := params["network"]
+	node, err := validateParams(nodeid, netid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		return
+	}
 	var request models.IngressRequest
 	var request models.IngressRequest
 	json.NewDecoder(r.Body).Decode(&request)
 	json.NewDecoder(r.Body).Decode(&request)
-
-	node, err := logic.CreateIngressGateway(netid, nodeid, request)
+	node, err = logic.CreateIngressGateway(netid, nodeid, request)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v",
 			fmt.Sprintf("failed to create ingress gateway on node [%s] on network [%s]: %v",
@@ -566,6 +579,11 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
 	netid := params["network"]
 	netid := params["network"]
+	node, err := validateParams(nodeid, netid)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
+		return
+	}
 	node, wasFailover, removedClients, err := logic.DeleteIngressGateway(nodeid)
 	node, wasFailover, removedClients, err := logic.DeleteIngressGateway(nodeid)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 		logger.Log(0, r.Header.Get("user"),
@@ -623,14 +641,11 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 
 
 	//start here
 	//start here
 	nodeid := params["nodeid"]
 	nodeid := params["nodeid"]
-	currentNode, err := logic.GetNodeByID(nodeid)
+	currentNode, err := validateParams(nodeid, params["network"])
 	if err != nil {
 	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
 		return
 		return
 	}
 	}
-
 	var newData models.ApiNode
 	var newData models.ApiNode
 	// we decode our body request params
 	// we decode our body request params
 	err = json.NewDecoder(r.Body).Decode(&newData)
 	err = json.NewDecoder(r.Body).Decode(&newData)
@@ -721,19 +736,13 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
 	// get params
 	// get params
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
 	var nodeid = params["nodeid"]
 	var nodeid = params["nodeid"]
-	forceDelete := r.URL.Query().Get("force") == "true"
-	fromNode := r.Header.Get("requestfrom") == "node"
-	node, err := logic.GetNodeByID(nodeid)
+	node, err := validateParams(nodeid, params["network"])
 	if err != nil {
 	if err != nil {
-		if logic.CheckAndRemoveLegacyNode(nodeid) {
-			logger.Log(0, "removed legacy node", nodeid)
-			logic.ReturnSuccessResponse(w, r, nodeid+" deleted.")
-		} else {
-			logger.Log(0, "error retrieving node to delete", err.Error())
-			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
-		}
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "bad request"))
 		return
 		return
 	}
 	}
+	forceDelete := r.URL.Query().Get("force") == "true"
+	fromNode := r.Header.Get("requestfrom") == "node"
 	if r.Header.Get("ismaster") != "yes" {
 	if r.Header.Get("ismaster") != "yes" {
 		username := r.Header.Get("user")
 		username := r.Header.Get("user")
 		if username != "" && !doesUserOwnNode(username, params["network"], nodeid) {
 		if username != "" && !doesUserOwnNode(username, params["network"], nodeid) {
@@ -816,3 +825,16 @@ func doesUserOwnNode(username, network, nodeID string) bool {
 
 
 	return logic.StringSliceContains(netUser.Nodes, nodeID)
 	return logic.StringSliceContains(netUser.Nodes, nodeID)
 }
 }
+
+func validateParams(nodeid, netid string) (models.Node, error) {
+	node, err := logic.GetNodeByID(nodeid)
+	if err != nil {
+		slog.Error("error fetching node", "node", nodeid, "error", err.Error())
+		return node, fmt.Errorf("error fetching node during parameter validation: %v", err)
+	}
+	if node.Network != netid {
+		slog.Error("network url param does not match node id", "url nodeid", netid, "node", node.Network)
+		return node, fmt.Errorf("network url param does not match node network")
+	}
+	return node, nil
+}