Jelajahi Sumber

add firewall tests for ipv6 (#1451)

Test things like cidr and local_cidr with ipv6 addresses, to ensure
everything is working correctly.
Wade Simmons 2 hari lalu
induk
melakukan
73cfa7b5b1
1 mengubah file dengan 197 tambahan dan 0 penghapusan
  1. 197 0
      firewall_test.go

+ 197 - 0
firewall_test.go

@@ -68,6 +68,9 @@ func TestFirewall_AddRule(t *testing.T) {
 	ti, err := netip.ParsePrefix("1.2.3.4/32")
 	require.NoError(t, err)
 
+	ti6, err := netip.ParsePrefix("fd12::34/128")
+	require.NoError(t, err)
+
 	require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
 	// An empty rule is any
 	assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
@@ -92,12 +95,24 @@ func TestFirewall_AddRule(t *testing.T) {
 	_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
 	assert.True(t, ok)
 
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
+	assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
+	_, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
+	assert.True(t, ok)
+
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
 	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
 	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
 	assert.True(t, ok)
 
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
+	assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
+	_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
+	assert.True(t, ok)
+
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
 	assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
@@ -117,6 +132,13 @@ func TestFirewall_AddRule(t *testing.T) {
 	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
 	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
 
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
+	anyIp6, err := netip.ParsePrefix("::/0")
+	require.NoError(t, err)
+
+	require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
+	assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
+
 	// Test error conditions
 	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
 	require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -199,6 +221,82 @@ func TestFirewall_Drop(t *testing.T) {
 	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
 }
 
+func TestFirewall_DropV6(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("fd12::34"),
+		RemoteAddr: netip.MustParseAddr("fd12::34"),
+		LocalPort:  10,
+		RemotePort: 90,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+
+	c := dummyCert{
+		name:     "host1",
+		networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
+		groups:   []string{"default-group"},
+		issuer:   "signer-shasum",
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &cert.CachedCertificate{
+				Certificate:    &c,
+				InvertedGroups: map[string]struct{}{"default-group": {}},
+			},
+		},
+		vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
+	}
+	h.buildNetworks(c.networks, c.unsafeNetworks)
+
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
+	cp := cert.NewCAPool()
+
+	// Drop outbound
+	assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
+	// Allow inbound
+	resetConntrack(fw)
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+	// Allow outbound because conntrack
+	require.NoError(t, fw.Drop(p, false, &h, cp, nil))
+
+	// test remote mismatch
+	oldRemote := p.RemoteAddr
+	p.RemoteAddr = netip.MustParseAddr("fd12::56")
+	assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
+	p.RemoteAddr = oldRemote
+
+	// ensure signer doesn't get in the way of group checks
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
+
+	// test caSha doesn't drop on match
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+
+	// ensure ca name doesn't get in the way of group checks
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
+
+	// test caName doesn't drop on match
+	cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
+	fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+}
+
 func BenchmarkFirewallTable_match(b *testing.B) {
 	f := &Firewall{}
 	ft := FirewallTable{
@@ -208,6 +306,10 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 	pfix := netip.MustParsePrefix("172.1.1.1/32")
 	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
 	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
+
+	pfix6 := netip.MustParsePrefix("fd11::11/128")
+	_ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
+	_ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
 	cp := cert.NewCAPool()
 
 	b.Run("fail on proto", func(b *testing.B) {
@@ -239,6 +341,15 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{},
+		}
+		ip := netip.MustParsePrefix("fd99::99/128")
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -252,6 +363,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
+			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
+		}
+	})
 
 	b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -265,6 +388,18 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name:     "nope",
+				networks: []netip.Prefix{netip.MustParsePrefix("fd99:99/128")},
+			},
+			InvertedGroups: map[string]struct{}{"nope": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass on group on any local cidr", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -289,6 +424,17 @@ func BenchmarkFirewallTable_match(b *testing.B) {
 			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
 		}
 	})
+	b.Run("pass on group on specific local cidr6", func(b *testing.B) {
+		c := &cert.CachedCertificate{
+			Certificate: &dummyCert{
+				name: "nope",
+			},
+			InvertedGroups: map[string]struct{}{"good-group": {}},
+		}
+		for n := 0; n < b.N; n++ {
+			assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
+		}
+	})
 
 	b.Run("pass on name", func(b *testing.B) {
 		c := &cert.CachedCertificate{
@@ -447,6 +593,42 @@ func TestFirewall_Drop3(t *testing.T) {
 	require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
 }
 
+func TestFirewall_Drop3V6(t *testing.T) {
+	l := test.NewLogger()
+	ob := &bytes.Buffer{}
+	l.SetOutput(ob)
+
+	p := firewall.Packet{
+		LocalAddr:  netip.MustParseAddr("fd12::34"),
+		RemoteAddr: netip.MustParseAddr("fd12::34"),
+		LocalPort:  1,
+		RemotePort: 1,
+		Protocol:   firewall.ProtoUDP,
+		Fragment:   false,
+	}
+
+	network := netip.MustParsePrefix("fd12::34/120")
+	c := cert.CachedCertificate{
+		Certificate: &dummyCert{
+			name:     "host-owner",
+			networks: []netip.Prefix{network},
+		},
+	}
+	h := HostInfo{
+		ConnectionState: &ConnectionState{
+			peerCert: &c,
+		},
+		vpnAddrs: []netip.Addr{network.Addr()},
+	}
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
+
+	// Test a remote address match
+	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
+	cp := cert.NewCAPool()
+	require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
+	require.NoError(t, fw.Drop(p, true, &h, cp, nil))
+}
+
 func TestFirewall_DropConntrackReload(t *testing.T) {
 	l := test.NewLogger()
 	ob := &bytes.Buffer{}
@@ -727,6 +909,21 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
 	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
 	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
 
+	// Test adding rule with cidr ipv6
+	cidr6 := netip.MustParsePrefix("fd00::/8")
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
+
+	// Test adding rule with local_cidr ipv6
+	conf = config.NewC(l)
+	mf = &mockFirewall{}
+	conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
+	require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
+	assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
+
 	// Test adding rule with ca_sha
 	conf = config.NewC(l)
 	mf = &mockFirewall{}