ソースを参照

check for overlapping networks on network creation

Matthew R. Kasun 3 年 前
コミット
f0795bd7ae
2 ファイル変更78 行追加1 行削除
  1. 35 0
      logic/network_test.go
  2. 43 1
      logic/networks.go

+ 35 - 0
logic/network_test.go

@@ -0,0 +1,35 @@
+package logic
+
+import (
+	"testing"
+
+	"github.com/gravitl/netmaker/netclient/ncutils"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestCheckOverlap(t *testing.T) {
+	_, err := ncutils.RunCmd("sudo ip link add nm-0 type wireguard", false)
+	assert.Nil(t, err)
+	_, err = ncutils.RunCmd("sudo ip a add 10.0.255.254/16 dev nm-0", false)
+	assert.Nil(t, err)
+	_, err = ncutils.RunCmd("sudo ip -6 a add 2001:db8::/64 dev nm-0", false)
+	assert.Nil(t, err)
+	t.Run("4Good", func(t *testing.T) {
+		err = CheckOverlap("10.10.10.0/24", "")
+		assert.Nil(t, err)
+	})
+	t.Run("4Bad", func(t *testing.T) {
+		err = CheckOverlap("10.0.1.0/24", "")
+		assert.NotNil(t, err)
+	})
+	t.Run("6Good", func(t *testing.T) {
+		err = CheckOverlap("", "3001:fe8::/64")
+		assert.Nil(t, err)
+	})
+	t.Run("6Bad", func(t *testing.T) {
+		err = CheckOverlap("", "2001:db8::1:0/64")
+		assert.NotNil(t, err)
+	})
+	_, err = ncutils.RunCmd("sudo ip link del nm-0", false)
+
+}

+ 43 - 1
logic/networks.go

@@ -74,7 +74,11 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 	network.SetNodesLastModified()
 	network.SetNetworkLastModified()
 
-	err := ValidateNetwork(&network, false)
+	err := CheckOverlap(network.AddressRange, network.AddressRange6)
+	if err != nil {
+		return models.Network{}, err
+	}
+	err = ValidateNetwork(&network, false)
 	if err != nil {
 		//returnErrorResponse(w, r, formatError(err, "badrequest"))
 		return models.Network{}, err
@@ -758,3 +762,41 @@ func isInterfacePresent(iface string, address string) (string, bool) {
 	// logger.Log(2, "failed to find iface", iface)
 	return "", true
 }
+
+// CheckOverlap check if new network overlaps with existing networks
+func CheckOverlap(network4, network6 string) error {
+	var net4, net6 *net.IPNet
+	locals, err := net.InterfaceAddrs()
+	if err != nil {
+		return errors.New("unable to parse local interfaces")
+	}
+	if network4 != "" {
+		_, net4, err = net.ParseCIDR(network4)
+		if err != nil {
+			return errors.New("invalid network range" + network4)
+		}
+	}
+	if network6 != "" {
+		_, net6, err = net.ParseCIDR(network6)
+		if err != nil {
+			return errors.New("invalid network range" + network6)
+		}
+	}
+	for _, local := range locals {
+		_, net, err := net.ParseCIDR(local.String())
+		if err != nil {
+			return errors.New("invalid local address")
+		}
+		if network4 != "" {
+			if net4.Contains(net.IP) || net.Contains(net4.IP) {
+				return errors.New("overlapping networks")
+			}
+		}
+		if network6 != "" {
+			if net6.Contains(net.IP) || net.Contains(net6.IP) {
+				return errors.New("overlapping networks")
+			}
+		}
+	}
+	return nil
+}