|
@@ -17,37 +17,39 @@ import (
|
|
|
func TestNewFirewall(t *testing.T) {
|
|
|
c := &cert.NebulaCertificate{}
|
|
|
fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
|
|
|
- assert.NotNil(t, fw.Conns)
|
|
|
+ conntrack := fw.Conntrack
|
|
|
+ assert.NotNil(t, conntrack)
|
|
|
+ assert.NotNil(t, conntrack.Conns)
|
|
|
+ assert.NotNil(t, conntrack.TimerWheel)
|
|
|
assert.NotNil(t, fw.InRules)
|
|
|
assert.NotNil(t, fw.OutRules)
|
|
|
- assert.NotNil(t, fw.TimerWheel)
|
|
|
assert.Equal(t, time.Second, fw.TCPTimeout)
|
|
|
assert.Equal(t, time.Minute, fw.UDPTimeout)
|
|
|
assert.Equal(t, time.Hour, fw.DefaultTimeout)
|
|
|
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
|
|
|
|
|
fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
|
|
|
|
|
fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
|
|
|
|
|
fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
|
|
|
|
|
fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
|
|
|
|
|
fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
|
|
|
- assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
|
|
|
- assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
|
|
|
+ assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
|
|
|
+ assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
|
|
|
}
|
|
|
|
|
|
func TestFirewall_AddRule(t *testing.T) {
|
|
@@ -461,6 +463,74 @@ func TestFirewall_Drop3(t *testing.T) {
|
|
|
assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp), ErrNoMatchingRule)
|
|
|
}
|
|
|
|
|
|
+func TestFirewall_DropConntrackReload(t *testing.T) {
|
|
|
+ ob := &bytes.Buffer{}
|
|
|
+ out := l.Out
|
|
|
+ l.SetOutput(ob)
|
|
|
+ defer l.SetOutput(out)
|
|
|
+
|
|
|
+ p := FirewallPacket{
|
|
|
+ ip2int(net.IPv4(1, 2, 3, 4)),
|
|
|
+ ip2int(net.IPv4(1, 2, 3, 4)),
|
|
|
+ 10,
|
|
|
+ 90,
|
|
|
+ fwProtoUDP,
|
|
|
+ false,
|
|
|
+ }
|
|
|
+
|
|
|
+ ipNet := net.IPNet{
|
|
|
+ IP: net.IPv4(1, 2, 3, 4),
|
|
|
+ Mask: net.IPMask{255, 255, 255, 0},
|
|
|
+ }
|
|
|
+
|
|
|
+ c := cert.NebulaCertificate{
|
|
|
+ Details: cert.NebulaCertificateDetails{
|
|
|
+ Name: "host1",
|
|
|
+ Ips: []*net.IPNet{&ipNet},
|
|
|
+ Groups: []string{"default-group"},
|
|
|
+ InvertedGroups: map[string]struct{}{"default-group": {}},
|
|
|
+ Issuer: "signer-shasum",
|
|
|
+ },
|
|
|
+ }
|
|
|
+ h := HostInfo{
|
|
|
+ ConnectionState: &ConnectionState{
|
|
|
+ peerCert: &c,
|
|
|
+ },
|
|
|
+ hostId: ip2int(ipNet.IP),
|
|
|
+ }
|
|
|
+ h.CreateRemoteCIDR(&c)
|
|
|
+
|
|
|
+ fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
|
|
+ assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
|
|
|
+ cp := cert.NewCAPool()
|
|
|
+
|
|
|
+ // Drop outbound
|
|
|
+ assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
|
|
+ // Allow inbound
|
|
|
+ resetConntrack(fw)
|
|
|
+ assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp))
|
|
|
+ // Allow outbound because conntrack
|
|
|
+ assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
|
|
+
|
|
|
+ oldFw := fw
|
|
|
+ fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
|
|
+ assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
|
|
|
+ fw.Conntrack = oldFw.Conntrack
|
|
|
+ fw.rulesVersion = oldFw.rulesVersion + 1
|
|
|
+
|
|
|
+ // Allow outbound because conntrack and new rules allow port 10
|
|
|
+ assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp))
|
|
|
+
|
|
|
+ oldFw = fw
|
|
|
+ fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
|
|
|
+ assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
|
|
|
+ fw.Conntrack = oldFw.Conntrack
|
|
|
+ fw.rulesVersion = oldFw.rulesVersion + 1
|
|
|
+
|
|
|
+ // Drop outbound because conntrack doesn't match new ruleset
|
|
|
+ assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp), ErrNoMatchingRule)
|
|
|
+}
|
|
|
+
|
|
|
func BenchmarkLookup(b *testing.B) {
|
|
|
ml := func(m map[string]struct{}, a [][]string) {
|
|
|
for n := 0; n < b.N; n++ {
|
|
@@ -861,7 +931,7 @@ func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, end
|
|
|
}
|
|
|
|
|
|
func resetConntrack(fw *Firewall) {
|
|
|
- fw.connMutex.Lock()
|
|
|
- fw.Conns = map[FirewallPacket]*conn{}
|
|
|
- fw.connMutex.Unlock()
|
|
|
+ fw.Conntrack.Lock()
|
|
|
+ fw.Conntrack.Conns = map[FirewallPacket]*conn{}
|
|
|
+ fw.Conntrack.Unlock()
|
|
|
}
|