فهرست منبع

NET-1288:add egress range check with netmaker network address (#2978)

* add egress range check with netmaker network address

* add egerssrange check for delete egressGW and extClientGW

* remove egress range check for delete
Max Ma 1 سال پیش
والد
کامیت
f63dfaf4b5
6فایلهای تغییر یافته به همراه101 افزوده شده و 1 حذف شده
  1. 22 0
      controllers/ext_client.go
  2. 6 1
      controllers/node.go
  3. 2 0
      go.mod
  4. 4 0
      go.sum
  5. 34 0
      logic/nodes.go
  6. 33 0
      logic/nodes_test.go

+ 22 - 0
controllers/ext_client.go

@@ -386,6 +386,17 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
+
+	var gateway models.EgressGatewayRequest
+	gateway.NetID = params["network"]
+	gateway.Ranges = customExtClient.ExtraAllowedIPs
+	err := logic.ValidateEgressRange(gateway)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
 	node, err := logic.GetNodeByID(nodeid)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
@@ -530,6 +541,17 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
+
+	var gateway models.EgressGatewayRequest
+	gateway.NetID = params["network"]
+	gateway.Ranges = update.ExtraAllowedIPs
+	err = logic.ValidateEgressRange(gateway)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
 	var changedID = update.ClientID != oldExtClient.ClientID
 
 	if !reflect.DeepEqual(update.DeniedACLs, oldExtClient.DeniedACLs) {

+ 6 - 1
controllers/node.go

@@ -414,7 +414,12 @@ func createEgressGateway(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	gateway.NetID = params["network"]
-	gateway.NodeID = params["nodeid"]
+	err = logic.ValidateEgressRange(gateway)
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"), "error validating egress range: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
 	node, err = logic.CreateEgressGateway(gateway)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),

+ 2 - 0
go.mod

@@ -12,6 +12,7 @@ require (
 	github.com/lib/pq v1.10.9
 	github.com/mattn/go-sqlite3 v1.14.22
 	github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa
+	github.com/seancfoley/ipaddress-go v1.6.0
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 	github.com/stretchr/testify v1.9.0
 	github.com/txn2/txeh v1.5.5
@@ -49,6 +50,7 @@ require (
 	github.com/gabriel-vasile/mimetype v1.4.3 // indirect
 	github.com/inconshreveable/mousetrap v1.1.0 // indirect
 	github.com/rivo/uniseg v0.2.0 // indirect
+	github.com/seancfoley/bintree v1.3.1 // indirect
 	github.com/spf13/pflag v1.0.5 // indirect
 )
 

+ 4 - 0
go.sum

@@ -70,6 +70,10 @@ github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa h1:hxMLFbj+F444JAS
 github.com/rqlite/gorqlite v0.0.0-20240122221808-a8a425b1a6aa/go.mod h1:xF/KoXmrRyahPfo5L7Szb5cAAUl53dMWBh9cMruGEZg=
 github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
+github.com/seancfoley/bintree v1.3.1 h1:cqmmQK7Jm4aw8gna0bP+huu5leVOgHGSJBEpUx3EXGI=
+github.com/seancfoley/bintree v1.3.1/go.mod h1:hIUabL8OFYyFVTQ6azeajbopogQc2l5C/hiXMcemWNU=
+github.com/seancfoley/ipaddress-go v1.6.0 h1:9z7yGmOnV4P2ML/dlR/kCJiv5tp8iHOOetJvxJh/R5w=
+github.com/seancfoley/ipaddress-go v1.6.0/go.mod h1:TQRZgv+9jdvzHmKoPGBMxyiaVmoI0rYpfEk8Q/sL/Iw=
 github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M=

+ 34 - 0
logic/nodes.go

@@ -19,6 +19,7 @@ import (
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/validation"
+	"github.com/seancfoley/ipaddress-go/ipaddr"
 	"golang.org/x/exp/slog"
 )
 
@@ -626,6 +627,39 @@ func ValidateParams(nodeid, netid string) (models.Node, error) {
 	return node, nil
 }
 
+func ValidateEgressRange(gateway models.EgressGatewayRequest) error {
+	network, err := GetNetworkSettings(gateway.NetID)
+	if err != nil {
+		slog.Error("error getting network with netid", "error", gateway.NetID, err.Error)
+		return errors.New("error getting network with netid:  " + gateway.NetID + " " + err.Error())
+	}
+	ipv4Net := network.AddressRange
+	ipv6Net := network.AddressRange6
+
+	for _, v := range gateway.Ranges {
+		if ipv4Net != "" {
+			if ContainsCIDR(ipv4Net, v) {
+				slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv4Net)
+				return errors.New("egress range should not be the same as or contained in the netmaker network address" + v + " " + ipv4Net)
+			}
+		}
+		if ipv6Net != "" {
+			if ContainsCIDR(ipv6Net, v) {
+				slog.Error("egress range should not be the same as or contained in the netmaker network address", "error", v, ipv6Net)
+				return errors.New("egress range should not be the same as or contained in the netmaker network address" + v + " " + ipv6Net)
+			}
+		}
+	}
+
+	return nil
+}
+
+func ContainsCIDR(net1, net2 string) bool {
+	one, two := ipaddr.NewIPAddressString(net1),
+		ipaddr.NewIPAddressString(net2)
+	return one.Contains(two) || two.Contains(one)
+}
+
 // GetAllFailOvers - gets all the nodes that are failovers
 func GetAllFailOvers() ([]models.Node, error) {
 	nodes, err := GetAllNodes()

+ 33 - 0
logic/nodes_test.go

@@ -0,0 +1,33 @@
+package logic
+
+import (
+	"testing"
+)
+
+func TestContainsCIDR(t *testing.T) {
+
+	b := ContainsCIDR("10.1.1.2/32", "10.1.1.0/24")
+	if !b {
+		t.Errorf("expected true, returned %v", b)
+	}
+
+	b = ContainsCIDR("10.1.1.2/32", "10.5.1.0/24")
+	if b {
+		t.Errorf("expected false, returned %v", b)
+	}
+
+	b = ContainsCIDR("fd52:65f5:d685:d11d::1/64", "fd52:65f5:d685:d11d::/64")
+	if !b {
+		t.Errorf("expected true, returned %v", b)
+	}
+
+	b1 := ContainsCIDR("fd10:10::/64", "fd10::/16")
+	if !b1 {
+		t.Errorf("expected true, returned %v", b1)
+	}
+
+	b1 = ContainsCIDR("fd10:10::/64", "fd10::/64")
+	if b1 {
+		t.Errorf("expected false, returned %v", b1)
+	}
+}