Browse Source

test and stupid fix

JackDoan 2 months ago
parent
commit
5fa386bb70
5 changed files with 109 additions and 11 deletions
  1. 19 2
      e2e/helpers_test.go
  2. 47 0
      e2e/tunnels_test.go
  3. 21 4
      handshake_manager.go
  4. 5 1
      hostmap.go
  5. 17 4
      relay_manager.go

+ 19 - 2
e2e/helpers_test.go

@@ -29,8 +29,6 @@ type m = map[string]any
 
 // newSimpleServer creates a nebula instance with many assumptions
 func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
-	l := NewTestLogger()
-
 	var vpnNetworks []netip.Prefix
 	for _, sn := range strings.Split(sVpnNetworks, ",") {
 		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
+	return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
+}
+
+func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	l := NewTestLogger()
+
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
+	}
+
 	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
 	caB, err := caCrt.MarshalPEM()

+ 47 - 0
e2e/tunnels_test.go

@@ -318,3 +318,50 @@ func TestCertMismatchCorrection(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestCrossStackRelaysWork(t *testing.T) {
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me     ", "10.128.0.1/24,fc00::1/64", m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "relay  ", "10.128.0.128/24,fc00::128/64", m{"relay": m{"am_relay": true}})
+	theirUdp := netip.MustParseAddrPort("10.0.0.2:4242")
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServerWithUdp(cert.Version2, ca, caKey, "them   ", "fc00::2/64", theirUdp, m{"relay": m{"use_relays": true}})
+
+	//myVpnV4 := myVpnIpNet[0]
+	myVpnV6 := myVpnIpNet[1]
+	relayVpnV4 := relayVpnIpNet[0]
+	relayVpnV6 := relayVpnIpNet[1]
+	theirVpnV6 := theirVpnIpNet[0]
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnV4.Addr(), relayUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnV6.Addr(), relayUdpAddr)
+	myControl.InjectRelays(theirVpnV6.Addr(), []netip.Addr{relayVpnV6.Addr()})
+	relayControl.InjectLightHouseAddr(theirVpnV6.Addr(), theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnV6.Addr(), 80, myVpnV6.Addr(), 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnV6.Addr(), theirVpnV6.Addr(), 80, 80)
+
+	t.Log("reply?")
+	theirControl.InjectTunUDPPacket(myVpnV6.Addr(), 80, theirVpnV6.Addr(), 80, []byte("Hi from them"))
+	p = r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them"), p, theirVpnV6.Addr(), myVpnV6.Addr(), 80, 80)
+
+	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
+	//t.Log("finish up")
+	//myControl.Stop()
+	//theirControl.Stop()
+	//relayControl.Stop()
+}

+ 21 - 4
handshake_manager.go

@@ -300,6 +300,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						InitiatorRelayIndex: idx,
 					}
 
+					relayFrom := hm.f.myVpnAddrs[0]
+
 					switch relayHostInfo.GetCert().Certificate.Version() {
 					case cert.Version1:
 						if !hm.f.myVpnAddrs[0].Is4() {
@@ -317,7 +319,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						b = vpnIp.As4()
 						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 					case cert.Version2:
-						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						if vpnIp.Is4() {
+							relayFrom = hm.f.myVpnAddrs[0]
+						} else {
+							//todo do this smarter
+							relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
+						}
+						m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
 					default:
 						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
@@ -332,7 +340,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnAddrs[0],
+							"relayFrom":           relayFrom,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
@@ -358,6 +366,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					InitiatorRelayIndex: existingRelay.LocalIndex,
 				}
 
+				relayFrom := hm.f.myVpnAddrs[0]
+
 				switch relayHostInfo.GetCert().Certificate.Version() {
 				case cert.Version1:
 					if !hm.f.myVpnAddrs[0].Is4() {
@@ -375,7 +385,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					b = vpnIp.As4()
 					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 				case cert.Version2:
-					m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+					if vpnIp.Is4() {
+						relayFrom = hm.f.myVpnAddrs[0]
+					} else {
+						//todo do this smarter
+						relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
+					}
+
+					m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
 				default:
 					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
@@ -390,7 +407,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					// This must send over the hostinfo, not over hm.Hosts[ip]
 					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 					hm.l.WithFields(logrus.Fields{
-						"relayFrom":           hm.f.myVpnAddrs[0],
+						"relayFrom":           relayFrom,
 						"relayTo":             vpnIp,
 						"initiatorRelayIndex": existingRelay.LocalIndex,
 						"relay":               relay}).

+ 5 - 1
hostmap.go

@@ -2,6 +2,7 @@ package nebula
 
 import (
 	"errors"
+	"fmt"
 	"net"
 	"net/netip"
 	"slices"
@@ -521,6 +522,7 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
 		return nil, nil, errors.New("unable to find host")
 	}
 
+	lastH := h
 	for h != nil {
 		for _, targetIp := range targetIps {
 			r, ok := h.relayState.QueryRelayForByIp(targetIp)
@@ -528,10 +530,12 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
 				return h, r, nil
 			}
 		}
+		lastH = h
 		h = h.next
 	}
 
-	return nil, nil, errors.New("unable to find host with relay")
+	//todo no merge
+	return nil, nil, fmt.Errorf("unable to find host with relay: %v", lastH)
 }
 
 func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {

+ 17 - 4
relay_manager.go

@@ -190,6 +190,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
 		}
 
+		relayFrom := h.vpnAddrs[0]
 		if v == cert.Version1 {
 			peer := peerHostInfo.vpnAddrs[0]
 			if !peer.Is4() {
@@ -207,7 +208,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 			b = targetAddr.As4()
 			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 		} else {
-			resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
+			if targetAddr.Is4() {
+				relayFrom = h.vpnAddrs[0]
+			} else {
+				//todo do this smarter
+				relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
+			}
+			resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			resp.RelayToAddr = target
 		}
 
@@ -360,7 +367,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
 		}
-
+		relayFrom := h.vpnAddrs[0]
 		if v == cert.Version1 {
 			if !h.vpnAddrs[0].Is4() {
 				rm.l.WithField("relayFrom", h.vpnAddrs[0]).
@@ -377,7 +384,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 			b = target.As4()
 			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 		} else {
-			req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
+			if target.Is4() {
+				relayFrom = h.vpnAddrs[0]
+			} else {
+				//todo do this smarter
+				relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
+			}
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			req.RelayToAddr = netAddrToProtoAddr(target)
 		}
 
@@ -388,7 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           h.vpnAddrs[0],
+				"relayFrom":           relayFrom,
 				"relayTo":             target,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,