2
0
Эх сурвалжийг харах

extend validaiton checks for egress ranges

abhishek9686 5 сар өмнө
parent
commit
4998dbdd9e

+ 2 - 2
controllers/ext_client.go

@@ -688,7 +688,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 	var gateway models.EgressGatewayRequest
 	var gateway models.EgressGatewayRequest
 	gateway.NetID = params["network"]
 	gateway.NetID = params["network"]
 	gateway.Ranges = customExtClient.ExtraAllowedIPs
 	gateway.Ranges = customExtClient.ExtraAllowedIPs
-	err := logic.ValidateEgressRange(gateway)
+	err := logic.ValidateEgressRange(gateway.NetID, gateway.Ranges)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
 		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
@@ -876,7 +876,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 	var gateway models.EgressGatewayRequest
 	var gateway models.EgressGatewayRequest
 	gateway.NetID = params["network"]
 	gateway.NetID = params["network"]
 	gateway.Ranges = update.ExtraAllowedIPs
 	gateway.Ranges = update.ExtraAllowedIPs
-	err = logic.ValidateEgressRange(gateway)
+	err = logic.ValidateEgressRange(gateway.NetID, gateway.Ranges)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
 		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))

+ 1 - 1
controllers/node.go

@@ -516,7 +516,7 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 	gateway.NetID = params["network"]
 	gateway.NetID = params["network"]
 	gateway.NodeID = params["nodeid"]
 	gateway.NodeID = params["nodeid"]
-	err = logic.ValidateEgressRange(gateway)
+	err = logic.ValidateEgressRange(gateway.NetID, gateway.Ranges)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
 		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))

+ 17 - 1
logic/egress.go

@@ -1,14 +1,30 @@
 package logic
 package logic
 
 
-import "github.com/gravitl/netmaker/models"
+import (
+	"net"
+
+	"github.com/gravitl/netmaker/models"
+)
 
 
 func ValidateEgressReq(e *models.Egress) bool {
 func ValidateEgressReq(e *models.Egress) bool {
 	if e.Network == "" {
 	if e.Network == "" {
 		return false
 		return false
 	}
 	}
+	_, err := GetNetwork(e.Network)
+	if err != nil {
+		return false
+	}
 	if e.Range == "" {
 	if e.Range == "" {
 		return false
 		return false
 	}
 	}
+	_, _, err = net.ParseCIDR(e.Range)
+	if err != nil {
+		return false
+	}
+	err = ValidateEgressRange(e.Network, []string{e.Range})
+	if err != nil {
+		return false
+	}
 	if len(e.Nodes) != 0 {
 	if len(e.Nodes) != 0 {
 		for k := range e.Nodes {
 		for k := range e.Nodes {
 			_, err := GetNodeByID(k)
 			_, err := GetNodeByID(k)

+ 5 - 5
logic/nodes.go

@@ -783,16 +783,16 @@ func ValidateNodeIp(currentNode *models.Node, newNode *models.ApiNode) error {
 	return nil
 	return nil
 }
 }
 
 
-func ValidateEgressRange(gateway models.EgressGatewayRequest) error {
-	network, err := GetNetworkSettings(gateway.NetID)
+func ValidateEgressRange(netID string, ranges []string) error {
+	network, err := GetNetworkSettings(netID)
 	if err != nil {
 	if err != nil {
-		slog.Error("error getting network with netid", "error", gateway.NetID, err.Error)
-		return errors.New("error getting network with netid:  " + gateway.NetID + " " + err.Error())
+		slog.Error("error getting network with netid", "error", netID, err.Error)
+		return errors.New("error getting network with netid:  " + netID + " " + err.Error())
 	}
 	}
 	ipv4Net := network.AddressRange
 	ipv4Net := network.AddressRange
 	ipv6Net := network.AddressRange6
 	ipv6Net := network.AddressRange6
 
 
-	for _, v := range gateway.Ranges {
+	for _, v := range ranges {
 		if ipv4Net != "" {
 		if ipv4Net != "" {
 			if ContainsCIDR(ipv4Net, v) {
 			if ContainsCIDR(ipv4Net, v) {
 				slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv4Net)
 				slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv4Net)