Przeglądaj źródła

use range field for domain check

abhishek9686 4 tygodni temu
rodzic
commit
be8f334796
5 zmienionych plików z 40 dodań i 27 usunięć
  1. 12 10
      controllers/egress.go
  2. 25 0
      logic/egress.go
  3. 1 1
      logic/peers.go
  4. 2 0
      models/mqtt.go
  5. 0 16
      schema/egress.go

+ 12 - 10
controllers/egress.go

@@ -46,20 +46,22 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	var egressRange string
+	var cidrErr error
 	if !req.IsInetGw {
-		if req.Domain != "" {
-			if !logic.IsFQDN(req.Domain) {
-				logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("bad domain name"), "badrequest"))
-				return
-			}
-		} else {
-			egressRange, err = logic.NormalizeCIDR(req.Range)
-			if err != nil {
+		egressRange, cidrErr = logic.NormalizeCIDR(req.Range)
+		isDomain := logic.IsFQDN(req.Range)
+		if cidrErr != nil && !isDomain {
+			if cidrErr != nil {
 				logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
-				return
+			} else {
+				logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("bad domain name"), "badrequest"))
 			}
+			return
+		}
+		if isDomain {
+			req.Domain = req.Range
+			egressRange = ""
 		}
-
 	} else {
 		egressRange = "*"
 		req.Domain = ""

+ 25 - 0
logic/egress.go

@@ -218,3 +218,28 @@ func GetEgressRanges(netID models.NetworkID) (map[string][]string, map[string]st
 	}
 	return nodeEgressMap, resultMap, nil
 }
+
+func ListAllByRoutingNodeWithDomain(ctx context.Context, egs []schema.Egress, nodeID string) (egWithDomain []models.EgressDomain) {
+	for _, egI := range egs {
+		if egI.Domain == "" {
+			continue
+		}
+		if _, ok := egI.Nodes[nodeID]; ok {
+			node, err := GetNodeByID(nodeID)
+			if err != nil {
+				continue
+			}
+			host, err := GetHost(node.HostID.String())
+			if err != nil {
+				continue
+			}
+			egWithDomain = append(egWithDomain, models.EgressDomain{
+				ID:     egI.ID,
+				Domain: egI.Domain,
+				Node:   node,
+				Host:   *host,
+			})
+		}
+	}
+	return
+}

+ 1 - 1
logic/peers.go

@@ -182,7 +182,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO()))
 		GetNodeEgressInfo(&node, eli, acls)
 		if node.IsEgressGateway {
-			egsWithDomain := (&schema.Egress{}).ListAllByRoutingNodeWithDomain(db.WithContext(context.TODO()), eli, node.ID.String())
+			egsWithDomain := ListAllByRoutingNodeWithDomain(db.WithContext(context.TODO()), eli, node.ID.String())
 			hostPeerUpdate.EgressWithDomains = append(hostPeerUpdate.EgressWithDomains, egsWithDomain...)
 		}
 		hostPeerUpdate = SetDefaultGw(node, hostPeerUpdate)

+ 2 - 0
models/mqtt.go

@@ -35,6 +35,8 @@ type HostPeerUpdate struct {
 
 type EgressDomain struct {
 	ID     string `json:"id"`
+	Node   Node   `json:"node"`
+	Host   Host   `json:"host"`
 	Domain string `json:"domain"`
 }
 

+ 0 - 16
schema/egress.go

@@ -5,7 +5,6 @@ import (
 	"time"
 
 	"github.com/gravitl/netmaker/db"
-	"github.com/gravitl/netmaker/models"
 	"gorm.io/datatypes"
 )
 
@@ -74,18 +73,3 @@ func (e *Egress) Count(ctx context.Context) (int, error) {
 func (e *Egress) Delete(ctx context.Context) error {
 	return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Delete(&e).Error
 }
-
-func (e *Egress) ListAllByRoutingNodeWithDomain(ctx context.Context, egs []Egress, nodeID string) (egWithDomain []models.EgressDomain) {
-	for _, egI := range egs {
-		if egI.Domain == "" {
-			continue
-		}
-		if _, ok := e.Nodes[nodeID]; ok {
-			egWithDomain = append(egWithDomain, models.EgressDomain{
-				ID:     egI.ID,
-				Domain: egI.Domain,
-			})
-		}
-	}
-	return
-}