Browse Source

Change allow list to evaluate all vpnaddr tables when available (#1330)

Nate Brown 5 months ago
parent
commit
6a96df18cc
4 changed files with 46 additions and 32 deletions
  1. 25 11
      allow_list.go
  2. 11 11
      handshake_ix.go
  3. 1 1
      handshake_manager.go
  4. 9 9
      outside.go

+ 25 - 11
allow_list.go

@@ -250,20 +250,20 @@ func getRemoteAllowRanges(c *config.C, k string) (*bart.Table[*AllowList], error
 	return remoteAllowRanges, nil
 }
 
-func (al *AllowList) Allow(ip netip.Addr) bool {
+func (al *AllowList) Allow(addr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
 
-	result, _ := al.cidrTree.Lookup(ip)
+	result, _ := al.cidrTree.Lookup(addr)
 	return result
 }
 
-func (al *LocalAllowList) Allow(ip netip.Addr) bool {
+func (al *LocalAllowList) Allow(udpAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
 func (al *LocalAllowList) AllowName(name string) bool {
@@ -281,23 +281,37 @@ func (al *LocalAllowList) AllowName(name string) bool {
 	return !al.nameRules[0].Allow
 }
 
-func (al *RemoteAllowList) AllowUnknownVpnIp(ip netip.Addr) bool {
+func (al *RemoteAllowList) AllowUnknownVpnAddr(vpnAddr netip.Addr) bool {
 	if al == nil {
 		return true
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(vpnAddr)
 }
 
-func (al *RemoteAllowList) Allow(vpnIp netip.Addr, ip netip.Addr) bool {
-	if !al.getInsideAllowList(vpnIp).Allow(ip) {
+func (al *RemoteAllowList) Allow(vpnAddr netip.Addr, udpAddr netip.Addr) bool {
+	if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
 		return false
 	}
-	return al.AllowList.Allow(ip)
+	return al.AllowList.Allow(udpAddr)
 }
 
-func (al *RemoteAllowList) getInsideAllowList(vpnIp netip.Addr) *AllowList {
+func (al *RemoteAllowList) AllowAll(vpnAddrs []netip.Addr, udpAddr netip.Addr) bool {
+	if !al.AllowList.Allow(udpAddr) {
+		return false
+	}
+
+	for _, vpnAddr := range vpnAddrs {
+		if !al.getInsideAllowList(vpnAddr).Allow(udpAddr) {
+			return false
+		}
+	}
+
+	return true
+}
+
+func (al *RemoteAllowList) getInsideAllowList(vpnAddr netip.Addr) *AllowList {
 	if al.insideAllowLists != nil {
-		inside, ok := al.insideAllowLists.Lookup(vpnIp)
+		inside, ok := al.insideAllowLists.Lookup(vpnAddr)
 		if ok {
 			return inside
 		}

+ 11 - 11
handshake_ix.go

@@ -189,15 +189,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			return
 		}
 
-		if addr.IsValid() {
-			// addr can be invalid when the tunnel is being relayed.
-			// We only want to apply the remote allow list for direct tunnels here
-			if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) {
-				f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
-				return
-			}
-		}
-
 		// vpnAddrs outside our vpn networks are of no use to us, filter them out
 		if _, ok := f.myVpnNetworksTable.Lookup(vpnAddr); !ok {
 			continue
@@ -216,6 +207,15 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		return
 	}
 
+	if addr.IsValid() {
+		// addr can be invalid when the tunnel is being relayed.
+		// We only want to apply the remote allow list for direct tunnels here
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+			return
+		}
+	}
+
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
 		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
@@ -450,8 +450,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hostinfo := hh.hostinfo
 	if addr.IsValid() {
 		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) {
-			f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
 	}

+ 1 - 1
handshake_manager.go

@@ -138,7 +138,7 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
 func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
 	if addr.IsValid() {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.Addr()) {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
 			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}

+ 9 - 9
outside.go

@@ -231,26 +231,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) {
-	if vpnAddr.IsValid() && hostinfo.remote != vpnAddr {
-		//TODO: CERT-V2 this is weird now that we can have multiple vpn addrs
-		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
+	if udpAddr.IsValid() && hostinfo.remote != udpAddr {
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
-		if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+
+		if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(vpnAddr)
+		hostinfo.SetRemote(udpAddr)
 	}
 
 }