Quellcode durchsuchen

Preserve conntrack table during firewall rules reload (SIGHUP) (#233)

Currently, we drop the conntrack table when firewall rules change during a SIGHUP reload. This means responses to inflight HTTP requests can be dropped, among other issues. This change copies the conntrack table over to the new firewall (it holds the conntrack mutex lock during this process, to be safe).

This change also records which firewall rules hash each conntrack entry used, so that we can re-verify the rules after the new firewall has been loaded.
Wade Simmons vor 5 Jahren
Ursprung
Commit
f3a6d8d990
3 geänderte Dateien mit 189 neuen und 42 gelöschten Zeilen
  1. 84 24
      firewall.go
  2. 88 18
      firewall_test.go
  3. 17 0
      interface.go

+ 84 - 24
firewall.go

@@ -15,6 +15,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 )
 )
 
 
@@ -37,13 +38,19 @@ type FirewallInterface interface {
 
 
 type conn struct {
 type conn struct {
 	Expires time.Time // Time when this conntrack entry will expire
 	Expires time.Time // Time when this conntrack entry will expire
-	Seq     uint32    // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
 	Sent    time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
 	Sent    time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
+	Seq     uint32    // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
+
+	// record why the original connection passed the firewall, so we can re-validate
+	// after ruleset changes. Note, rulesVersion is a uint16 so that these two
+	// fields pack for free after the uint32 above
+	incoming     bool
+	rulesVersion uint16
 }
 }
 
 
 // TODO: need conntrack max tracked connections handling
 // TODO: need conntrack max tracked connections handling
 type Firewall struct {
 type Firewall struct {
-	Conns map[FirewallPacket]*conn
+	Conntrack *FirewallConntrack
 
 
 	InRules  *FirewallTable
 	InRules  *FirewallTable
 	OutRules *FirewallTable
 	OutRules *FirewallTable
@@ -54,18 +61,23 @@ type Firewall struct {
 	UDPTimeout     time.Duration //linux: 180s max
 	UDPTimeout     time.Duration //linux: 180s max
 	DefaultTimeout time.Duration //linux: 600s
 	DefaultTimeout time.Duration //linux: 600s
 
 
-	TimerWheel *TimerWheel
-
 	// Used to ensure we don't emit local packets for ips we don't own
 	// Used to ensure we don't emit local packets for ips we don't own
 	localIps *CIDRTree
 	localIps *CIDRTree
 
 
-	connMutex sync.Mutex
-	rules     string
+	rules        string
+	rulesVersion uint16
 
 
 	trackTCPRTT  bool
 	trackTCPRTT  bool
 	metricTCPRTT metrics.Histogram
 	metricTCPRTT metrics.Histogram
 }
 }
 
 
+type FirewallConntrack struct {
+	sync.Mutex
+
+	Conns      map[FirewallPacket]*conn
+	TimerWheel *TimerWheel
+}
+
 type FirewallTable struct {
 type FirewallTable struct {
 	TCP      firewallPort
 	TCP      firewallPort
 	UDP      firewallPort
 	UDP      firewallPort
@@ -171,10 +183,12 @@ func NewFirewall(tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.N
 	}
 	}
 
 
 	return &Firewall{
 	return &Firewall{
-		Conns:          make(map[FirewallPacket]*conn),
+		Conntrack: &FirewallConntrack{
+			Conns:      make(map[FirewallPacket]*conn),
+			TimerWheel: NewTimerWheel(min, max),
+		},
 		InRules:        newFirewallTable(),
 		InRules:        newFirewallTable(),
 		OutRules:       newFirewallTable(),
 		OutRules:       newFirewallTable(),
-		TimerWheel:     NewTimerWheel(min, max),
 		TCPTimeout:     tcpTimeout,
 		TCPTimeout:     tcpTimeout,
 		UDPTimeout:     UDPTimeout,
 		UDPTimeout:     UDPTimeout,
 		DefaultTimeout: defaultTimeout,
 		DefaultTimeout: defaultTimeout,
@@ -354,7 +368,7 @@ var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
 // returns nil if the packet should not be dropped.
 // returns nil if the packet should not be dropped.
 func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
 func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) error {
 	// Check if we spoke to this tuple, if we did then allow this packet
 	// Check if we spoke to this tuple, if we did then allow this packet
-	if f.inConns(packet, fp, incoming) {
+	if f.inConns(packet, fp, incoming, h, caPool) {
 		return nil
 		return nil
 	}
 	}
 
 
@@ -398,26 +412,66 @@ func (f *Firewall) Destroy() {
 }
 }
 
 
 func (f *Firewall) EmitStats() {
 func (f *Firewall) EmitStats() {
-	conntrackCount := len(f.Conns)
+	conntrack := f.Conntrack
+	conntrack.Lock()
+	conntrackCount := len(conntrack.Conns)
+	conntrack.Unlock()
 	metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
 	metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
+	metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
 }
 }
 
 
-func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool {
-	f.connMutex.Lock()
+func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool) bool {
+	conntrack := f.Conntrack
+	conntrack.Lock()
 
 
 	// Purge every time we test
 	// Purge every time we test
-	ep, has := f.TimerWheel.Purge()
+	ep, has := conntrack.TimerWheel.Purge()
 	if has {
 	if has {
 		f.evict(ep)
 		f.evict(ep)
 	}
 	}
 
 
-	c, ok := f.Conns[fp]
+	c, ok := conntrack.Conns[fp]
 
 
 	if !ok {
 	if !ok {
-		f.connMutex.Unlock()
+		conntrack.Unlock()
 		return false
 		return false
 	}
 	}
 
 
+	if c.rulesVersion != f.rulesVersion {
+		// This conntrack entry was for an older rule set, validate
+		// it still passes with the current rule set
+		table := f.OutRules
+		if c.incoming {
+			table = f.InRules
+		}
+
+		// We now know which firewall table to check against
+		if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
+			if l.Level >= logrus.DebugLevel {
+				h.logger().
+					WithField("fwPacket", fp).
+					WithField("incoming", c.incoming).
+					WithField("rulesVersion", f.rulesVersion).
+					WithField("oldRulesVersion", c.rulesVersion).
+					Debugln("dropping old conntrack entry, does not match new ruleset")
+			}
+			delete(conntrack.Conns, fp)
+			conntrack.Unlock()
+			return false
+		}
+
+		if l.Level >= logrus.DebugLevel {
+			h.logger().
+				WithField("fwPacket", fp).
+				WithField("incoming", c.incoming).
+				WithField("rulesVersion", f.rulesVersion).
+				WithField("oldRulesVersion", c.rulesVersion).
+				Debugln("keeping old conntrack entry, does match new ruleset")
+		}
+
+		c.rulesVersion = f.rulesVersion
+	}
+
 	switch fp.Protocol {
 	switch fp.Protocol {
 	case fwProtoTCP:
 	case fwProtoTCP:
 		c.Expires = time.Now().Add(f.TCPTimeout)
 		c.Expires = time.Now().Add(f.TCPTimeout)
@@ -432,7 +486,7 @@ func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool) bool
 		c.Expires = time.Now().Add(f.DefaultTimeout)
 		c.Expires = time.Now().Add(f.DefaultTimeout)
 	}
 	}
 
 
-	f.connMutex.Unlock()
+	conntrack.Unlock()
 
 
 	return true
 	return true
 }
 }
@@ -453,14 +507,19 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
 		timeout = f.DefaultTimeout
 		timeout = f.DefaultTimeout
 	}
 	}
 
 
-	f.connMutex.Lock()
-	if _, ok := f.Conns[fp]; !ok {
-		f.TimerWheel.Add(fp, timeout)
+	conntrack := f.Conntrack
+	conntrack.Lock()
+	if _, ok := conntrack.Conns[fp]; !ok {
+		conntrack.TimerWheel.Add(fp, timeout)
 	}
 	}
 
 
+	// Record which rulesVersion allowed this connection, so we can retest after
+	// firewall reload
+	c.incoming = incoming
+	c.rulesVersion = f.rulesVersion
 	c.Expires = time.Now().Add(timeout)
 	c.Expires = time.Now().Add(timeout)
-	f.Conns[fp] = c
-	f.connMutex.Unlock()
+	conntrack.Conns[fp] = c
+	conntrack.Unlock()
 }
 }
 
 
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
 // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
@@ -468,7 +527,8 @@ func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
 func (f *Firewall) evict(p FirewallPacket) {
 func (f *Firewall) evict(p FirewallPacket) {
 	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	//TODO: report a stat if the tcp rtt tracking was never resolved?
 	// Are we still tracking this conn?
 	// Are we still tracking this conn?
-	t, ok := f.Conns[p]
+	conntrack := f.Conntrack
+	t, ok := conntrack.Conns[p]
 	if !ok {
 	if !ok {
 		return
 		return
 	}
 	}
@@ -477,12 +537,12 @@ func (f *Firewall) evict(p FirewallPacket) {
 
 
 	// Timeout is in the future, re-add the timer
 	// Timeout is in the future, re-add the timer
 	if newT > 0 {
 	if newT > 0 {
-		f.TimerWheel.Add(p, newT)
+		conntrack.TimerWheel.Add(p, newT)
 		return
 		return
 	}
 	}
 
 
 	// This conn is done
 	// This conn is done
-	delete(f.Conns, p)
+	delete(conntrack.Conns, p)
 }
 }
 
 
 func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
 func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {

+ 88 - 18
firewall_test.go

@@ -17,37 +17,39 @@ import (
 func TestNewFirewall(t *testing.T) {
 func TestNewFirewall(t *testing.T) {
 	c := &cert.NebulaCertificate{}
 	c := &cert.NebulaCertificate{}
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
 	fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
-	assert.NotNil(t, fw.Conns)
+	conntrack := fw.Conntrack
+	assert.NotNil(t, conntrack)
+	assert.NotNil(t, conntrack.Conns)
+	assert.NotNil(t, conntrack.TimerWheel)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.InRules)
 	assert.NotNil(t, fw.OutRules)
 	assert.NotNil(t, fw.OutRules)
-	assert.NotNil(t, fw.TimerWheel)
 	assert.Equal(t, time.Second, fw.TCPTimeout)
 	assert.Equal(t, time.Second, fw.TCPTimeout)
 	assert.Equal(t, time.Minute, fw.UDPTimeout)
 	assert.Equal(t, time.Minute, fw.UDPTimeout)
 	assert.Equal(t, time.Hour, fw.DefaultTimeout)
 	assert.Equal(t, time.Hour, fw.DefaultTimeout)
 
 
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
 	fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
 	fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
 	fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
 	fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
 	fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
-	assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
+	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
+	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
 }
 }
 
 
 func TestFirewall_AddRule(t *testing.T) {
 func TestFirewall_AddRule(t *testing.T) {
@@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) {
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
 	assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
 }
 }
 
 
+func TestFirewall_DropConntrackReload(t *testing.T) {
+	ob := &bytes.Buffer{}
+	out := l.Out
+	l.SetOutput(ob)
+	defer l.SetOutput(out)
+
+	p := FirewallPacket{
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		ip2int(net.IPv4(1, 2, 3, 4)),
+		10,
+		90,
+		fwProtoUDP,
+		false,
+	}
+
+	ipNet := net.IPNet{
+		IP:   net.IPv4(1, 2, 3, 4),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+
+	c := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:           "host1",
+			Ips:            []*net.IPNet{&ipNet},
+			Groups:         []string{"default-group"},
+			InvertedGroups: map[string]struct{}{"default-group": {}},
+			Issuer:         "signer-shasum",
+		},
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+		hostId: ip2int(ipNet.IP),
+	}
+	h.CreateRemoteCIDR(&c)
+
+	fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
+	cp := cert.NewCAPool()
+
+	// Drop outbound
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+	// Allow inbound
+	resetConntrack(fw)
+	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
+	// Allow outbound because conntrack
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+
+	oldFw := fw
+	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
+	fw.Conntrack = oldFw.Conntrack
+	fw.rulesVersion = oldFw.rulesVersion + 1
+
+	// Allow outbound because conntrack and new rules allow port 10
+	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
+
+	oldFw = fw
+	fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
+	assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
+	fw.Conntrack = oldFw.Conntrack
+	fw.rulesVersion = oldFw.rulesVersion + 1
+
+	// Drop outbound because conntrack doesn't match new ruleset
+	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
+}
+
 func BenchmarkLookup(b *testing.B) {
 func BenchmarkLookup(b *testing.B) {
 	ml := func(m map[string]struct{}, a [][]string) {
 	ml := func(m map[string]struct{}, a [][]string) {
 		for n := 0; n < b.N; n++ {
 		for n := 0; n < b.N; n++ {
@@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
 }
 }
 
 
 func resetConntrack(fw *Firewall) {
 func resetConntrack(fw *Firewall) {
-	fw.connMutex.Lock()
-	fw.Conns = map[FirewallPacket]*conn{}
-	fw.connMutex.Unlock()
+	fw.Conntrack.Lock()
+	fw.Conntrack.Conns = map[FirewallPacket]*conn{}
+	fw.Conntrack.Unlock()
 }
 }

+ 17 - 0
interface.go

@@ -219,11 +219,28 @@ func (f *Interface) reloadFirewall(c *Config) {
 	}
 	}
 
 
 	oldFw := f.firewall
 	oldFw := f.firewall
+	conntrack := oldFw.Conntrack
+	conntrack.Lock()
+	defer conntrack.Unlock()
+
+	fw.rulesVersion = oldFw.rulesVersion + 1
+	// If rulesVersion is back to zero, we have wrapped all the way around. Be
+	// safe and just reset conntrack in this case.
+	if fw.rulesVersion == 0 {
+		l.WithField("firewallHash", fw.GetRuleHash()).
+			WithField("oldFirewallHash", oldFw.GetRuleHash()).
+			WithField("rulesVersion", fw.rulesVersion).
+			Warn("firewall rulesVersion has overflowed, resetting conntrack")
+	} else {
+		fw.Conntrack = conntrack
+	}
+
 	f.firewall = fw
 	f.firewall = fw
 
 
 	oldFw.Destroy()
 	oldFw.Destroy()
 	l.WithField("firewallHash", fw.GetRuleHash()).
 	l.WithField("firewallHash", fw.GetRuleHash()).
 		WithField("oldFirewallHash", oldFw.GetRuleHash()).
 		WithField("oldFirewallHash", oldFw.GetRuleHash()).
+		WithField("rulesVersion", fw.rulesVersion).
 		Info("New firewall has been installed")
 		Info("New firewall has been installed")
 }
 }