Browse Source

test and stupid fix

JackDoan 2 months ago
parent
commit
7797927401
4 changed files with 62 additions and 11 deletions
  1. 19 2
      e2e/helpers_test.go
  2. 21 4
      handshake_manager.go
  3. 5 1
      hostmap.go
  4. 17 4
      relay_manager.go

+ 19 - 2
e2e/helpers_test.go

@@ -29,8 +29,6 @@ type m = map[string]any
 
 // newSimpleServer creates a nebula instance with many assumptions
 func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
-	l := NewTestLogger()
-
 	var vpnNetworks []netip.Prefix
 	for _, sn := range strings.Split(sVpnNetworks, ",") {
 		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
@@ -56,6 +54,25 @@ func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name
 		budpIp[3] = 239
 		udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242)
 	}
+	return newSimpleServerWithUdp(v, caCrt, caKey, name, sVpnNetworks, udpAddr, overrides)
+}
+
+func newSimpleServerWithUdp(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, udpAddr netip.AddrPort, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) {
+	l := NewTestLogger()
+
+	var vpnNetworks []netip.Prefix
+	for _, sn := range strings.Split(sVpnNetworks, ",") {
+		vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn))
+		if err != nil {
+			panic(err)
+		}
+		vpnNetworks = append(vpnNetworks, vpnIpNet)
+	}
+
+	if len(vpnNetworks) == 0 {
+		panic("no vpn networks")
+	}
+
 	_, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{})
 
 	caB, err := caCrt.MarshalPEM()

+ 21 - 4
handshake_manager.go

@@ -299,6 +299,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						InitiatorRelayIndex: idx,
 					}
 
+					relayFrom := hm.f.myVpnAddrs[0]
+
 					switch relayHostInfo.GetCert().Certificate.Version() {
 					case cert.Version1:
 						if !hm.f.myVpnAddrs[0].Is4() {
@@ -316,7 +318,13 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						b = vpnIp.As4()
 						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 					case cert.Version2:
-						m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+						if vpnIp.Is4() {
+							relayFrom = hm.f.myVpnAddrs[0]
+						} else {
+							//todo do this smarter
+							relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
+						}
+						m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
 					default:
 						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
@@ -331,7 +339,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           hm.f.myVpnAddrs[0],
+							"relayFrom":           relayFrom,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
@@ -357,6 +365,8 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					InitiatorRelayIndex: existingRelay.LocalIndex,
 				}
 
+				relayFrom := hm.f.myVpnAddrs[0]
+
 				switch relayHostInfo.GetCert().Certificate.Version() {
 				case cert.Version1:
 					if !hm.f.myVpnAddrs[0].Is4() {
@@ -374,7 +384,14 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					b = vpnIp.As4()
 					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 				case cert.Version2:
-					m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0])
+					if vpnIp.Is4() {
+						relayFrom = hm.f.myVpnAddrs[0]
+					} else {
+						//todo do this smarter
+						relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
+					}
+
+					m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
 				default:
 					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
@@ -389,7 +406,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					// This must send over the hostinfo, not over hm.Hosts[ip]
 					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 					hm.l.WithFields(logrus.Fields{
-						"relayFrom":           hm.f.myVpnAddrs[0],
+						"relayFrom":           relayFrom,
 						"relayTo":             vpnIp,
 						"initiatorRelayIndex": existingRelay.LocalIndex,
 						"relay":               relay}).

+ 5 - 1
hostmap.go

@@ -2,6 +2,7 @@ package nebula
 
 import (
 	"errors"
+	"fmt"
 	"net"
 	"net/netip"
 	"slices"
@@ -521,6 +522,7 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
 		return nil, nil, errors.New("unable to find host")
 	}
 
+	lastH := h
 	for h != nil {
 		for _, targetIp := range targetIps {
 			r, ok := h.relayState.QueryRelayForByIp(targetIp)
@@ -528,10 +530,12 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
 				return h, r, nil
 			}
 		}
+		lastH = h
 		h = h.next
 	}
 
-	return nil, nil, errors.New("unable to find host with relay")
+	//todo no merge
+	return nil, nil, fmt.Errorf("unable to find host with relay: %v", lastH)
 }
 
 func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {

+ 17 - 4
relay_manager.go

@@ -190,6 +190,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 			InitiatorRelayIndex: peerRelay.RemoteIndex,
 		}
 
+		relayFrom := h.vpnAddrs[0]
 		if v == cert.Version1 {
 			peer := peerHostInfo.vpnAddrs[0]
 			if !peer.Is4() {
@@ -207,7 +208,13 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 			b = targetAddr.As4()
 			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 		} else {
-			resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0])
+			if targetAddr.Is4() {
+				relayFrom = h.vpnAddrs[0]
+			} else {
+				//todo do this smarter
+				relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
+			}
+			resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			resp.RelayToAddr = target
 		}
 
@@ -360,7 +367,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 			Type:                NebulaControl_CreateRelayRequest,
 			InitiatorRelayIndex: index,
 		}
-
+		relayFrom := h.vpnAddrs[0]
 		if v == cert.Version1 {
 			if !h.vpnAddrs[0].Is4() {
 				rm.l.WithField("relayFrom", h.vpnAddrs[0]).
@@ -377,7 +384,13 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 			b = target.As4()
 			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 		} else {
-			req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0])
+			if target.Is4() {
+				relayFrom = h.vpnAddrs[0]
+			} else {
+				//todo do this smarter
+				relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
+			}
+			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			req.RelayToAddr = netAddrToProtoAddr(target)
 		}
 
@@ -388,7 +401,7 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           h.vpnAddrs[0],
+				"relayFrom":           relayFrom,
 				"relayTo":             target,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"responderRelayIndex": req.ResponderRelayIndex,