Parcourir la source

redine acl firewall model

abhishek9686 il y a 9 mois
Parent
commit
031a0c14ac
6 fichiers modifiés avec 104 ajouts et 99 suppressions
  1. 21 21
      controllers/acls.go
  2. 56 64
      logic/acls.go
  3. 2 2
      logic/peers.go
  4. 7 7
      models/acl.go
  5. 5 5
      models/mqtt.go
  6. 13 0
      models/node.go

+ 21 - 21
controllers/acls.go

@@ -67,27 +67,27 @@ func aclPolicyTypes(w http.ResponseWriter, r *http.Request) {
 				},
 				PortRange: "443",
 			},
-			{
-				Name: "MySQL",
-				AllowedProtocols: []models.Protocol{
-					models.TCP,
-				},
-				PortRange: "3306",
-			},
-			{
-				Name: "DNS TCP",
-				AllowedProtocols: []models.Protocol{
-					models.TCP,
-				},
-				PortRange: "53",
-			},
-			{
-				Name: "DNS UDP",
-				AllowedProtocols: []models.Protocol{
-					models.UDP,
-				},
-				PortRange: "53",
-			},
+			// {
+			// 	Name: "MySQL",
+			// 	AllowedProtocols: []models.Protocol{
+			// 		models.TCP,
+			// 	},
+			// 	PortRange: "3306",
+			// },
+			// {
+			// 	Name: "DNS TCP",
+			// 	AllowedProtocols: []models.Protocol{
+			// 		models.TCP,
+			// 	},
+			// 	PortRange: "53",
+			// },
+			// {
+			// 	Name: "DNS UDP",
+			// 	AllowedProtocols: []models.Protocol{
+			// 		models.UDP,
+			// 	},
+			// 	PortRange: "53",
+			// },
 			{
 				Name: "All TCP",
 				AllowedProtocols: []models.Protocol{

+ 56 - 64
logic/acls.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"net"
 	"sort"
 	"sync"
 	"time"
@@ -652,19 +653,17 @@ func RemoveDeviceTagFromAclPolicies(tagID models.TagID, netID models.NetworkID)
 	return nil
 }
 
-func GetAclRulesForNode(node *models.Node) (rules map[string][]models.AclRule) {
+func GetAclRulesForNode(node *models.Node) (rules map[string]models.AclRule) {
 	defaultPolicy, err := GetDefaultPolicy(models.NetworkID(node.Network), models.DevicePolicy)
-	rules = make(map[string][]models.AclRule)
+	rules = make(map[string]models.AclRule)
 	if err == nil && defaultPolicy.Enabled {
-		return map[string][]models.AclRule{
+		return map[string]models.AclRule{
 			defaultPolicy.ID: {
-				{
-					SrcIP:     node.NetworkRange,
-					SrcIP6:    node.NetworkRange6,
-					Proto:     []models.Protocol{models.ALL},
-					Direction: models.TrafficDirectionBi,
-					Allowed:   true,
-				},
+				IPList:    []net.IPNet{node.NetworkRange},
+				IP6List:   []net.IPNet{node.NetworkRange6},
+				Proto:     []models.Protocol{models.ALL},
+				Direction: models.TrafficDirectionBi,
+				Allowed:   true,
 			},
 		}
 	}
@@ -679,36 +678,37 @@ func GetAclRulesForNode(node *models.Node) (rules map[string][]models.AclRule) {
 			}
 			srcTags := convAclTagToValueMap(acl.Src)
 			dstTags := convAclTagToValueMap(acl.Dst)
-			aclRules := []models.AclRule{}
+			aclRule := models.AclRule{
+				Proto:     acl.Proto,
+				Port:      acl.Port,
+				Direction: acl.AllowedDirection,
+				Allowed:   true,
+			}
 			if acl.AllowedDirection == models.TrafficDirectionBi {
 				var existsInSrcTag bool
 				var existsInDstTag bool
 				// if contains all resources, return entire cidr
 				if _, ok := srcTags["*"]; ok {
-					return map[string][]models.AclRule{
+					return map[string]models.AclRule{
 						acl.ID: {
-							{
-								SrcIP:     node.NetworkRange,
-								SrcIP6:    node.NetworkRange6,
-								Proto:     []models.Protocol{models.ALL},
-								Port:      acl.Port,
-								Direction: acl.AllowedDirection,
-								Allowed:   true,
-							},
+							IPList:    []net.IPNet{node.NetworkRange},
+							IP6List:   []net.IPNet{node.NetworkRange6},
+							Proto:     []models.Protocol{models.ALL},
+							Port:      acl.Port,
+							Direction: acl.AllowedDirection,
+							Allowed:   true,
 						},
 					}
 				}
 				if _, ok := dstTags["*"]; ok {
-					return map[string][]models.AclRule{
+					return map[string]models.AclRule{
 						acl.ID: {
-							{
-								SrcIP:     node.NetworkRange,
-								SrcIP6:    node.NetworkRange6,
-								Proto:     []models.Protocol{models.ALL},
-								Port:      acl.Port,
-								Direction: acl.AllowedDirection,
-								Allowed:   true,
-							},
+							IPList:    []net.IPNet{node.NetworkRange},
+							IP6List:   []net.IPNet{node.NetworkRange6},
+							Proto:     []models.Protocol{models.ALL},
+							Port:      acl.Port,
+							Direction: acl.AllowedDirection,
+							Allowed:   true,
 						},
 					}
 				}
@@ -719,6 +719,7 @@ func GetAclRulesForNode(node *models.Node) (rules map[string][]models.AclRule) {
 				if _, ok := dstTags[nodeTag.String()]; ok {
 					existsInDstTag = true
 				}
+
 				if existsInSrcTag {
 					// get all dst tags
 					for dst := range dstTags {
@@ -728,17 +729,14 @@ func GetAclRulesForNode(node *models.Node) (rules map[string][]models.AclRule) {
 						// Get peers in the tags and add allowed rules
 						nodes := taggedNodes[models.TagID(dst)]
 						for _, node := range nodes {
-							aclRules = append(aclRules, models.AclRule{
-								SrcIP:     node.Address,
-								SrcIP6:    node.Address6,
-								Proto:     acl.Proto,
-								Port:      acl.Port,
-								Direction: acl.AllowedDirection,
-								Allowed:   true,
-							})
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
 						}
 					}
-
 				}
 				if existsInDstTag {
 					// get all src tags
@@ -749,28 +747,24 @@ func GetAclRulesForNode(node *models.Node) (rules map[string][]models.AclRule) {
 						// Get peers in the tags and add allowed rules
 						nodes := taggedNodes[models.TagID(src)]
 						for _, node := range nodes {
-							aclRules = append(aclRules, models.AclRule{
-								SrcIP:     node.Address,
-								SrcIP6:    node.Address6,
-								Proto:     acl.Proto,
-								Port:      acl.Port,
-								Direction: acl.AllowedDirection,
-								Allowed:   true,
-							})
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
 						}
 					}
 				}
 				if existsInDstTag && existsInSrcTag {
 					nodes := taggedNodes[nodeTag]
 					for _, node := range nodes {
-						aclRules = append(aclRules, models.AclRule{
-							SrcIP:     node.Address,
-							SrcIP6:    node.Address6,
-							Proto:     acl.Proto,
-							Port:      acl.Port,
-							Direction: acl.AllowedDirection,
-							Allowed:   true,
-						})
+						if node.Address.IP != nil {
+							aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+						}
+						if node.Address6.IP != nil {
+							aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+						}
 					}
 				}
 			} else {
@@ -783,20 +777,18 @@ func GetAclRulesForNode(node *models.Node) (rules map[string][]models.AclRule) {
 						// Get peers in the tags and add allowed rules
 						nodes := taggedNodes[models.TagID(src)]
 						for _, node := range nodes {
-							aclRules = append(aclRules, models.AclRule{
-								SrcIP:     node.Address,
-								SrcIP6:    node.Address6,
-								Proto:     acl.Proto,
-								Port:      acl.Port,
-								Direction: acl.AllowedDirection,
-								Allowed:   true,
-							})
+							if node.Address.IP != nil {
+								aclRule.IPList = append(aclRule.IPList, node.AddressIPNet4())
+							}
+							if node.Address6.IP != nil {
+								aclRule.IP6List = append(aclRule.IP6List, node.AddressIPNet6())
+							}
 						}
 					}
 				}
 			}
-			if len(aclRules) > 0 {
-				rules[acl.ID] = aclRules
+			if len(aclRule.IPList) > 0 || len(aclRule.IP6List) > 0 {
+				rules[acl.ID] = aclRule
 			}
 		}
 	}

+ 2 - 2
logic/peers.go

@@ -76,7 +76,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		FwUpdate: models.FwUpdate{
 			EgressInfo:  make(map[string]models.EgressInfo),
 			IngressInfo: make(map[string]models.IngressInfo),
-			AclRules:    make(map[string]map[string][]models.AclRule),
+			AclRules:    make(map[string]models.AclRule),
 		},
 		PeerIDs:           make(models.PeerMap, 0),
 		Peers:             []wgtypes.PeerConfig{},
@@ -155,7 +155,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
 		if !hostPeerUpdate.IsInternetGw {
 			hostPeerUpdate.IsInternetGw = IsInternetGw(node)
 		}
-		hostPeerUpdate.FwUpdate.AclRules[node.Network] = GetAclRulesForNode(&node)
+		hostPeerUpdate.FwUpdate.AclRules = GetAclRulesForNode(&node)
 		currentPeers := GetNetworkNodesMemory(allNodes, node.Network)
 		for _, peer := range currentPeers {
 			peer := peer

+ 7 - 7
models/acl.go

@@ -16,13 +16,13 @@ const (
 )
 
 // Protocol - allowed protocol
-type Protocol int
+type Protocol string
 
 const (
-	ALL Protocol = iota
-	UDP
-	TCP
-	ICMP
+	ALL  Protocol = "all"
+	UDP  Protocol = "udp"
+	TCP  Protocol = "tcp"
+	ICMP Protocol = "icmp"
 )
 
 type AclPolicyType string
@@ -93,8 +93,8 @@ type ProtocolType struct {
 }
 
 type AclRule struct {
-	SrcIP     net.IPNet
-	SrcIP6    net.IPNet
+	IPList    []net.IPNet
+	IP6List   []net.IPNet
 	Proto     []Protocol // tcp, udp, etc.
 	Port      []int
 	Direction AllowedTrafficDirection // inbound or outbound

+ 5 - 5
models/mqtt.go

@@ -90,11 +90,11 @@ type KeyUpdate struct {
 
 // FwUpdate - struct for firewall updates
 type FwUpdate struct {
-	IsEgressGw  bool                            `json:"is_egress_gw"`
-	IsIngressGw bool                            `json:"is_ingress_gw"`
-	EgressInfo  map[string]EgressInfo           `json:"egress_info"`
-	IngressInfo map[string]IngressInfo          `json:"ingress_info"`
-	AclRules    map[string]map[string][]AclRule `json:"acl_rules"`
+	IsEgressGw  bool                   `json:"is_egress_gw"`
+	IsIngressGw bool                   `json:"is_ingress_gw"`
+	EgressInfo  map[string]EgressInfo  `json:"egress_info"`
+	IngressInfo map[string]IngressInfo `json:"ingress_info"`
+	AclRules    map[string]AclRule     `json:"acl_rules"`
 }
 
 // FailOverMeReq - struct for failover req

+ 13 - 0
models/node.go

@@ -201,6 +201,19 @@ func (node *Node) PrimaryAddress() string {
 	return node.Address6.IP.String()
 }
 
+func (node *Node) AddressIPNet4() net.IPNet {
+	return net.IPNet{
+		IP:   node.Address.IP,
+		Mask: net.CIDRMask(32, 32),
+	}
+}
+func (node *Node) AddressIPNet6() net.IPNet {
+	return net.IPNet{
+		IP:   node.Address6.IP,
+		Mask: net.CIDRMask(128, 128),
+	}
+}
+
 // ExtClient.PrimaryAddress - returns ipv4 IPNet format
 func (extPeer *ExtClient) AddressIPNet4() net.IPNet {
 	return net.IPNet{