Browse Source

add validate check for inet request

abhishek9686 1 year ago
parent
commit
5fa03c2334
2 changed files with 24 additions and 5 deletions
  1. 2 2
      pro/controllers/inet_gws.go
  2. 22 3
      pro/logic/nodes.go

+ 2 - 2
pro/controllers/inet_gws.go

@@ -51,7 +51,7 @@ func createInternetGw(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	err = proLogic.ValidateInetGwReq(request)
+	err = proLogic.ValidateInetGwReq(node, request)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
@@ -100,7 +100,7 @@ func updateInternetGw(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("node is not a internet gw"), "badrequest"))
 		return
 	}
-	err = proLogic.ValidateInetGwReq(request)
+	err = proLogic.ValidateInetGwReq(node, request)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return

+ 22 - 3
pro/logic/nodes.go

@@ -2,19 +2,38 @@ package logic
 
 import (
 	"errors"
+	"fmt"
 
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 )
 
-func ValidateInetGwReq(req models.InetNodeReq) error {
+func ValidateInetGwReq(inetNode models.Node, req models.InetNodeReq) error {
 	for _, clientNodeID := range req.InetNodeClientIDs {
 		clientNode, err := logic.GetNodeByID(clientNodeID)
 		if err != nil {
-			continue
+			return err
+		}
+		clientHost, err := logic.GetHost(clientNode.HostID.String())
+		if err != nil {
+			return err
 		}
 		if clientNode.IsInternetGateway {
-			return errors.New("node acting as internet gateway cannot use another internet gateway")
+			return fmt.Errorf("node %s acting as internet gateway cannot use another internet gateway", clientHost.Name)
+		}
+		if clientNode.InternetGwID != "" {
+			return fmt.Errorf("node %s is already using a internet gateway", clientHost.Name)
+		}
+
+		for _, nodeID := range clientHost.Nodes {
+			node, err := logic.GetNodeByID(nodeID)
+			if err != nil {
+				continue
+			}
+			if node.InternetGwID != "" && node.InternetGwID != inetNode.ID.String() {
+				return errors.New("nodes on same host cannot use different internet gateway")
+			}
+
 		}
 	}
 	return nil