| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983 | package nebulaimport (	"bytes"	"encoding/binary"	"errors"	"math"	"net"	"testing"	"time"	"github.com/rcrowley/go-metrics"	"github.com/slackhq/nebula/cert"	"github.com/slackhq/nebula/config"	"github.com/slackhq/nebula/firewall"	"github.com/slackhq/nebula/iputil"	"github.com/slackhq/nebula/test"	"github.com/stretchr/testify/assert")func TestNewFirewall(t *testing.T) {	l := test.NewLogger()	c := &cert.NebulaCertificate{}	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)	conntrack := fw.Conntrack	assert.NotNil(t, conntrack)	assert.NotNil(t, conntrack.Conns)	assert.NotNil(t, conntrack.TimerWheel)	assert.NotNil(t, fw.InRules)	assert.NotNil(t, fw.OutRules)	assert.Equal(t, time.Second, fw.TCPTimeout)	assert.Equal(t, time.Minute, fw.UDPTimeout)	assert.Equal(t, time.Hour, fw.DefaultTimeout)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)	fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)	fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)	fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)	fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)	fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)}func TestFirewall_AddRule(t *testing.T) {	l := test.NewLogger()	ob := &bytes.Buffer{}	l.SetOutput(ob)	c := &cert.NebulaCertificate{}	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.NotNil(t, fw.InRules)	assert.NotNil(t, fw.OutRules)	_, ti, _ := net.ParseCIDR("1.2.3.4/32")	assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))	// An empty rule is any	assert.True(t, fw.InRules.TCP[1].Any.Any)	assert.Empty(t, fw.InRules.TCP[1].Any.Groups)	assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))	assert.False(t, fw.InRules.UDP[1].Any.Any)	assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")	assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))	assert.False(t, fw.InRules.ICMP[1].Any.Any)	assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)	assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))	assert.False(t, fw.OutRules.AnyProto[1].Any.Any)	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)	assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))	assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")	// Set any and clear fields	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))	assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])	assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))	assert.NotNil(t, fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP)))	// run twice just to make sure	//TODO: these ANY rules should clear the CA firewall portion	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)	assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	_, anyIp, _ := net.ParseCIDR("0.0.0.0/0")	assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))	assert.True(t, fw.OutRules.AnyProto[0].Any.Any)	// Test error conditions	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)	assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))	assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))}func TestFirewall_Drop(t *testing.T) {	l := test.NewLogger()	ob := &bytes.Buffer{}	l.SetOutput(ob)	p := firewall.Packet{		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		LocalPort:  10,		RemotePort: 90,		Protocol:   firewall.ProtoUDP,		Fragment:   false,	}	ipNet := net.IPNet{		IP:   net.IPv4(1, 2, 3, 4),		Mask: net.IPMask{255, 255, 255, 0},	}	c := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:           "host1",			Ips:            []*net.IPNet{&ipNet},			Groups:         []string{"default-group"},			InvertedGroups: map[string]struct{}{"default-group": {}},			Issuer:         "signer-shasum",		},	}	h := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c,		},		vpnIp: iputil.Ip2VpnIp(ipNet.IP),	}	h.CreateRemoteCIDR(&c)	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))	cp := cert.NewCAPool()	// Drop outbound	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)	// Allow inbound	resetConntrack(fw)	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))	// Allow outbound because conntrack	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))	// test remote mismatch	oldRemote := p.RemoteIP	p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)	p.RemoteIP = oldRemote	// ensure signer doesn't get in the way of group checks	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)	// test caSha doesn't drop on match	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))	// ensure ca name doesn't get in the way of group checks	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))	assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)	// test caName doesn't drop on match	cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))}func BenchmarkFirewallTable_match(b *testing.B) {	ft := FirewallTable{		TCP: firewallPort{},	}	_, n, _ := net.ParseCIDR("172.1.1.1/32")	_ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")	_ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")	_ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")	_ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")	_ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")	cp := cert.NewCAPool()	b.Run("fail on proto", func(b *testing.B) {		c := &cert.NebulaCertificate{}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)		}	})	b.Run("fail on port", func(b *testing.B) {		c := &cert.NebulaCertificate{}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)		}	})	b.Run("fail all group, name, and cidr", func(b *testing.B) {		_, ip, _ := net.ParseCIDR("9.254.254.254/32")		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"nope": {}},				Name:           "nope",				Ips:            []*net.IPNet{ip},			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)		}	})	b.Run("pass on group", func(b *testing.B) {		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"good-group": {}},				Name:           "nope",			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)		}	})	b.Run("pass on name", func(b *testing.B) {		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"nope": {}},				Name:           "good-host",			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)		}	})	b.Run("pass on ip", func(b *testing.B) {		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"nope": {}},				Name:           "good-host",			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)		}	})	b.Run("pass on local ip", func(b *testing.B) {		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"nope": {}},				Name:           "good-host",			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)		}	})	_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")	b.Run("pass on ip with any port", func(b *testing.B) {		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"nope": {}},				Name:           "good-host",			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)		}	})	b.Run("pass on local ip with any port", func(b *testing.B) {		ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))		c := &cert.NebulaCertificate{			Details: cert.NebulaCertificateDetails{				InvertedGroups: map[string]struct{}{"nope": {}},				Name:           "good-host",			},		}		for n := 0; n < b.N; n++ {			ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)		}	})}func TestFirewall_Drop2(t *testing.T) {	l := test.NewLogger()	ob := &bytes.Buffer{}	l.SetOutput(ob)	p := firewall.Packet{		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		LocalPort:  10,		RemotePort: 90,		Protocol:   firewall.ProtoUDP,		Fragment:   false,	}	ipNet := net.IPNet{		IP:   net.IPv4(1, 2, 3, 4),		Mask: net.IPMask{255, 255, 255, 0},	}	c := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:           "host1",			Ips:            []*net.IPNet{&ipNet},			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},		},	}	h := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c,		},		vpnIp: iputil.Ip2VpnIp(ipNet.IP),	}	h.CreateRemoteCIDR(&c)	c1 := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:           "host1",			Ips:            []*net.IPNet{&ipNet},			InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},		},	}	h1 := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c1,		},	}	h1.CreateRemoteCIDR(&c1)	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))	cp := cert.NewCAPool()	// h1/c1 lacks the proper groups	assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)	// c has the proper groups	resetConntrack(fw)	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))}func TestFirewall_Drop3(t *testing.T) {	l := test.NewLogger()	ob := &bytes.Buffer{}	l.SetOutput(ob)	p := firewall.Packet{		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		LocalPort:  1,		RemotePort: 1,		Protocol:   firewall.ProtoUDP,		Fragment:   false,	}	ipNet := net.IPNet{		IP:   net.IPv4(1, 2, 3, 4),		Mask: net.IPMask{255, 255, 255, 0},	}	c := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name: "host-owner",			Ips:  []*net.IPNet{&ipNet},		},	}	c1 := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:   "host1",			Ips:    []*net.IPNet{&ipNet},			Issuer: "signer-sha-bad",		},	}	h1 := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c1,		},		vpnIp: iputil.Ip2VpnIp(ipNet.IP),	}	h1.CreateRemoteCIDR(&c1)	c2 := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:   "host2",			Ips:    []*net.IPNet{&ipNet},			Issuer: "signer-sha",		},	}	h2 := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c2,		},		vpnIp: iputil.Ip2VpnIp(ipNet.IP),	}	h2.CreateRemoteCIDR(&c2)	c3 := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:   "host3",			Ips:    []*net.IPNet{&ipNet},			Issuer: "signer-sha-bad",		},	}	h3 := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c3,		},		vpnIp: iputil.Ip2VpnIp(ipNet.IP),	}	h3.CreateRemoteCIDR(&c3)	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))	cp := cert.NewCAPool()	// c1 should pass because host match	assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))	// c2 should pass because ca sha match	resetConntrack(fw)	assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))	// c3 should fail because no match	resetConntrack(fw)	assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)}func TestFirewall_DropConntrackReload(t *testing.T) {	l := test.NewLogger()	ob := &bytes.Buffer{}	l.SetOutput(ob)	p := firewall.Packet{		LocalIP:    iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		RemoteIP:   iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),		LocalPort:  10,		RemotePort: 90,		Protocol:   firewall.ProtoUDP,		Fragment:   false,	}	ipNet := net.IPNet{		IP:   net.IPv4(1, 2, 3, 4),		Mask: net.IPMask{255, 255, 255, 0},	}	c := cert.NebulaCertificate{		Details: cert.NebulaCertificateDetails{			Name:           "host1",			Ips:            []*net.IPNet{&ipNet},			Groups:         []string{"default-group"},			InvertedGroups: map[string]struct{}{"default-group": {}},			Issuer:         "signer-shasum",		},	}	h := HostInfo{		ConnectionState: &ConnectionState{			peerCert: &c,		},		vpnIp: iputil.Ip2VpnIp(ipNet.IP),	}	h.CreateRemoteCIDR(&c)	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))	cp := cert.NewCAPool()	// Drop outbound	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)	// Allow inbound	resetConntrack(fw)	assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))	// Allow outbound because conntrack	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))	oldFw := fw	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))	fw.Conntrack = oldFw.Conntrack	fw.rulesVersion = oldFw.rulesVersion + 1	// Allow outbound because conntrack and new rules allow port 10	assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))	oldFw = fw	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))	fw.Conntrack = oldFw.Conntrack	fw.rulesVersion = oldFw.rulesVersion + 1	// Drop outbound because conntrack doesn't match new ruleset	assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)}func BenchmarkLookup(b *testing.B) {	ml := func(m map[string]struct{}, a [][]string) {		for n := 0; n < b.N; n++ {			for _, sg := range a {				found := false				for _, g := range sg {					if _, ok := m[g]; !ok {						found = false						break					}					found = true				}				if found {					return				}			}		}	}	b.Run("array to map best", func(b *testing.B) {		m := map[string]struct{}{			"1ne": {},			"2wo": {},			"3hr": {},			"4ou": {},			"5iv": {},			"6ix": {},		}		a := [][]string{			{"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},			{"one", "2wo", "3hr", "4ou", "5iv", "6ix"},			{"one", "two", "3hr", "4ou", "5iv", "6ix"},			{"one", "two", "thr", "4ou", "5iv", "6ix"},			{"one", "two", "thr", "fou", "5iv", "6ix"},			{"one", "two", "thr", "fou", "fiv", "6ix"},			{"one", "two", "thr", "fou", "fiv", "six"},		}		for n := 0; n < b.N; n++ {			ml(m, a)		}	})	b.Run("array to map worst", func(b *testing.B) {		m := map[string]struct{}{			"one": {},			"two": {},			"thr": {},			"fou": {},			"fiv": {},			"six": {},		}		a := [][]string{			{"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},			{"one", "2wo", "3hr", "4ou", "5iv", "6ix"},			{"one", "two", "3hr", "4ou", "5iv", "6ix"},			{"one", "two", "thr", "4ou", "5iv", "6ix"},			{"one", "two", "thr", "fou", "5iv", "6ix"},			{"one", "two", "thr", "fou", "fiv", "6ix"},			{"one", "two", "thr", "fou", "fiv", "six"},		}		for n := 0; n < b.N; n++ {			ml(m, a)		}	})	//TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster}func Test_parsePort(t *testing.T) {	_, _, err := parsePort("")	assert.EqualError(t, err, "was not a number; ``")	_, _, err = parsePort("  ")	assert.EqualError(t, err, "was not a number; `  `")	_, _, err = parsePort("-")	assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")	_, _, err = parsePort(" - ")	assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")	_, _, err = parsePort("a-b")	assert.EqualError(t, err, "beginning range was not a number; `a`")	_, _, err = parsePort("1-b")	assert.EqualError(t, err, "ending range was not a number; `b`")	s, e, err := parsePort(" 1 - 2    ")	assert.Equal(t, int32(1), s)	assert.Equal(t, int32(2), e)	assert.Nil(t, err)	s, e, err = parsePort("0-1")	assert.Equal(t, int32(0), s)	assert.Equal(t, int32(0), e)	assert.Nil(t, err)	s, e, err = parsePort("9919")	assert.Equal(t, int32(9919), s)	assert.Equal(t, int32(9919), e)	assert.Nil(t, err)	s, e, err = parsePort("any")	assert.Equal(t, int32(0), s)	assert.Equal(t, int32(0), e)	assert.Nil(t, err)}func TestNewFirewallFromConfig(t *testing.T) {	l := test.NewLogger()	// Test a bad rule definition	c := &cert.NebulaCertificate{}	conf := config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}	_, err := NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")	// Test both port and code	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")	// Test missing host, group, cidr, ca_name and ca_sha	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")	// Test code/port error	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")	// Test proto error	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")	// Test cidr parse error	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")	// Test local_cidr parse error	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")	// Test both group and groups	conf = config.NewC(l)	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}	_, err = NewFirewallFromConfig(l, c, conf)	assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")}func TestAddFirewallRulesFromConfig(t *testing.T) {	l := test.NewLogger()	// Test adding tcp rule	conf := config.NewC(l)	mf := &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)	// Test adding udp rule	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)	// Test adding icmp rule	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))	assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)	// Test adding any rule	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)	// Test adding rule with cidr	cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)	// Test adding rule with local_cidr	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)	// Test adding rule with ca_sha	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)	// Test adding rule with ca_name	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)	// Test single group	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)	// Test single groups	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)	// Test multiple AND groups	conf = config.NewC(l)	mf = &mockFirewall{}	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}	assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)	// Test Add error	conf = config.NewC(l)	mf = &mockFirewall{}	mf.nextCallReturn = errors.New("test error")	conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}	assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")}func TestTCPRTTTracking(t *testing.T) {	b := make([]byte, 200)	// Max ip IHL (60 bytes) and tcp IHL (60 bytes)	b[0] = 15	b[60+12] = 15 << 4	f := Firewall{		metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),	}	// Set SEQ to 1	binary.BigEndian.PutUint32(b[60+4:60+8], 1)	c := &conn{}	setTCPRTTTracking(c, b)	assert.Equal(t, uint32(1), c.Seq)	// Bad ack - no ack flag	binary.BigEndian.PutUint32(b[60+8:60+12], 80)	assert.False(t, f.checkTCPRTT(c, b))	// Bad ack, number is too low	binary.BigEndian.PutUint32(b[60+8:60+12], 0)	b[60+13] = uint8(0x10)	assert.False(t, f.checkTCPRTT(c, b))	// Good ack	binary.BigEndian.PutUint32(b[60+8:60+12], 80)	assert.True(t, f.checkTCPRTT(c, b))	assert.Equal(t, uint32(0), c.Seq)	// Set SEQ to 1	binary.BigEndian.PutUint32(b[60+4:60+8], 1)	c = &conn{}	setTCPRTTTracking(c, b)	assert.Equal(t, uint32(1), c.Seq)	// Good acks	binary.BigEndian.PutUint32(b[60+8:60+12], 81)	assert.True(t, f.checkTCPRTT(c, b))	assert.Equal(t, uint32(0), c.Seq)	// Set SEQ to max uint32 - 20	binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)	c = &conn{}	setTCPRTTTracking(c, b)	assert.Equal(t, ^uint32(0)-20, c.Seq)	// Good acks	binary.BigEndian.PutUint32(b[60+8:60+12], 81)	assert.True(t, f.checkTCPRTT(c, b))	assert.Equal(t, uint32(0), c.Seq)	// Set SEQ to max uint32 / 2	binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)	c = &conn{}	setTCPRTTTracking(c, b)	assert.Equal(t, ^uint32(0)/2, c.Seq)	// Below	binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)	assert.False(t, f.checkTCPRTT(c, b))	assert.Equal(t, ^uint32(0)/2, c.Seq)	// Halfway below	binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))	assert.False(t, f.checkTCPRTT(c, b))	assert.Equal(t, ^uint32(0)/2, c.Seq)	// Halfway above is ok	binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))	assert.True(t, f.checkTCPRTT(c, b))	assert.Equal(t, uint32(0), c.Seq)	// Set SEQ to max uint32	binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))	c = &conn{}	setTCPRTTTracking(c, b)	assert.Equal(t, ^uint32(0), c.Seq)	// Halfway + 1 above	binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)	assert.False(t, f.checkTCPRTT(c, b))	assert.Equal(t, ^uint32(0), c.Seq)	// Halfway above	binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)	assert.True(t, f.checkTCPRTT(c, b))	assert.Equal(t, uint32(0), c.Seq)}func TestFirewall_convertRule(t *testing.T) {	l := test.NewLogger()	ob := &bytes.Buffer{}	l.SetOutput(ob)	// Ensure group array of 1 is converted and a warning is printed	c := map[interface{}]interface{}{		"group": []interface{}{"group1"},	}	r, err := convertRule(l, c, "test", 1)	assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")	assert.Nil(t, err)	assert.Equal(t, "group1", r.Group)	// Ensure group array of > 1 is errord	ob.Reset()	c = map[interface{}]interface{}{		"group": []interface{}{"group1", "group2"},	}	r, err = convertRule(l, c, "test", 1)	assert.Equal(t, "", ob.String())	assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")	// Make sure a well formed group is alright	ob.Reset()	c = map[interface{}]interface{}{		"group": "group1",	}	r, err = convertRule(l, c, "test", 1)	assert.Nil(t, err)	assert.Equal(t, "group1", r.Group)}type addRuleCall struct {	incoming  bool	proto     uint8	startPort int32	endPort   int32	groups    []string	host      string	ip        *net.IPNet	localIp   *net.IPNet	caName    string	caSha     string}type mockFirewall struct {	lastCall       addRuleCall	nextCallReturn error}func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {	mf.lastCall = addRuleCall{		incoming:  incoming,		proto:     proto,		startPort: startPort,		endPort:   endPort,		groups:    groups,		host:      host,		ip:        ip,		localIp:   localIp,		caName:    caName,		caSha:     caSha,	}	err := mf.nextCallReturn	mf.nextCallReturn = nil	return err}func resetConntrack(fw *Firewall) {	fw.Conntrack.Lock()	fw.Conntrack.Conns = map[firewall.Packet]*conn{}	fw.Conntrack.Unlock()}
 |