Browse Source

Routine-local conntrack cache (#391)

Previously, every packet we see gets a lock on the conntrack table and updates it. When running with multiple routines, this can cause heavy lock contention and limit our ability for the threads to run independently. This change caches reads from the conntrack table for a very short period of time to reduce this lock contention. This cache will currently default to disabled unless you are running with multiple routines, in which case the default cache delay will be 1 second. This means that entries in the conntrack table may be up to 1 second out of date and remain in a routine local cache for up to 1 second longer than the global table.

Instead of calling time.Now() for every packet, this cache system relies on a tick thread that updates the current cache "version" each tick. Every packet we check if the cache version is out of date, and reset the cache if so.
Wade Simmons 4 years ago
parent
commit
2a4beb41b9
8 changed files with 118 additions and 31 deletions
  1. 64 3
      firewall.go
  2. 18 18
      firewall_test.go
  3. 3 3
      inside.go
  4. 9 1
      interface.go
  5. 14 0
      main.go
  6. 4 4
      outside.go
  7. 3 1
      udp_generic.go
  8. 3 1
      udp_linux.go

+ 64 - 3
firewall.go

@@ -12,6 +12,7 @@ import (
 	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
@@ -372,9 +373,9 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 
 // Drop returns an error if the packet should be dropped, explaining why. It
 // returns nil if the packet should not be dropped.
-func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
+func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
-	if f.inConns(packet, fp, incoming, h, caPool) {
+	if f.inConns(packet, fp, incoming, h, caPool, localCache) {
 		return nil
 	}
 
@@ -426,7 +427,12 @@ func (f *Firewall) EmitStats() {
 	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
 }
 
-func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
+func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
+	if localCache != nil {
+		if _, ok := localCache[fp]; ok {
+			return true
+		}
+	}
 	conntrack := f.Conntrack
 	conntrack.Lock()
 
@@ -494,6 +500,10 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *H
 
 	conntrack.Unlock()
 
+	if localCache != nil {
+		localCache[fp] = struct{}{}
+	}
+
 	return true
 }
 
@@ -923,3 +933,54 @@ func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
 	c.Seq = 0
 	return true
 }
+
+// ConntrackCache is used as a local routine cache to know if a given flow
+// has been seen in the conntrack table.
+type ConntrackCache map[FirewallPacket]struct{}
+
+type ConntrackCacheTicker struct {
+	cacheV    uint64
+	cacheTick uint64
+
+	cache ConntrackCache
+}
+
+func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
+	if d == 0 {
+		return nil
+	}
+
+	c := &ConntrackCacheTicker{
+		cache: ConntrackCache{},
+	}
+
+	go c.tick(d)
+
+	return c
+}
+
+func (c *ConntrackCacheTicker) tick(d time.Duration) {
+	for {
+		time.Sleep(d)
+		atomic.AddUint64(&c.cacheTick, 1)
+	}
+}
+
+// Get checks if the cache ticker has moved to the next version before returning
+// the map. If it has moved, we reset the map.
+func (c *ConntrackCacheTicker) Get() ConntrackCache {
+	if c == nil {
+		return nil
+	}
+	if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
+		c.cacheV = tick
+		if ll := len(c.cache); ll > 0 {
+			if l.GetLevel() == logrus.DebugLevel {
+				l.WithField("len", ll).Debug("resetting conntrack cache")
+			}
+			c.cache = make(ConntrackCache, ll)
+		}
+	}
+
+	return c.cache
+}

+ 18 - 18
firewall_test.go

@@ -182,44 +182,44 @@ func TestFirewall_Drop(t *testing.T) {
 	cp := cert.NewCAPool()
 
 	// Drop outbound
-	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
 	// Allow inbound
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 	// Allow outbound because conntrack
-	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
 
 	// test remote mismatch
 	oldRemote := p.RemoteIP
 	p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
-	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrInvalidRemoteIP)
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
 	p.RemoteIP = oldRemote
 
 	// ensure signer doesn't get in the way of group checks
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
-	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
+	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caSha doesn't drop on match
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 
 	// ensure ca name doesn't get in the way of group checks
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
-	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp), ErrNoMatchingRule)
+	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
 
 	// test caName doesn't drop on match
 	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
 	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 }
 
 func BenchmarkFirewallTable_match(b *testing.B) {
@@ -370,10 +370,10 @@ func TestFirewall_Drop2(t *testing.T) {
 	cp := cert.NewCAPool()
 
 	// h1/c1 lacks the proper groups
-	assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp), ErrNoMatchingRule)
+	assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
 	// c has the proper groups
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 }
 
 func TestFirewall_Drop3(t *testing.T) {
@@ -454,13 +454,13 @@ func TestFirewall_Drop3(t *testing.T) {
 	cp := cert.NewCAPool()
 
 	// c1 should pass because host match
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
 	// c2 should pass because ca sha match
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
 	// c3 should fail because no match
 	resetConntrack(fw)
-	assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
+	assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
 }
 
 func TestFirewall_DropConntrackReload(t *testing.T) {
@@ -505,12 +505,12 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	cp := cert.NewCAPool()
 
 	// Drop outbound
-	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
 	// Allow inbound
 	resetConntrack(fw)
-	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
 	// Allow outbound because conntrack
-	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
 
 	oldFw := fw
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
@@ -519,7 +519,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 	// Allow outbound because conntrack and new rules allow port 10
-	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
 
 	oldFw = fw
 	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
@@ -528,7 +528,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 	fw.rulesVersion = oldFw.rulesVersion + 1
 
 	// Drop outbound because conntrack doesn't match new ruleset
-	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
 }
 
 func BenchmarkLookup(b *testing.B) {

+ 3 - 3
inside.go

@@ -7,7 +7,7 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
-func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int) {
+func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket, nb, out []byte, q int, localCache ConntrackCache) {
 	err := newPacket(packet, false, fwPacket)
 	if err != nil {
 		l.WithField("packet", packet).Debugf("Error while validating outbound packet: %s", err)
@@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 		ci.queueLock.Unlock()
 	}
 
-	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs)
+	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
 	if dropReason == nil {
 		mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 		if f.lightHouse != nil && mc%5000 == 0 {
@@ -129,7 +129,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 	}
 
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs)
+	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
 	if dropReason != nil {
 		if l.Level >= logrus.DebugLevel {
 			l.WithField("fwPacket", fp).

+ 9 - 1
interface.go

@@ -40,6 +40,8 @@ type InterfaceConfig struct {
 	routines                int
 	MessageMetrics          *MessageMetrics
 	version                 string
+
+	ConntrackCacheTimeout time.Duration
 }
 
 type Interface struct {
@@ -61,6 +63,8 @@ type Interface struct {
 	routines           int
 	version            string
 
+	conntrackCacheTimeout time.Duration
+
 	writers []*udpConn
 	readers []io.ReadWriteCloser
 
@@ -102,6 +106,8 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 		writers:            make([]*udpConn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 
+		conntrackCacheTimeout: c.ConntrackCacheTimeout,
+
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
 		messageMetrics:   c.MessageMetrics,
 	}
@@ -173,6 +179,8 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 	fwPacket := &FirewallPacket{}
 	nb := make([]byte, 12, 12)
 
+	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
+
 	for {
 		n, err := reader.Read(packet)
 		if err != nil {
@@ -181,7 +189,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 			os.Exit(2)
 		}
 
-		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i)
+		f.consumeInsidePacket(packet[:n], fwPacket, nb, out, i, conntrackCache.Get())
 	}
 }
 

+ 14 - 0
main.go

@@ -117,6 +117,18 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		}
 	}
 
+	// EXPERIMENTAL
+	// Intentionally not documented yet while we do more testing and determine
+	// a good default value.
+	conntrackCacheTimeout := config.GetDuration("firewall.conntrack.routine_cache_timeout", 0)
+	if routines > 1 && !config.IsSet("firewall.conntrack.routine_cache_timeout") {
+		// Use a different default if we are running with multiple routines
+		conntrackCacheTimeout = 1 * time.Second
+	}
+	if conntrackCacheTimeout > 0 {
+		l.WithField("duration", conntrackCacheTimeout).Info("Using routine-local conntrack cache")
+	}
+
 	var tun Inside
 	if !configTest {
 		config.CatchHUP()
@@ -359,6 +371,8 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
+
+		ConntrackCacheTimeout: conntrackCacheTimeout,
 	}
 
 	switch ifConfig.Cipher {

+ 4 - 4
outside.go

@@ -17,7 +17,7 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int) {
+func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte, header *Header, fwPacket *FirewallPacket, lhh *LightHouseHandler, nb []byte, q int, localCache ConntrackCache) {
 	err := header.Parse(packet)
 	if err != nil {
 		// TODO: best if we return this and let caller log
@@ -45,7 +45,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 			return
 		}
 
-		f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q)
+		f.decryptToTun(hostinfo, header.MessageCounter, out, packet, fwPacket, nb, q, localCache)
 
 		// Fallthrough to the bottom to record incoming traffic
 
@@ -257,7 +257,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int) {
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *FirewallPacket, nb []byte, q int, localCache ConntrackCache) {
 	var err error
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:HeaderLen], packet[HeaderLen:], messageCounter, nb)
@@ -281,7 +281,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return
 	}
 
-	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs)
+	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
 	if dropReason != nil {
 		if l.Level >= logrus.DebugLevel {
 			hostinfo.logger().WithField("fwPacket", fwPacket).

+ 3 - 1
udp_generic.go

@@ -115,6 +115,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 
 	lhh := f.lightHouse.NewRequestHandler()
 
+	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
+
 	for {
 		// Just read one packet at a time
 		n, rua, err := u.ReadFromUDP(buffer)
@@ -124,7 +126,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 		}
 
 		udpAddr.UDPAddr = *rua
-		f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q)
+		f.readOutsidePackets(udpAddr, plaintext[:0], buffer[:n], header, fwPacket, lhh, nb, q, conntrackCache.Get())
 	}
 }
 

+ 3 - 1
udp_linux.go

@@ -174,6 +174,8 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 		read = u.ReadSingle
 	}
 
+	conntrackCache := NewConntrackCacheTicker(f.conntrackCacheTimeout)
+
 	for {
 		n, err := read(msgs)
 		if err != nil {
@@ -186,7 +188,7 @@ func (u *udpConn) ListenOut(f *Interface, q int) {
 			udpAddr.IP = binary.BigEndian.Uint32(names[i][4:8])
 			udpAddr.Port = binary.BigEndian.Uint16(names[i][2:4])
 
-			f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q)
+			f.readOutsidePackets(udpAddr, plaintext[:0], buffers[i][:msgs[i].Len], header, fwPacket, lhh, nb, q, conntrackCache.Get())
 		}
 	}
 }