|
@@ -58,7 +58,9 @@ type Firewall struct {
|
|
|
DefaultTimeout time.Duration //linux: 600s
|
|
|
|
|
|
// Used to ensure we don't emit local packets for ips we don't own
|
|
|
- localIps *cidr.Tree4[struct{}]
|
|
|
+ localIps *cidr.Tree4[struct{}]
|
|
|
+ assignedCIDR *net.IPNet
|
|
|
+ hasSubnets bool
|
|
|
|
|
|
rules string
|
|
|
rulesVersion uint16
|
|
@@ -103,17 +105,22 @@ func newFirewallTable() *FirewallTable {
|
|
|
}
|
|
|
|
|
|
type FirewallCA struct {
|
|
|
- Any *firewallLocalCIDR
|
|
|
- CANames map[string]*firewallLocalCIDR
|
|
|
- CAShas map[string]*firewallLocalCIDR
|
|
|
+ Any *FirewallRule
|
|
|
+ CANames map[string]*FirewallRule
|
|
|
+ CAShas map[string]*FirewallRule
|
|
|
}
|
|
|
|
|
|
type FirewallRule struct {
|
|
|
// Any makes Hosts, Groups, and CIDR irrelevant
|
|
|
- Any bool
|
|
|
- Hosts map[string]struct{}
|
|
|
- Groups [][]string
|
|
|
- CIDR *cidr.Tree4[struct{}]
|
|
|
+ Any *firewallLocalCIDR
|
|
|
+ Hosts map[string]*firewallLocalCIDR
|
|
|
+ Groups []*firewallGroups
|
|
|
+ CIDR *cidr.Tree4[*firewallLocalCIDR]
|
|
|
+}
|
|
|
+
|
|
|
+type firewallGroups struct {
|
|
|
+ Groups []string
|
|
|
+ LocalCIDR *firewallLocalCIDR
|
|
|
}
|
|
|
|
|
|
// Even though ports are uint16, int32 maps are faster for lookup
|
|
@@ -121,8 +128,8 @@ type FirewallRule struct {
|
|
|
type firewallPort map[int32]*FirewallCA
|
|
|
|
|
|
type firewallLocalCIDR struct {
|
|
|
- Any *FirewallRule
|
|
|
- LocalCIDR *cidr.Tree4[*FirewallRule]
|
|
|
+ Any bool
|
|
|
+ LocalCIDR *cidr.Tree4[struct{}]
|
|
|
}
|
|
|
|
|
|
// NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
|
|
@@ -145,8 +152,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|
|
}
|
|
|
|
|
|
localIps := cidr.NewTree4[struct{}]()
|
|
|
+ var assignedCIDR *net.IPNet
|
|
|
for _, ip := range c.Details.Ips {
|
|
|
- localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
|
|
|
+ ipNet := &net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}
|
|
|
+ localIps.AddCIDR(ipNet, struct{}{})
|
|
|
+
|
|
|
+ if assignedCIDR == nil {
|
|
|
+ // Only grabbing the first one in the cert since any more than that currently has undefined behavior
|
|
|
+ assignedCIDR = ipNet
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
for _, n := range c.Details.Subnets {
|
|
@@ -164,6 +178,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|
|
UDPTimeout: UDPTimeout,
|
|
|
DefaultTimeout: defaultTimeout,
|
|
|
localIps: localIps,
|
|
|
+ assignedCIDR: assignedCIDR,
|
|
|
+ hasSubnets: len(c.Details.Subnets) > 0,
|
|
|
l: l,
|
|
|
|
|
|
metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
|
|
@@ -276,7 +292,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort
|
|
|
return fmt.Errorf("unknown protocol %v", proto)
|
|
|
}
|
|
|
|
|
|
- return fp.addRule(startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
|
|
+ return fp.addRule(f, startPort, endPort, groups, host, ip, localIp, caName, caSha)
|
|
|
}
|
|
|
|
|
|
// GetRuleHash returns a hash representation of all inbound and outbound rules
|
|
@@ -630,7 +646,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaC
|
|
|
return false
|
|
|
}
|
|
|
|
|
|
-func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
|
|
+func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
|
|
|
if startPort > endPort {
|
|
|
return fmt.Errorf("start port was lower than end port")
|
|
|
}
|
|
@@ -638,12 +654,12 @@ func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string,
|
|
|
for i := startPort; i <= endPort; i++ {
|
|
|
if _, ok := fp[i]; !ok {
|
|
|
fp[i] = &FirewallCA{
|
|
|
- CANames: make(map[string]*firewallLocalCIDR),
|
|
|
- CAShas: make(map[string]*firewallLocalCIDR),
|
|
|
+ CANames: make(map[string]*FirewallRule),
|
|
|
+ CAShas: make(map[string]*FirewallRule),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if err := fp[i].addRule(groups, host, ip, localIp, caName, caSha); err != nil {
|
|
|
+ if err := fp[i].addRule(f, groups, host, ip, localIp, caName, caSha); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
@@ -674,26 +690,28 @@ func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCer
|
|
|
return fp[firewall.PortAny].match(p, c, caPool)
|
|
|
}
|
|
|
|
|
|
-func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
|
|
- fl := func() *firewallLocalCIDR {
|
|
|
- return &firewallLocalCIDR{
|
|
|
- LocalCIDR: cidr.NewTree4[*FirewallRule](),
|
|
|
+func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, localIp *net.IPNet, caName, caSha string) error {
|
|
|
+ fr := func() *FirewallRule {
|
|
|
+ return &FirewallRule{
|
|
|
+ Hosts: make(map[string]*firewallLocalCIDR),
|
|
|
+ Groups: make([]*firewallGroups, 0),
|
|
|
+ CIDR: cidr.NewTree4[*firewallLocalCIDR](),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if caSha == "" && caName == "" {
|
|
|
if fc.Any == nil {
|
|
|
- fc.Any = fl()
|
|
|
+ fc.Any = fr()
|
|
|
}
|
|
|
|
|
|
- return fc.Any.addRule(groups, host, ip, localIp)
|
|
|
+ return fc.Any.addRule(f, groups, host, ip, localIp)
|
|
|
}
|
|
|
|
|
|
if caSha != "" {
|
|
|
if _, ok := fc.CAShas[caSha]; !ok {
|
|
|
- fc.CAShas[caSha] = fl()
|
|
|
+ fc.CAShas[caSha] = fr()
|
|
|
}
|
|
|
- err := fc.CAShas[caSha].addRule(groups, host, ip, localIp)
|
|
|
+ err := fc.CAShas[caSha].addRule(f, groups, host, ip, localIp)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -701,9 +719,9 @@ func (fc *FirewallCA) addRule(groups []string, host string, ip, localIp *net.IPN
|
|
|
|
|
|
if caName != "" {
|
|
|
if _, ok := fc.CANames[caName]; !ok {
|
|
|
- fc.CANames[caName] = fl()
|
|
|
+ fc.CANames[caName] = fr()
|
|
|
}
|
|
|
- err := fc.CANames[caName].addRule(groups, host, ip, localIp)
|
|
|
+ err := fc.CANames[caName].addRule(f, groups, host, ip, localIp)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
@@ -735,75 +753,56 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool
|
|
|
return fc.CANames[s.Details.Name].match(p, c)
|
|
|
}
|
|
|
|
|
|
-func (fc *firewallLocalCIDR) addRule(groups []string, host string, ip, localIp *net.IPNet) error {
|
|
|
- fr := func() *FirewallRule {
|
|
|
- return &FirewallRule{
|
|
|
- Hosts: make(map[string]struct{}),
|
|
|
- Groups: make([][]string, 0),
|
|
|
- CIDR: cidr.NewTree4[struct{}](),
|
|
|
+func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip *net.IPNet, localCIDR *net.IPNet) error {
|
|
|
+ flc := func() *firewallLocalCIDR {
|
|
|
+ return &firewallLocalCIDR{
|
|
|
+ LocalCIDR: cidr.NewTree4[struct{}](),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
|
|
|
- if fc.Any == nil {
|
|
|
- fc.Any = fr()
|
|
|
+ if fr.isAny(groups, host, ip) {
|
|
|
+ if fr.Any == nil {
|
|
|
+ fr.Any = flc()
|
|
|
}
|
|
|
|
|
|
- return fc.Any.addRule(groups, host, ip)
|
|
|
+ return fr.Any.addRule(f, localCIDR)
|
|
|
}
|
|
|
|
|
|
- _, efr := fc.LocalCIDR.GetCIDR(localIp)
|
|
|
- if efr != nil {
|
|
|
- return efr.addRule(groups, host, ip)
|
|
|
- }
|
|
|
-
|
|
|
- nfr := fr()
|
|
|
- err := nfr.addRule(groups, host, ip)
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
-
|
|
|
- fc.LocalCIDR.AddCIDR(localIp, nfr)
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
-func (fc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
|
|
- if fc == nil {
|
|
|
- return false
|
|
|
- }
|
|
|
-
|
|
|
- if fc.Any.match(p, c) {
|
|
|
- return true
|
|
|
- }
|
|
|
-
|
|
|
- return fc.LocalCIDR.EachContains(p.LocalIP, func(fr *FirewallRule) bool {
|
|
|
- return fr.match(p, c)
|
|
|
- })
|
|
|
-}
|
|
|
+ if len(groups) > 0 {
|
|
|
+ nlc := flc()
|
|
|
+ err := nlc.addRule(f, localCIDR)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
|
|
|
-func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
|
|
|
- if fr.Any {
|
|
|
- return nil
|
|
|
+ fr.Groups = append(fr.Groups, &firewallGroups{
|
|
|
+ Groups: groups,
|
|
|
+ LocalCIDR: nlc,
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
- if fr.isAny(groups, host, ip) {
|
|
|
- fr.Any = true
|
|
|
- // If it's any we need to wipe out any pre-existing rules to save on memory
|
|
|
- fr.Groups = make([][]string, 0)
|
|
|
- fr.Hosts = make(map[string]struct{})
|
|
|
- fr.CIDR = cidr.NewTree4[struct{}]()
|
|
|
- } else {
|
|
|
- if len(groups) > 0 {
|
|
|
- fr.Groups = append(fr.Groups, groups)
|
|
|
+ if host != "" {
|
|
|
+ nlc := fr.Hosts[host]
|
|
|
+ if nlc == nil {
|
|
|
+ nlc = flc()
|
|
|
}
|
|
|
-
|
|
|
- if host != "" {
|
|
|
- fr.Hosts[host] = struct{}{}
|
|
|
+ err := nlc.addRule(f, localCIDR)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
}
|
|
|
+ fr.Hosts[host] = nlc
|
|
|
+ }
|
|
|
|
|
|
- if ip != nil {
|
|
|
- fr.CIDR.AddCIDR(ip, struct{}{})
|
|
|
+ if ip != nil {
|
|
|
+ _, nlc := fr.CIDR.GetCIDR(ip)
|
|
|
+ if nlc == nil {
|
|
|
+ nlc = flc()
|
|
|
}
|
|
|
+ err := nlc.addRule(f, localCIDR)
|
|
|
+ if err != nil {
|
|
|
+ return err
|
|
|
+ }
|
|
|
+ fr.CIDR.AddCIDR(ip, nlc)
|
|
|
}
|
|
|
|
|
|
return nil
|
|
@@ -837,7 +836,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
|
|
}
|
|
|
|
|
|
// Shortcut path for if groups, hosts, or cidr contained an `any`
|
|
|
- if fr.Any {
|
|
|
+ if fr.Any.match(p, c) {
|
|
|
return true
|
|
|
}
|
|
|
|
|
@@ -845,7 +844,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
|
|
for _, sg := range fr.Groups {
|
|
|
found := false
|
|
|
|
|
|
- for _, g := range sg {
|
|
|
+ for _, g := range sg.Groups {
|
|
|
if _, ok := c.Details.InvertedGroups[g]; !ok {
|
|
|
found = false
|
|
|
break
|
|
@@ -854,26 +853,48 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool
|
|
|
found = true
|
|
|
}
|
|
|
|
|
|
- if found {
|
|
|
+ if found && sg.LocalCIDR.match(p, c) {
|
|
|
return true
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if fr.Hosts != nil {
|
|
|
- if _, ok := fr.Hosts[c.Details.Name]; ok {
|
|
|
- return true
|
|
|
+ if flc, ok := fr.Hosts[c.Details.Name]; ok {
|
|
|
+ if flc.match(p, c) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if fr.CIDR != nil {
|
|
|
- ok, _ := fr.CIDR.Contains(p.RemoteIP)
|
|
|
- if ok {
|
|
|
- return true
|
|
|
+ return fr.CIDR.EachContains(p.RemoteIP, func(flc *firewallLocalCIDR) bool {
|
|
|
+ return flc.match(p, c)
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp *net.IPNet) error {
|
|
|
+ if localIp == nil || (localIp != nil && localIp.Contains(net.IPv4(0, 0, 0, 0))) {
|
|
|
+ if !f.hasSubnets {
|
|
|
+ flc.Any = true
|
|
|
+ return nil
|
|
|
}
|
|
|
+ localIp = f.assignedCIDR
|
|
|
}
|
|
|
|
|
|
- // No host, group, or cidr matched, bye bye
|
|
|
- return false
|
|
|
+ flc.LocalCIDR.AddCIDR(localIp, struct{}{})
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool {
|
|
|
+ if flc == nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ if flc.Any {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+
|
|
|
+ ok, _ := flc.LocalCIDR.Contains(p.LocalIP)
|
|
|
+ return ok
|
|
|
}
|
|
|
|
|
|
type rule struct {
|