Explorar o código

Stab at better logging when a relay is being used (#1533)

Nate Brown hai 1 semana
pai
achega
56067afca2
Modificáronse 8 ficheiros con 180 adicións e 122 borrados
  1. 38 1
      e2e/handshakes_test.go
  2. 4 4
      e2e/helpers_test.go
  3. 60 65
      handshake_ix.go
  4. 6 6
      handshake_manager.go
  5. 25 6
      hostmap.go
  6. 1 1
      interface.go
  7. 41 34
      outside.go
  8. 5 5
      remote_list.go

+ 38 - 1
e2e/handshakes_test.go

@@ -25,11 +25,12 @@ import (
 
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
-	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil)
 	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
 
 	// Start the servers
 	myControl.Start()
@@ -38,6 +39,41 @@ func BenchmarkHotPath(b *testing.B) {
 	r := router.NewR(b, myControl, theirControl)
 	r.CancelFlowLogs()
 
+	assertTunnel(b, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	b.ResetTimer()
+
+	for n := 0; n < b.N; n++ {
+		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
+		_ = r.RouteForAllUntilTxTun(theirControl)
+	}
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func BenchmarkHotPathRelay(b *testing.B) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me     ", "10.128.0.1/24", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay  ", "10.128.0.128/24", m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them   ", "10.128.0.2/24", m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(b, myControl, relayControl, theirControl)
+	r.CancelFlowLogs()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	assertTunnel(b, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	b.ResetTimer()
+
 	for n := 0; n < b.N; n++ {
 		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
@@ -45,6 +81,7 @@ func BenchmarkHotPath(b *testing.B) {
 
 	myControl.Stop()
 	theirControl.Stop()
+	relayControl.Stop()
 }
 
 func TestGoodHandshake(t *testing.T) {

+ 4 - 4
e2e/helpers_test.go

@@ -292,7 +292,7 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 	}
 }
 
-func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
+func assertTunnel(t testing.TB, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B"))
 	bPacket := r.RouteForAllUntilTxTun(controlA)
@@ -325,7 +325,7 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpn
 	assert.Equal(t, hBinA.RemoteIndex, hAinB.LocalIndex, "Host B remote index does not match host A local index")
 }
 
-func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+func assertUdpPacket(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	if toIp.Is6() {
 		assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort)
 	} else {
@@ -333,7 +333,7 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	}
 }
 
-func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+func assertUdpPacket6(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy)
 	v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6)
 	assert.NotNil(t, v6, "No ipv6 data found")
@@ -352,7 +352,7 @@ func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr,
 	assert.Equal(t, expected, data.Payload(), "Data was incorrect")
 }
 
-func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
+func assertUdpPacket4(t testing.TB, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) {
 	packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy)
 	v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4)
 	assert.NotNil(t, v4, "No ipv4 data found")

+ 60 - 65
handshake_ix.go

@@ -99,11 +99,11 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 	return true
 }
 
-func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
+func ixHandshakeStage1(f *Interface, via ViaSender, packet []byte, h *header.H) {
 	cs := f.pki.getCertState()
 	crt := cs.GetDefaultCertificate()
 	if crt == nil {
-		f.l.WithField("udpAddr", addr).
+		f.l.WithField("from", via).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
 			WithField("certVersion", cs.initiatingVersion).
 			Error("Unable to handshake with host because no certificate is available")
@@ -112,7 +112,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Error("Failed to create connection state")
 		return
@@ -123,7 +123,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Error("Failed to call noise.ReadMessage")
 		return
@@ -132,7 +132,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Error("Failed unmarshal handshake message")
 		return
@@ -140,7 +140,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("Handshake did not contain a certificate")
 		return
@@ -153,7 +153,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			fp = "<error generating certificate fingerprint>"
 		}
 
-		e := f.l.WithError(err).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("from", via).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			WithField("certVpnNetworks", rc.Networks()).
 			WithField("certFingerprint", fp)
@@ -172,7 +172,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		if myCertOtherVersion == nil {
 			if f.l.Level >= logrus.DebugLevel {
 				f.l.WithError(err).WithFields(m{
-					"udpAddr":   addr,
+					"from":      via,
 					"handshake": m{"stage": 1, "style": "ix_psk0"},
 					"cert":      remoteCert,
 				}).Debug("Might be unable to handshake with host due to missing certificate version")
@@ -184,7 +184,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	}
 
 	if len(remoteCert.Certificate.Networks()) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("cert", remoteCert).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("No networks in certificate")
@@ -201,7 +201,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	vpnAddrs := make([]netip.Addr, len(vpnNetworks))
 	for i, network := range vpnNetworks {
 		if f.myVpnAddrsTable.Contains(network.Addr()) {
-			f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr).
+			f.l.WithField("vpnNetworks", vpnNetworks).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -215,18 +215,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		}
 	}
 
-	if addr.IsValid() {
-		// addr can be invalid when the tunnel is being relayed.
+	if !via.IsRelayed {
 		// We only want to apply the remote allow list for direct tunnels here
-		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, addr.Addr()) {
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(vpnAddrs, via.UdpAddr.Addr()) {
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
+				Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
 
 	myIndex, err := generateIndex(f.l)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -251,7 +251,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	msgRxL := f.l.WithFields(m{
 		"vpnAddrs":       vpnAddrs,
-		"udpAddr":        addr,
+		"from":           via,
 		"certName":       certName,
 		"certVersion":    certVersion,
 		"fingerprint":    fingerprint,
@@ -283,7 +283,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	hsBytes, err := hs.Marshal()
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -295,7 +295,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2)
 	msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -303,7 +303,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("fingerprint", fingerprint).
@@ -329,7 +329,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	ci.eKey = NewNebulaCipherState(eKey)
 
 	hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs)
-	hostinfo.SetRemote(addr)
+	if !via.IsRelayed {
+		hostinfo.SetRemote(via.UdpAddr)
+	}
 	hostinfo.buildNetworks(f.myVpnNetworksTable, remoteCert.Certificate)
 
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
@@ -337,7 +339,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		switch err {
 		case ErrAlreadySeen:
 			// Update remote if preferred
-			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
+			if existing.SetRemoteIfPreferred(f.hostMap, via) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -345,21 +347,21 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-			if addr.IsValid() {
-				err := f.outside.WriteTo(msg, addr)
+			if !via.IsRelayed {
+				err := f.outside.WriteTo(msg, via.UdpAddr)
 				if err != nil {
-					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						WithError(err).Error("Failed to send handshake message")
 				} else {
-					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr).
+					f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("from", via).
 						WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true).
 						Info("Handshake message sent")
 				}
 				return
 			} else {
-				if via == nil {
-					f.l.Error("Handshake send failed: both addr and via are nil.")
+				if via.relay == nil {
+					f.l.Error("Handshake send failed: both addr and via.relay are nil.")
 					return
 				}
 				hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -371,7 +373,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			}
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("oldHandshakeTime", existing.lastHandshakeTime).
@@ -387,7 +389,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			return
 		case ErrLocalIndexCollision:
 			// This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -400,7 +402,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
-			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+			f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 				WithField("certName", certName).
 				WithField("certVersion", certVersion).
 				WithField("fingerprint", fingerprint).
@@ -414,30 +416,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 
 	// Do the send
 	f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
-	if addr.IsValid() {
-		err = f.outside.WriteTo(msg, addr)
+	if !via.IsRelayed {
+		err = f.outside.WriteTo(msg, via.UdpAddr)
+		log := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
+			WithField("certName", certName).
+			WithField("certVersion", certVersion).
+			WithField("fingerprint", fingerprint).
+			WithField("issuer", issuer).
+			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
 		if err != nil {
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-				WithField("certName", certName).
-				WithField("certVersion", certVersion).
-				WithField("fingerprint", fingerprint).
-				WithField("issuer", issuer).
-				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				WithError(err).Error("Failed to send handshake")
+			log.WithError(err).Error("Failed to send handshake")
 		} else {
-			f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
-				WithField("certName", certName).
-				WithField("certVersion", certVersion).
-				WithField("fingerprint", fingerprint).
-				WithField("issuer", issuer).
-				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				Info("Handshake message sent")
+			log.Info("Handshake message sent")
 		}
 	} else {
-		if via == nil {
-			f.l.Error("Handshake send failed: both addr and via are nil.")
+		if via.relay == nil {
+			f.l.Error("Handshake send failed: both addr and via.relay are nil.")
 			return
 		}
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
@@ -462,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+func ixHandshakeStage2(f *Interface, via ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
 	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
@@ -472,10 +467,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	defer hh.Unlock()
 
 	hostinfo := hh.hostinfo
-	if addr.IsValid() {
+	if !via.IsRelayed {
 		// The vpnAddr we know about is the one we tried to handshake with, use it to apply the remote allow list.
-		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, addr.Addr()) {
-			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
+			f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return false
 		}
 	}
@@ -483,7 +478,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci := hostinfo.ConnectionState
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Error("Failed to call noise.ReadMessage")
 
@@ -492,7 +487,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		// near future
 		return false
 	} else if dKey == nil || eKey == nil {
-		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Error("Noise did not arrive at a key")
 
@@ -504,7 +499,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	hs := &NebulaHandshake{}
 	err = hs.Unmarshal(msg)
 	if err != nil || hs.Details == nil {
-		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("from", via).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
 
 		// The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again
@@ -513,7 +508,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve())
 	if err != nil {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("vpnAddrs", hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Handshake did not contain a certificate")
@@ -527,7 +522,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 			fp = "<error generating certificate fingerprint>"
 		}
 
-		e := f.l.WithError(err).WithField("udpAddr", addr).
+		e := f.l.WithError(err).WithField("from", via).
 			WithField("vpnAddrs", hostinfo.vpnAddrs).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("certFingerprint", fp).
@@ -542,7 +537,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	}
 
 	if len(remoteCert.Certificate.Networks()) == 0 {
-		f.l.WithError(err).WithField("udpAddr", addr).
+		f.l.WithError(err).WithField("from", via).
 			WithField("vpnAddrs", hostinfo.vpnAddrs).
 			WithField("cert", remoteCert).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -565,8 +560,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.eKey = NewNebulaCipherState(eKey)
 
 	// Make sure the current udpAddr being used is set for responding
-	if addr.IsValid() {
-		hostinfo.SetRemote(addr)
+	if !via.IsRelayed {
+		hostinfo.SetRemote(via.UdpAddr)
 	} else {
 		hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0])
 	}
@@ -588,7 +583,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	// Ensure the right host responded
 	if !correctHostResponded {
 		f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks).
-			WithField("udpAddr", addr).
+			WithField("from", via).
 			WithField("certName", certName).
 			WithField("certVersion", certVersion).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
@@ -602,7 +597,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 		f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) {
 			// Block the current used address
 			newHH.hostinfo.remotes = hostinfo.remotes
-			newHH.hostinfo.remotes.BlockRemote(addr)
+			newHH.hostinfo.remotes.BlockRemote(via)
 
 			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).
 				WithField("vpnNetworks", vpnNetworks).
@@ -625,7 +620,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 	ci.window.Update(f.l, 2)
 
 	duration := time.Since(hh.startTime).Nanoseconds()
-	msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
+	msgRxL := f.l.WithField("vpnAddrs", vpnAddrs).WithField("from", via).
 		WithField("certName", certName).
 		WithField("certVersion", certVersion).
 		WithField("fingerprint", fingerprint).

+ 6 - 6
handshake_manager.go

@@ -136,11 +136,11 @@ func (hm *HandshakeManager) Run(ctx context.Context) {
 	}
 }
 
-func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
+func (hm *HandshakeManager) HandleIncoming(via ViaSender, packet []byte, h *header.H) {
 	// First remote allow list check before we know the vpnIp
-	if addr.IsValid() {
-		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(addr.Addr()) {
-			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+	if !via.IsRelayed {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnAddr(via.UdpAddr.Addr()) {
+			hm.l.WithField("from", via).Debug("lighthouse.remote_allow_list denied incoming handshake")
 			return
 		}
 	}
@@ -149,11 +149,11 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender,
 	case header.HandshakeIXPSK0:
 		switch h.MessageCounter {
 		case 1:
-			ixHandshakeStage1(hm.f, addr, via, packet, h)
+			ixHandshakeStage1(hm.f, via, packet, h)
 
 		case 2:
 			newHostinfo := hm.queryIndex(h.RemoteIndex)
-			tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
+			tearDown := ixHandshakeStage2(hm.f, via, newHostinfo, packet, h)
 			if tearDown && newHostinfo != nil {
 				hm.DeleteHostInfo(newHostinfo.hostinfo)
 			}

+ 25 - 6
hostmap.go

@@ -1,7 +1,9 @@
 package nebula
 
 import (
+	"encoding/json"
 	"errors"
+	"fmt"
 	"net"
 	"net/netip"
 	"slices"
@@ -276,9 +278,25 @@ type HostInfo struct {
 }
 
 type ViaSender struct {
+	UdpAddr   netip.AddrPort
 	relayHI   *HostInfo // relayHI is the host info object of the relay
 	remoteIdx uint32    // remoteIdx is the index included in the header of the received packet
 	relay     *Relay    // relay contains the rest of the relay information, including the PeerIP of the host trying to communicate with us.
+	IsRelayed bool      // IsRelayed is true if the packet was sent through a relay
+}
+
+func (v ViaSender) String() string {
+	if v.IsRelayed {
+		return fmt.Sprintf("%s (relayed)", v.UdpAddr)
+	}
+	return v.UdpAddr.String()
+}
+
+func (v ViaSender) MarshalJSON() ([]byte, error) {
+	if v.IsRelayed {
+		return json.Marshal(m{"relay": v.UdpAddr})
+	}
+	return json.Marshal(m{"direct": v.UdpAddr})
 }
 
 type cachedPacket struct {
@@ -694,6 +712,7 @@ func (i *HostInfo) GetCert() *cert.CachedCertificate {
 	return nil
 }
 
+// TODO: Maybe use ViaSender here?
 func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 	// We copy here because we likely got this remote from a source that reuses the object
 	if i.remote != remote {
@@ -704,14 +723,14 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) {
 
 // SetRemoteIfPreferred returns true if the remote was changed. The lastRoam
 // time on the HostInfo will also be updated.
-func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) bool {
-	if !newRemote.IsValid() {
-		// relays have nil udp Addrs
+func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, via ViaSender) bool {
+	if via.IsRelayed {
 		return false
 	}
+
 	currentRemote := i.remote
 	if !currentRemote.IsValid() {
-		i.SetRemote(newRemote)
+		i.SetRemote(via.UdpAddr)
 		return true
 	}
 
@@ -724,7 +743,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 			return false
 		}
 
-		if l.Contains(newRemote.Addr()) {
+		if l.Contains(via.UdpAddr.Addr()) {
 			newIsPreferred = true
 		}
 	}
@@ -734,7 +753,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote netip.AddrPort) b
 		i.lastRoam = time.Now()
 		i.lastRoamRemote = currentRemote
 
-		i.SetRemote(newRemote)
+		i.SetRemote(via.UdpAddr)
 
 		return true
 	}

+ 1 - 1
interface.go

@@ -279,7 +279,7 @@ func (f *Interface) listenOut(i int) {
 	nb := make([]byte, 12, 12)
 
 	li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) {
-		f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
+		f.readOutsidePackets(ViaSender{UdpAddr: fromUdpAddr}, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l))
 	})
 }
 

+ 41 - 34
outside.go

@@ -19,21 +19,21 @@ const (
 	minFwPacketLen = 4
 )
 
-func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) readOutsidePackets(via ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) {
 	err := h.Parse(packet)
 	if err != nil {
 		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
 		if len(packet) > 1 {
-			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", ip, err)
+			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", via, err)
 		}
 		return
 	}
 
 	//l.Error("in packet ", header, packet[HeaderLen:])
-	if ip.IsValid() {
-		if f.myVpnNetworksTable.Contains(ip.Addr()) {
+	if !via.IsRelayed {
+		if f.myVpnNetworksTable.Contains(via.UdpAddr.Addr()) {
 			if f.l.Level >= logrus.DebugLevel {
-				f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet")
+				f.l.WithField("from", via).Debug("Refusing to process double encrypted packet")
 			}
 			return
 		}
@@ -54,8 +54,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	switch h.Type {
 	case header.Message:
-		// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
@@ -79,7 +78,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			// Successfully validated the thing. Get rid of the Relay header.
 			signedPayload = signedPayload[header.Len:]
 			// Pull the Roaming parts up here, and return in all call paths.
-			f.handleHostRoaming(hostinfo, ip)
+			f.handleHostRoaming(hostinfo, via)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			f.connectionManager.In(hostinfo)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
@@ -96,7 +95,14 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			case TerminalType:
 				// If I am the target of this relay, process the unwrapped packet
 				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
-				f.readOutsidePackets(netip.AddrPort{}, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
+				via = ViaSender{
+					UdpAddr:   via.UdpAddr,
+					relayHI:   hostinfo,
+					remoteIdx: relay.RemoteIndex,
+					relay:     relay,
+					IsRelayed: true,
+				}
+				f.readOutsidePackets(via, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
@@ -126,31 +132,32 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	case header.LightHouse:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+			hostinfo.logger(f.l).WithError(err).WithField("from", via).
 				WithField("packet", packet).
 				Error("Failed to decrypt lighthouse packet")
 			return
 		}
 
-		lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f)
+		//TODO: assert via is not relayed
+		lhf.HandleRequest(via.UdpAddr, hostinfo.vpnAddrs, d, f)
 
 		// Fallthrough to the bottom to record incoming traffic
 
 	case header.Test:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+			hostinfo.logger(f.l).WithError(err).WithField("from", via).
 				WithField("packet", packet).
 				Error("Failed to decrypt test packet")
 			return
@@ -159,7 +166,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 		if h.Subtype == header.TestRequest {
 			// This testRequest might be from TryPromoteBest, so we should roam
 			// to the new IP address before responding
-			f.handleHostRoaming(hostinfo, ip)
+			f.handleHostRoaming(hostinfo, via)
 			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
 		}
 
@@ -170,34 +177,34 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handshakeManager.HandleIncoming(ip, via, packet, h)
+		f.handshakeManager.HandleIncoming(via, packet, h)
 		return
 
 	case header.RecvError:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		f.handleRecvError(ip, h)
+		f.handleRecvError(via.UdpAddr, h)
 		return
 
 	case header.CloseTunnel:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", ip).
+		hostinfo.logger(f.l).WithField("from", via).
 			Info("Close tunnel received, tearing down.")
 
 		f.closeTunnel(hostinfo)
 		return
 
 	case header.Control:
-		if !f.handleEncrypted(ci, ip, h) {
+		if !f.handleEncrypted(ci, via, h) {
 			return
 		}
 
 		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
 		if err != nil {
-			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", ip).
+			hostinfo.logger(f.l).WithError(err).WithField("from", via).
 				WithField("packet", packet).
 				Error("Failed to decrypt Control packet")
 			return
@@ -207,11 +214,11 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	default:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", ip)
+		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", via)
 		return
 	}
 
-	f.handleHostRoaming(hostinfo, ip)
+	f.handleHostRoaming(hostinfo, via)
 
 	f.connectionManager.In(hostinfo)
 }
@@ -230,36 +237,36 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) {
 	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
 }
 
-func (f *Interface) handleHostRoaming(hostinfo *HostInfo, udpAddr netip.AddrPort) {
-	if udpAddr.IsValid() && hostinfo.remote != udpAddr {
-		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, udpAddr.Addr()) {
-			hostinfo.logger(f.l).WithField("newAddr", udpAddr).Debug("lighthouse.remote_allow_list denied roaming")
+func (f *Interface) handleHostRoaming(hostinfo *HostInfo, via ViaSender) {
+	if !via.IsRelayed && hostinfo.remote != via.UdpAddr {
+		if !f.lightHouse.GetRemoteAllowList().AllowAll(hostinfo.vpnAddrs, via.UdpAddr.Addr()) {
+			hostinfo.logger(f.l).WithField("newAddr", via.UdpAddr).Debug("lighthouse.remote_allow_list denied roaming")
 			return
 		}
 
-		if !hostinfo.lastRoam.IsZero() && udpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
+		if !hostinfo.lastRoam.IsZero() && via.UdpAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
 			if f.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
+				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
 					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
 			}
 			return
 		}
 
-		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", udpAddr).
+		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", via.UdpAddr).
 			Info("Host roamed to new udp ip/port.")
 		hostinfo.lastRoam = time.Now()
 		hostinfo.lastRoamRemote = hostinfo.remote
-		hostinfo.SetRemote(udpAddr)
+		hostinfo.SetRemote(via.UdpAddr)
 	}
 
 }
 
 // handleEncrypted returns true if a packet should be processed, false otherwise
-func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h *header.H) bool {
+func (f *Interface) handleEncrypted(ci *ConnectionState, via ViaSender, h *header.H) bool {
 	// If connectionstate does not exist, send a recv error, if possible, to encourage a fast reconnect
 	if ci == nil {
-		if addr.IsValid() {
-			f.maybeSendRecvError(addr, h.RemoteIndex)
+		if !via.IsRelayed {
+			f.maybeSendRecvError(via.UdpAddr, h.RemoteIndex)
 		}
 		return false
 	}

+ 5 - 5
remote_list.go

@@ -338,21 +338,21 @@ func (r *RemoteList) CopyCache() *CacheMap {
 }
 
 // BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
-func (r *RemoteList) BlockRemote(bad netip.AddrPort) {
-	if !bad.IsValid() {
-		// relays can have nil udp Addrs
+func (r *RemoteList) BlockRemote(bad ViaSender) {
+	if bad.IsRelayed {
 		return
 	}
+
 	r.Lock()
 	defer r.Unlock()
 
 	// Check if we already blocked this addr
-	if r.unlockedIsBad(bad) {
+	if r.unlockedIsBad(bad.UdpAddr) {
 		return
 	}
 
 	// We copy here because we are taking something else's memory and we can't trust everything
-	r.badRemotes = append(r.badRemotes, bad)
+	r.badRemotes = append(r.badRemotes, bad.UdpAddr)
 
 	// Mark the next interaction must recollect/dedupe
 	r.shouldRebuild = true