Browse Source

Remove unusable networks from remote tunnels at handshake time (#1318)

Nate Brown 6 months ago
parent
commit
1ad0f57c1e
5 changed files with 66 additions and 27 deletions
  1. 2 4
      firewall.go
  2. 8 7
      firewall_test.go
  3. 45 9
      handshake_ix.go
  4. 10 6
      hostmap.go
  5. 1 1
      interface.go

+ 2 - 4
firewall.go

@@ -8,7 +8,6 @@ import (
 	"hash/fnv"
 	"net/netip"
 	"reflect"
-	"slices"
 	"strconv"
 	"strings"
 	"sync"
@@ -438,9 +437,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool *
 			return ErrInvalidRemoteIP
 		}
 	} else {
-		// Simple case: Certificate has one IP and no subnets
-		//TODO: we can make this more performant
-		if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) {
+		// Simple case: Certificate has one address and no unsafe networks
+		if h.vpnAddrs[0] != fp.RemoteAddr {
 			f.metrics(incoming).droppedRemoteAddr.Inc(1)
 			return ErrInvalidRemoteIP
 		}

+ 8 - 7
firewall_test.go

@@ -152,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
 	}
-	h.buildNetworks(&c)
+	h.buildNetworks(c.networks, c.unsafeNetworks)
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -332,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	c1 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -342,11 +342,12 @@ func TestFirewall_Drop2(t *testing.T) {
 		InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
 	}
 	h1 := HostInfo{
+		vpnAddrs: []netip.Addr{network.Addr()},
 		ConnectionState: &ConnectionState{
 			peerCert: &c1,
 		},
 	}
-	h1.buildNetworks(c1.Certificate)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -394,7 +395,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h1.buildNetworks(c1.Certificate)
+	h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
 
 	c2 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -409,7 +410,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h2.buildNetworks(c2.Certificate)
+	h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
 
 	c3 := cert.CachedCertificate{
 		Certificate: &dummyCert{
@@ -424,7 +425,7 @@ func TestFirewall_Drop3(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h3.buildNetworks(c3.Certificate)
+	h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
@@ -471,7 +472,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
 		},
 		vpnAddrs: []netip.Addr{network.Addr()},
 	}
-	h.buildNetworks(c.Certificate)
+	h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
 
 	fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
 	assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))

+ 45 - 9
handshake_ix.go

@@ -172,6 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	}
 
 	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
 	certName := remoteCert.Certificate.Name()
 	fingerprint := remoteCert.Fingerprint
 	issuer := remoteCert.Certificate.Issuer()
@@ -189,15 +190,32 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		}
 
 		if addr.IsValid() {
+			// addr can be invalid when the tunnel is being relayed.
+			// We only want to apply the remote allow list for direct tunnels here
 			if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
 				f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 				return
 			}
 		}
 
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
 		vpnAddrs = append(vpnAddrs, vpnAddr)
 	}
 
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
+		return
+	}
+
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
 		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
@@ -294,7 +312,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
 	hostinfo.SetRemote(addr)
-	hostinfo.buildNetworks(remoteCert.Certificate)
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
@@ -431,7 +449,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	hostinfo := hh.hostinfo
 	if addr.IsValid() {
-		//TODO: this is kind of nonsense now
+		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
 		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
 			f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
@@ -492,7 +510,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			e = e.WithField("cert", remoteCert)
 		}
 
-		e.Info("Invalid vpn ip from host")
+		e.Info("Empty networks from host")
 		return true
 	}
 
@@ -516,9 +534,26 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 	}
 
-	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
-	for i, n := range vpnNetworks {
-		vpnAddrs[i] = n.Addr()
+	var vpnAddrs []netip.Addr
+	var filteredNetworks []netip.Prefix
+	for _, network := range vpnNetworks {
+		// vpnAddrs outside our vpn networks are of no use to us, filter them out
+		vpnAddr := network.Addr()
+		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
+			continue
+		}
+
+		filteredNetworks = append(filteredNetworks, network)
+		vpnAddrs = append(vpnAddrs, vpnAddr)
+	}
+
+	if len(vpnAddrs) == 0 {
+		f.l.WithError(err).WithField("udpAddr", addr).
+			WithField("certName", certName).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("No usable vpn addresses from host, refusing handshake")
+		return true
 	}
 
 	// Ensure the right host responded
@@ -558,7 +593,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
+	f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("issuer", issuer).
@@ -569,9 +604,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		Info("Handshake message received")
 
 	// Build up the radix for the firewall if we have subnets in the cert
-	hostinfo.buildNetworks(remoteCert.Certificate)
+	hostinfo.vpnAddrs = vpnAddrs
+	hostinfo.buildNetworks(filteredNetworks, remoteCert.Certificate.UnsafeNetworks())
 
-	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
+	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 

+ 10 - 6
hostmap.go

@@ -215,8 +215,12 @@ type HostInfo struct {
 	ConnectionState *ConnectionState
 	remoteIndexId   uint32
 	localIndexId    uint32
-	vpnAddrs        []netip.Addr
-	recvError       atomic.Uint32
+
+	// vpnAddrs is a list of vpn addresses assigned to this host that are within our own vpn networks
+	// The host may have other vpn addresses that are outside our
+	// vpn networks but were removed because they are not usable
+	vpnAddrs  []netip.Addr
+	recvError atomic.Uint32
 
 	// networks are both all vpn and unsafe networks assigned to this host
 	networks   *bart.Table[struct{}]
@@ -712,18 +716,18 @@ func (i *HostInfo) RecvErrorExceeded() bool {
 	return true
 }
 
-func (i *HostInfo) buildNetworks(c cert.Certificate) {
-	if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 {
+func (i *HostInfo) buildNetworks(networks, unsafeNetworks []netip.Prefix) {
+	if len(networks) == 1 && len(unsafeNetworks) == 0 {
 		// Simple case, no CIDRTree needed
 		return
 	}
 
 	i.networks = new(bart.Table[struct{}])
-	for _, network := range c.Networks() {
+	for _, network := range networks {
 		i.networks.Insert(network, struct{}{})
 	}
 
-	for _, network := range c.UnsafeNetworks() {
+	for _, network := range unsafeNetworks {
 		i.networks.Insert(network, struct{}{})
 	}
 }

+ 1 - 1
interface.go

@@ -64,7 +64,7 @@ type Interface struct {
 	myBroadcastAddrsTable *bart.Table[struct{}]
 	myVpnAddrs            []netip.Addr          // A list of addresses assigned to us via our certificate
 	myVpnAddrsTable       *bart.Table[struct{}] // A table of addresses assigned to us via our certificate
-	myVpnNetworks         []netip.Prefix        // A table of networks assigned to us via our certificate
+	myVpnNetworks         []netip.Prefix        // A list of networks assigned to us via our certificate
 	myVpnNetworksTable    *bart.Table[struct{}] // A table of networks assigned to us via our certificate
 	dropLocalBroadcast    bool
 	dropMulticast         bool