Browse Source

check if new network is overlapping (#2655)

Abhishek K 1 year ago
parent
commit
a9a237cafc
2 changed files with 40 additions and 1 deletions
  1. 24 1
      logic/networks.go
  2. 16 0
      models/network.go

+ 24 - 1
logic/networks.go

@@ -23,7 +23,6 @@ func GetNetworks() ([]models.Network, error) {
 	var networks []models.Network
 
 	collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
-
 	if err != nil {
 		return networks, err
 	}
@@ -72,6 +71,9 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 		}
 		network.AddressRange6 = normalizedRange
 	}
+	if !IsNetworkCIDRUnique(network.GetNetworkNetworkCIDR4(), network.GetNetworkNetworkCIDR6()) {
+		return models.Network{}, errors.New("network cidr already in use")
+	}
 
 	network.SetDefaults()
 	network.SetNodesLastModified()
@@ -101,6 +103,27 @@ func GetNetworkNonServerNodeCount(networkName string) (int, error) {
 	return len(nodes), err
 }
 
+func IsNetworkCIDRUnique(cidr4 *net.IPNet, cidr6 *net.IPNet) bool {
+	networks, err := GetNetworks()
+	if err != nil {
+		return database.IsEmptyRecord(err)
+	}
+	for _, network := range networks {
+		if intersect(network.GetNetworkNetworkCIDR4(), cidr4) ||
+			intersect(network.GetNetworkNetworkCIDR6(), cidr6) {
+			return false
+		}
+	}
+	return true
+}
+
+func intersect(n1, n2 *net.IPNet) bool {
+	if n1 == nil || n2 == nil {
+		return false
+	}
+	return n2.Contains(n1.IP) || n1.Contains(n2.IP)
+}
+
 // GetParentNetwork - get parent network
 func GetParentNetwork(networkname string) (models.Network, error) {
 

+ 16 - 0
models/network.go

@@ -1,6 +1,7 @@
 package models
 
 import (
+	"net"
 	"time"
 )
 
@@ -81,3 +82,18 @@ func (network *Network) SetDefaults() {
 		network.DefaultACL = "yes"
 	}
 }
+
+func (network *Network) GetNetworkNetworkCIDR4() *net.IPNet {
+	if network.AddressRange == "" {
+		return nil
+	}
+	_, netCidr, _ := net.ParseCIDR(network.AddressRange)
+	return netCidr
+}
+func (network *Network) GetNetworkNetworkCIDR6() *net.IPNet {
+	if network.AddressRange6 == "" {
+		return nil
+	}
+	_, netCidr, _ := net.ParseCIDR(network.AddressRange6)
+	return netCidr
+}