abhishek9686 6 месяцев назад
Родитель
Сommit
d21411392b
2 измененных файлов с 25 добавлено и 3 удалено
  1. 4 0
      controllers/node_test.go
  2. 21 3
      logic/gateway.go

+ 4 - 0
controllers/node_test.go

@@ -21,6 +21,10 @@ var linuxHost models.Host
 func TestCreateEgressGateway(t *testing.T) {
 	var gateway models.EgressGatewayRequest
 	gateway.Ranges = []string{"10.100.100.0/24"}
+	gateway.RangesWithMetric = append(gateway.RangesWithMetric, models.EgressRangeMetric{
+		Network:     "10.100.100.0/24",
+		RouteMetric: 256,
+	})
 	gateway.NetID = "skynet"
 	deleteAllNetworks()
 	createNet()

+ 21 - 3
logic/gateway.go

@@ -77,6 +77,14 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
 	if host.FirewallInUse == models.FIREWALL_NONE {
 		return models.Node{}, errors.New("please install iptables or nftables on the device")
 	}
+	if len(gateway.RangesWithMetric) == 0 && len(gateway.Ranges) > 0 {
+		for _, rangeI := range gateway.Ranges {
+			gateway.RangesWithMetric = append(gateway.RangesWithMetric, models.EgressRangeMetric{
+				Network:     rangeI,
+				RouteMetric: 256,
+			})
+		}
+	}
 	for i := len(gateway.Ranges) - 1; i >= 0; i-- {
 		// check if internet gateway IPv4
 		if gateway.Ranges[i] == "0.0.0.0/0" || gateway.Ranges[i] == "::/0" {
@@ -105,9 +113,19 @@ func CreateEgressGateway(gateway models.EgressGatewayRequest) (models.Node, erro
 	node.EgressGatewayRanges = gateway.Ranges
 	node.EgressGatewayNatEnabled = models.ParseBool(gateway.NatEnabled)
 	rangesWithMetric := []string{}
-	for i, rangeI := range gateway.RangesWithMetric {
-		rangesWithMetric = append(rangesWithMetric, rangeI.Network)
-		if rangeI.RouteMetric <= 0 || rangeI.RouteMetric > 999 {
+	for i := len(gateway.RangesWithMetric) - 1; i >= 0; i-- {
+		if gateway.RangesWithMetric[i].Network == "0.0.0.0/0" || gateway.RangesWithMetric[i].Network == "::/0" {
+			// remove inet range
+			gateway.RangesWithMetric = append(gateway.RangesWithMetric[:i], gateway.RangesWithMetric[i+1:]...)
+			continue
+		}
+		normalized, err := NormalizeCIDR(gateway.Ranges[i])
+		if err != nil {
+			return models.Node{}, err
+		}
+		gateway.RangesWithMetric[i].Network = normalized
+		rangesWithMetric = append(rangesWithMetric, gateway.RangesWithMetric[i].Network)
+		if gateway.RangesWithMetric[i].RouteMetric <= 0 || gateway.RangesWithMetric[i].RouteMetric > 999 {
 			gateway.RangesWithMetric[i].RouteMetric = 256
 		}
 	}