2
0
Эх сурвалжийг харах

Preserve conntrack table during firewall rules reload (SIGHUP) (#233)

Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe).

This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
Wade Simmons 5 жил өмнө
parent
commit
f3a6d8d990
3 өөрчлөгдсөн 189 нэмэгдсэн , 42 устгасан
  1. 84 24
      firewall.go
  2. 88 18
      firewall_test.go
  3. 17 0
      interface.go

+ 84 - 24
firewall.go

@@ -15,6 +15,7 @@ import (
 	"time"
 
 	"github.com/rcrowley/go-metrics"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 )
 
@@ -37,13 +38,19 @@ type FirewallInterface interface {
 
 type conn struct {
 	Expires time.Time // Time when this conntrack entry will expire
-	Seq     uint32    // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
 	Sent    time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
+	Seq     uint32    // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
+
+	// record why the original connection passed the firewall, so we can re-validate
+	// after ruleset changes. Note, rulesVersion is a uint16 so that these two
+	// fields pack for free after the uint32 above
+	incoming     bool
+	rulesVersion uint16
 }
 
 // TODO: need conntrack max tracked connections handling
 type Firewall struct {
-	Conns map[FirewallPacket]*conn
+	Conntrack *FirewallConntrack
 
 	InRules  *FirewallTable
 	OutRules *FirewallTable
@@ -54,18 +61,23 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 
-	TimerWheel *TimerWheel
-
 	// Used to ensure we don't emit local packets for ips we don't own
 	localIps *CIDRTree
 
-	connMutex sync.Mutex
-	rules     string
+	rules        string
+	rulesVersion uint16
 
 	trackTCPRTT  bool
 	metricTCPRTT metrics.Histogram
 }
 
+type FirewallConntrack struct {
+	sync.Mutex
+
+	Conns      map[FirewallPacket]*conn
+	TimerWheel *TimerWheel
+}
+
 type FirewallTable struct {
 	TCP      firewallPort
 	UDP      firewallPort
@@ -171,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
 	}
 
 	return &Firewall{
-		Conns:          make(map[FirewallPacket]*conn),
+		Conntrack: &FirewallConntrack{
+			Conns:      make(map[FirewallPacket]*conn),
+			TimerWheel: NewTimerWheel(min, max),
+		},
 		InRules:        newFirewallTable(),
 		OutRules:       newFirewallTable(),
-		TimerWheel:     NewTimerWheel(min, max),
 		TCPTimeout:     tcpTimeout,
 		UDPTimeout:     UDPTimeout,
 		DefaultTimeout: defaultTimeout,
@@ -354,7 +368,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 // returns nil if the packet should not be dropped.
 func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
-	if f.inConns(packet, fp, incoming) {
+	if f.inConns(packet, fp, incoming, h, caPool) {
 		return nil
 	}
 
@@ -398,26 +412,66 @@ func (f *Firewall) Destroy() {
 }
 
 func (f *Firewall) EmitStats() {
-	conntrackCount := len(f.Conns)
+	conntrack := f.Conntrack
+	conntrack.Lock()
+	conntrackCount := len(conntrack.Conns)
+	conntrack.Unlock()
 	metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
+	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
 }
 
-func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
-	f.connMutex.Lock()
+func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
+	conntrack := f.Conntrack
+	conntrack.Lock()
 
 	// Purge every time we test
-	ep, has := f.TimerWheel.Purge()
+	ep, has := conntrack.TimerWheel.Purge()
 	if has {
 		f.evict(ep)
 	}
 
-	c, ok := f.Conns[fp]
+	c, ok := conntrack.Conns[fp]
 
 	if !ok {
-		f.connMutex.Unlock()
+		conntrack.Unlock()
 		return false
 	}
 
+	if c.rulesVersion != f.rulesVersion {
+		// This conntrack entry was for an older rule set, validate
+		// it still passes with the current rule set
+		table := f.OutRules
+		if c.incoming {
+			table = f.InRules
+		}
+
+		// We now know which firewall table to check against
+		if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
+			if l.Level >= logrus.DebugLevel {
+				h.logger().
+					WithField("fwPacket", fp).
+					WithField("incoming", c.incoming).
+					WithField("rulesVersion", f.rulesVersion).
+					WithField("oldRulesVersion", c.rulesVersion).
+					Debugln("dropping old conntrack entry, does not match new ruleset")
+			}
+			delete(conntrack.Conns, fp)
+			conntrack.Unlock()
+			return false
+		}
+
+		if l.Level >= logrus.DebugLevel {
+			h.logger().
+				WithField("fwPacket", fp).
+				WithField("incoming", c.incoming).
+				WithField("rulesVersion", f.rulesVersion).
+				WithField("oldRulesVersion", c.rulesVersion).
+				Debugln("keeping old conntrack entry, does match new ruleset")
+		}
+
+		c.rulesVersion = f.rulesVersion
+	}
+
 	switch fp.Protocol {
 	case fwProtoTCP:
 		c.Expires = time.Now().Add(f.TCPTimeout)
@@ -432,7 +486,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
 		c.Expires = time.Now().Add(f.DefaultTimeout)
 	}
 
-	f.connMutex.Unlock()
+	conntrack.Unlock()
 
 	return true
 }
@@ -453,14 +507,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
 		timeout = f.DefaultTimeout
 	}
 
-	f.connMutex.Lock()
-	if _, ok := f.Conns[fp]; !ok {
-		f.TimerWheel.Add(fp, timeout)
+	conntrack := f.Conntrack
+	conntrack.Lock()
+	if _, ok := conntrack.Conns[fp]; !ok {
+		conntrack.TimerWheel.Add(fp, timeout)
 	}
 
+	// Record which rulesVersion allowed this connection, so we can retest after
+	// firewall reload
+	c.incoming = incoming
+	c.rulesVersion = f.rulesVersion
 	c.Expires = time.Now().Add(timeout)
-	f.Conns[fp] = c
-	f.connMutex.Unlock()
+	conntrack.Conns[fp] = c
+	conntrack.Unlock()
 }
 
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -468,7 +527,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
 func (f *Firewall) evict(p FirewallPacket) {
 	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
-	t, ok := f.Conns[p]
+	conntrack := f.Conntrack
+	t, ok := conntrack.Conns[p]
 	if !ok {
 		return
 	}
@@ -477,12 +537,12 @@ func (f *Firewall) evict(p FirewallPacket) {
 
 	// Timeout is in the future, re-add the timer
 	if newT > 0 {
-		f.TimerWheel.Add(p, newT)
+		conntrack.TimerWheel.Add(p, newT)
 		return
 	}
 
 	// This conn is done
-	delete(f.Conns, p)
+	delete(conntrack.Conns, p)
 }
 
 func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {

+ 88 - 18
firewall_test.go

@@ -17,37 +17,39 @@ import (
 func TestNewFirewall(t *testing.T) {
 	c := &cert.NebulaCertificate{}
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
-	assert.NotNil(t, fw.Conns)
+	conntrack := fw.Conntrack
+	assert.NotNil(t, conntrack)
+	assert.NotNil(t, conntrack.Conns)
+	assert.NotNil(t, conntrack.TimerWheel)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
-	assert.NotNil(t, fw.TimerWheel)
 	assert.Equal(t, time.Second, fw.TCPTimeout)
 	assert.Equal(t, time.Minute, fw.UDPTimeout)
 	assert.Equal(t, time.Hour, fw.DefaultTimeout)
 
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 	fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 	fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 	fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 	fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 	fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 }
 
 func TestFirewall_AddRule(t *testing.T) {
@@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) {
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
 }
 
+func TestFirewall_DropConntrackReload(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
+	p := FirewallPacket{
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		10,
+		90,
+		fwProtoUDP,
+		false,
+	}
+
+	ipNet := net.IPNet{
+		IP:   net.IPv4(1, 2, 3, 4),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+
+	c := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           "host1",
+			Ips:            []*net.IPNet{&ipNet},
+			Groups:         []string{"default-group"},
+			InvertedGroups: map[string]struct{}{"default-group": {}},
+			Issuer:         "signer-shasum",
+		},
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+		hostId: ip2int(ipNet.IP),
+	}
+	h.CreateRemoteCIDR(&c)
+
+	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	cp := cert.NewCAPool()
+
+	// Drop outbound
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+	// Allow inbound
+	resetConntrack(fw)
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	// Allow outbound because conntrack
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+
+	oldFw := fw
+	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
+	fw.Conntrack = oldFw.Conntrack
+	fw.rulesVersion = oldFw.rulesVersion + 1
+
+	// Allow outbound because conntrack and new rules allow port 10
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+
+	oldFw = fw
+	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
+	fw.Conntrack = oldFw.Conntrack
+	fw.rulesVersion = oldFw.rulesVersion + 1
+
+	// Drop outbound because conntrack doesn't match new ruleset
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+}
+
 func BenchmarkLookup(b *testing.B) {
 	ml := func(m map[string]struct{}, a [][]string) {
 		for n := 0; n < b.N; n++ {
@@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 }
 
 func resetConntrack(fw *Firewall) {
-	fw.connMutex.Lock()
-	fw.Conns = map[FirewallPacket]*conn{}
-	fw.connMutex.Unlock()
+	fw.Conntrack.Lock()
+	fw.Conntrack.Conns = map[FirewallPacket]*conn{}
+	fw.Conntrack.Unlock()
 }

+ 17 - 0
interface.go

@@ -219,11 +219,28 @@ func (f *Interface) reloadFirewall(c *Config) {
 	}
 
 	oldFw := f.firewall
+	conntrack := oldFw.Conntrack
+	conntrack.Lock()
+	defer conntrack.Unlock()
+
+	fw.rulesVersion = oldFw.rulesVersion + 1
+	// If rulesVersion is back to zero, we have wrapped all the way around. Be
+	// safe and just reset conntrack in this case.
+	if fw.rulesVersion == 0 {
+		l.WithField("firewallHash", fw.GetRuleHash()).
+			WithField("oldFirewallHash", oldFw.GetRuleHash()).
+			WithField("rulesVersion", fw.rulesVersion).
+			Warn("firewall rulesVersion has overflowed, resetting conntrack")
+	} else {
+		fw.Conntrack = conntrack
+	}
+
 	f.firewall = fw
 
 	oldFw.Destroy()
 	l.WithField("firewallHash", fw.GetRuleHash()).
 		WithField("oldFirewallHash", oldFw.GetRuleHash()).
+		WithField("rulesVersion", fw.rulesVersion).
 		Info("New firewall has been installed")
 }