Browse Source

add locking for stop crash

Ryan 1 month ago
parent
commit
2d128a3254
4 changed files with 60 additions and 18 deletions
  1. 5 7
      firewall.go
  2. 52 8
      firewall/cache.go
  3. 1 1
      inside.go
  4. 2 2
      outside.go

+ 5 - 7
firewall.go

@@ -423,7 +423,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) error {
+func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	if f.inConns(fp, h, caPool, localCache) {
 		return nil
@@ -490,11 +490,9 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.hash", nil).Update(int64(f.GetRuleHashFNV()))
 }
 
-func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache firewall.ConntrackCache) bool {
-	if localCache != nil {
-		if _, ok := localCache[fp]; ok {
-			return true
-		}
+func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool, localCache *firewall.ConntrackCache) bool {
+	if localCache != nil && localCache.Has(fp) {
+		return true
 	}
 	conntrack := f.Conntrack
 	conntrack.Lock()
@@ -559,7 +557,7 @@ func (f *Firewall) inConns(fp firewall.Packet, h *HostInfo, caPool *cert.CAPool,
 	conntrack.Unlock()
 
 	if localCache != nil {
-		localCache[fp] = struct{}{}
+		localCache.Add(fp)
 	}
 
 	return true

+ 52 - 8
firewall/cache.go

@@ -1,6 +1,7 @@
 package firewall
 
 import (
+	"sync"
 	"sync/atomic"
 	"time"
 
@@ -9,13 +10,58 @@ import (
 
 // ConntrackCache is used as a local routine cache to know if a given flow
 // has been seen in the conntrack table.
-type ConntrackCache map[Packet]struct{}
+type ConntrackCache struct {
+	mu      sync.Mutex
+	entries map[Packet]struct{}
+}
+
+func newConntrackCache() *ConntrackCache {
+	return &ConntrackCache{entries: make(map[Packet]struct{})}
+}
+
+func (c *ConntrackCache) Has(p Packet) bool {
+	if c == nil {
+		return false
+	}
+	c.mu.Lock()
+	_, ok := c.entries[p]
+	c.mu.Unlock()
+	return ok
+}
+
+func (c *ConntrackCache) Add(p Packet) {
+	if c == nil {
+		return
+	}
+	c.mu.Lock()
+	c.entries[p] = struct{}{}
+	c.mu.Unlock()
+}
+
+func (c *ConntrackCache) Len() int {
+	if c == nil {
+		return 0
+	}
+	c.mu.Lock()
+	l := len(c.entries)
+	c.mu.Unlock()
+	return l
+}
+
+func (c *ConntrackCache) Reset(capHint int) {
+	if c == nil {
+		return
+	}
+	c.mu.Lock()
+	c.entries = make(map[Packet]struct{}, capHint)
+	c.mu.Unlock()
+}
 
 type ConntrackCacheTicker struct {
 	cacheV    uint64
 	cacheTick atomic.Uint64
 
-	cache ConntrackCache
+	cache *ConntrackCache
 }
 
 func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
@@ -23,9 +69,7 @@ func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
 		return nil
 	}
 
-	c := &ConntrackCacheTicker{
-		cache: ConntrackCache{},
-	}
+	c := &ConntrackCacheTicker{cache: newConntrackCache()}
 
 	go c.tick(d)
 
@@ -41,17 +85,17 @@ func (c *ConntrackCacheTicker) tick(d time.Duration) {
 
 // Get checks if the cache ticker has moved to the next version before returning
 // the map. If it has moved, we reset the map.
-func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
+func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
 	if c == nil {
 		return nil
 	}
 	if tick := c.cacheTick.Load(); tick != c.cacheV {
 		c.cacheV = tick
-		if ll := len(c.cache); ll > 0 {
+		if ll := c.cache.Len(); ll > 0 {
 			if l.Level == logrus.DebugLevel {
 				l.WithField("len", ll).Debug("resetting conntrack cache")
 			}
-			c.cache = make(ConntrackCache, ll)
+			c.cache.Reset(ll)
 		}
 	}
 

+ 1 - 1
inside.go

@@ -13,7 +13,7 @@ import (
 	"github.com/slackhq/nebula/routing"
 )
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache *firewall.ConntrackCache) {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 		if f.l.Level >= logrus.DebugLevel {

+ 2 - 2
outside.go

@@ -20,7 +20,7 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache *firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
@@ -466,7 +466,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache *firewall.ConntrackCache, addr netip.AddrPort, recvIndex uint32) bool {
 	var (
 		err error
 		pkt *overlay.Packet