Selaa lähdekoodia

make sure hosts use the correct IP addresses when relaying

JackDoan 2 kuukautta sitten
vanhempi
sitoutus
d2cb854bff
4 muutettua tiedostoa jossa 98 lisäystä ja 81 poistoa
  1. 46 58
      handshake_manager.go
  2. 5 6
      hostmap.go
  3. 11 1
      lighthouse.go
  4. 36 16
      relay_manager.go

+ 46 - 58
handshake_manager.go

@@ -300,47 +300,27 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 						InitiatorRelayIndex: idx,
 						InitiatorRelayIndex: idx,
 					}
 					}
 
 
-					relayFrom := hm.f.myVpnAddrs[0]
-
 					switch relayHostInfo.GetCert().Certificate.Version() {
 					switch relayHostInfo.GetCert().Certificate.Version() {
 					case cert.Version1:
 					case cert.Version1:
-						if !hm.f.myVpnAddrs[0].Is4() {
-							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
-							continue
-						}
-
-						if !vpnIp.Is4() {
-							hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
-							continue
-						}
-
-						b := hm.f.myVpnAddrs[0].As4()
-						m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
-						b = vpnIp.As4()
-						m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+						err = buildRelayInfoCertV1(&m, hm.f.myVpnNetworks, vpnIp)
 					case cert.Version2:
 					case cert.Version2:
-						if vpnIp.Is4() {
-							relayFrom = hm.f.myVpnAddrs[0]
-						} else {
-							//todo do this smarter
-							relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
-						}
-						m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
-						m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+						err = buildRelayInfoCertV2(&m, hm.f.myVpnNetworks, vpnIp)
 					default:
 					default:
-						hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+						err = errors.New("unknown certificate version found while creating relay")
+					}
+					if err != nil {
+						hostinfo.logger(hm.l).WithError(err).Error("Refusing to relay")
 						continue
 						continue
 					}
 					}
 
 
 					msg, err := m.Marshal()
 					msg, err := m.Marshal()
 					if err != nil {
 					if err != nil {
-						hostinfo.logger(hm.l).
-							WithError(err).
+						hostinfo.logger(hm.l).WithError(err).
 							Error("Failed to marshal Control message to create relay")
 							Error("Failed to marshal Control message to create relay")
 					} else {
 					} else {
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						hm.l.WithFields(logrus.Fields{
 						hm.l.WithFields(logrus.Fields{
-							"relayFrom":           relayFrom,
+							"relayFrom":           m.GetRelayFrom(),
 							"relayTo":             vpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"initiatorRelayIndex": idx,
 							"relay":               relay}).
 							"relay":               relay}).
@@ -366,48 +346,27 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered
 					InitiatorRelayIndex: existingRelay.LocalIndex,
 					InitiatorRelayIndex: existingRelay.LocalIndex,
 				}
 				}
 
 
-				relayFrom := hm.f.myVpnAddrs[0]
-
+				var err error
 				switch relayHostInfo.GetCert().Certificate.Version() {
 				switch relayHostInfo.GetCert().Certificate.Version() {
 				case cert.Version1:
 				case cert.Version1:
-					if !hm.f.myVpnAddrs[0].Is4() {
-						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
-						continue
-					}
-
-					if !vpnIp.Is4() {
-						hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version")
-						continue
-					}
-
-					b := hm.f.myVpnAddrs[0].As4()
-					m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
-					b = vpnIp.As4()
-					m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+					err = buildRelayInfoCertV1(&m, hm.f.myVpnNetworks, vpnIp)
 				case cert.Version2:
 				case cert.Version2:
-					if vpnIp.Is4() {
-						relayFrom = hm.f.myVpnAddrs[0]
-					} else {
-						//todo do this smarter
-						relayFrom = hm.f.myVpnAddrs[len(hm.f.myVpnAddrs)-1]
-					}
-
-					m.RelayFromAddr = netAddrToProtoAddr(relayFrom)
-					m.RelayToAddr = netAddrToProtoAddr(vpnIp)
+					err = buildRelayInfoCertV2(&m, hm.f.myVpnNetworks, vpnIp)
 				default:
 				default:
-					hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay")
+					err = errors.New("unknown certificate version found while creating relay")
+				}
+				if err != nil {
+					hostinfo.logger(hm.l).WithError(err).Error("Refusing to relay")
 					continue
 					continue
 				}
 				}
 				msg, err := m.Marshal()
 				msg, err := m.Marshal()
 				if err != nil {
 				if err != nil {
-					hostinfo.logger(hm.l).
-						WithError(err).
-						Error("Failed to marshal Control message to create relay")
+					hostinfo.logger(hm.l).WithError(err).Error("Failed to marshal Control message to create relay")
 				} else {
 				} else {
 					// This must send over the hostinfo, not over hm.Hosts[ip]
 					// This must send over the hostinfo, not over hm.Hosts[ip]
 					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 					hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 					hm.l.WithFields(logrus.Fields{
 					hm.l.WithFields(logrus.Fields{
-						"relayFrom":           relayFrom,
+						"relayFrom":           m.GetRelayFrom(),
 						"relayTo":             vpnIp,
 						"relayTo":             vpnIp,
 						"initiatorRelayIndex": existingRelay.LocalIndex,
 						"initiatorRelayIndex": existingRelay.LocalIndex,
 						"relay":               relay}).
 						"relay":               relay}).
@@ -742,3 +701,32 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
 func hsTimeout(tries int64, interval time.Duration) time.Duration {
 func hsTimeout(tries int64, interval time.Duration) time.Duration {
 	return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
 	return time.Duration(tries / 2 * ((2 * int64(interval)) + (tries-1)*int64(interval)))
 }
 }
+
+var errNoRelayTooOld = errors.New("can not establish v1 relay with a v6 network because the relay is not running a current nebula version")
+
+func buildRelayInfoCertV1(m *NebulaControl, myVpnNetworks []netip.Prefix, peerVpnIp netip.Addr) error {
+	relayFrom := myVpnNetworks[0].Addr()
+	if !relayFrom.Is4() {
+		return errNoRelayTooOld
+	}
+	if !peerVpnIp.Is4() {
+		return errNoRelayTooOld
+	}
+
+	b := relayFrom.As4()
+	m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
+	b = peerVpnIp.As4()
+	m.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
+	return nil
+}
+
+func buildRelayInfoCertV2(m *NebulaControl, myVpnNetworks []netip.Prefix, peerVpnIp netip.Addr) error {
+	for i := range myVpnNetworks {
+		if myVpnNetworks[i].Contains(peerVpnIp) {
+			m.RelayFromAddr = netAddrToProtoAddr(myVpnNetworks[i].Addr())
+			m.RelayToAddr = netAddrToProtoAddr(peerVpnIp)
+			return nil
+		}
+	}
+	return errors.New("cannot establish relay, no networks in common")
+}

+ 5 - 6
hostmap.go

@@ -2,7 +2,6 @@ package nebula
 
 
 import (
 import (
 	"errors"
 	"errors"
-	"fmt"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
 	"slices"
 	"slices"
@@ -513,16 +512,18 @@ func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo {
 	return hm.queryVpnAddr(vpnIp, nil)
 	return hm.queryVpnAddr(vpnIp, nil)
 }
 }
 
 
+var errUnableToFindHost = errors.New("unable to find host")
+var errUnableToFindHostWithRelay = errors.New("unable to find host with relay")
+
 func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) {
 	hm.RLock()
 	hm.RLock()
 	defer hm.RUnlock()
 	defer hm.RUnlock()
 
 
 	h, ok := hm.Hosts[relayHostIp]
 	h, ok := hm.Hosts[relayHostIp]
 	if !ok {
 	if !ok {
-		return nil, nil, errors.New("unable to find host")
+		return nil, nil, errUnableToFindHost
 	}
 	}
 
 
-	lastH := h
 	for h != nil {
 	for h != nil {
 		for _, targetIp := range targetIps {
 		for _, targetIp := range targetIps {
 			r, ok := h.relayState.QueryRelayForByIp(targetIp)
 			r, ok := h.relayState.QueryRelayForByIp(targetIp)
@@ -530,12 +531,10 @@ func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp net
 				return h, r, nil
 				return h, r, nil
 			}
 			}
 		}
 		}
-		lastH = h
 		h = h.next
 		h = h.next
 	}
 	}
 
 
-	//todo no merge
-	return nil, nil, fmt.Errorf("unable to find host with relay: %v", lastH)
+	return nil, nil, errUnableToFindHostWithRelay
 }
 }
 
 
 func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {
 func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) {

+ 11 - 1
lighthouse.go

@@ -1425,7 +1425,7 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
 	return relays
 	return relays
 }
 }
 
 
-// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
+// findNetworkUnion returns the first netip.Addr of addrs contained in the list of provided netip.Prefix, if able
 func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
 func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
 	for i := range prefixes {
 	for i := range prefixes {
 		for j := range addrs {
 		for j := range addrs {
@@ -1450,3 +1450,13 @@ func (d *NebulaMetaDetails) GetVpnAddrAndVersion() (netip.Addr, cert.Version, er
 		return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
 		return netip.Addr{}, cert.Version1, ErrBadDetailsVpnAddr
 	}
 	}
 }
 }
+
+func (d *NebulaControl) GetRelayFrom() netip.Addr {
+	if d.OldRelayFromAddr != 0 {
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], d.OldRelayFromAddr)
+		return netip.AddrFrom4(b)
+	} else {
+		return protoAddrToNetAddr(d.RelayFromAddr)
+	}
+}

+ 36 - 16
relay_manager.go

@@ -155,6 +155,8 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 		"vpnAddrs":            h.vpnAddrs}).
 		"vpnAddrs":            h.vpnAddrs}).
 		Info("handleCreateRelayResponse")
 		Info("handleCreateRelayResponse")
 
 
+	//peer == relayFrom
+	//target == relayTo
 	target := m.RelayToAddr
 	target := m.RelayToAddr
 	targetAddr := protoAddrToNetAddr(target)
 	targetAddr := protoAddrToNetAddr(target)
 
 
@@ -195,7 +197,7 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 			peer := peerHostInfo.vpnAddrs[0]
 			peer := peerHostInfo.vpnAddrs[0]
 			if !peer.Is4() {
 			if !peer.Is4() {
 				rm.l.WithField("relayFrom", peer).
 				rm.l.WithField("relayFrom", peer).
-					WithField("relayTo", target).
+					WithField("relayTo", targetAddr).
 					WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
 					WithField("initiatorRelayIndex", resp.InitiatorRelayIndex).
 					WithField("responderRelayIndex", resp.ResponderRelayIndex).
 					WithField("responderRelayIndex", resp.ResponderRelayIndex).
 					WithField("vpnAddrs", peerHostInfo.vpnAddrs).
 					WithField("vpnAddrs", peerHostInfo.vpnAddrs).
@@ -208,12 +210,21 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 			b = targetAddr.As4()
 			b = targetAddr.As4()
 			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 			resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 		} else {
 		} else {
-			if targetAddr.Is4() {
-				relayFrom = h.vpnAddrs[0]
-			} else {
-				//todo do this smarter
-				relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
+			ok = false
+			peerNetworks := h.GetCert().Certificate.Networks()
+			for i := range peerNetworks {
+				if peerNetworks[i].Contains(targetAddr) {
+					relayFrom = peerNetworks[i].Addr()
+					ok = true
+					break
+				}
 			}
 			}
+			if !ok {
+				rm.l.WithFields(logrus.Fields{"from": f.myVpnNetworks, "to": targetAddr}).
+					Error("cannot establish relay, no networks in common")
+				return
+			}
+
 			resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			resp.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			resp.RelayToAddr = target
 			resp.RelayToAddr = target
 		}
 		}
@@ -225,8 +236,8 @@ func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f
 		} else {
 		} else {
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
 			rm.l.WithFields(logrus.Fields{
-				"relayFrom":           resp.RelayFromAddr,
-				"relayTo":             resp.RelayToAddr,
+				"relayFrom":           relayFrom,
+				"relayTo":             targetAddr,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"initiatorRelayIndex": resp.InitiatorRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"responderRelayIndex": resp.ResponderRelayIndex,
 				"vpnAddrs":            peerHostInfo.vpnAddrs}).
 				"vpnAddrs":            peerHostInfo.vpnAddrs}).
@@ -369,8 +380,8 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 		}
 		}
 		relayFrom := h.vpnAddrs[0]
 		relayFrom := h.vpnAddrs[0]
 		if v == cert.Version1 {
 		if v == cert.Version1 {
-			if !h.vpnAddrs[0].Is4() {
-				rm.l.WithField("relayFrom", h.vpnAddrs[0]).
+			if !relayFrom.Is4() {
+				rm.l.WithField("relayFrom", relayFrom).
 					WithField("relayTo", target).
 					WithField("relayTo", target).
 					WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
 					WithField("initiatorRelayIndex", req.InitiatorRelayIndex).
 					WithField("responderRelayIndex", req.ResponderRelayIndex).
 					WithField("responderRelayIndex", req.ResponderRelayIndex).
@@ -379,17 +390,26 @@ func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f
 				return
 				return
 			}
 			}
 
 
-			b := h.vpnAddrs[0].As4()
+			b := relayFrom.As4()
 			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
 			req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:])
 			b = target.As4()
 			b = target.As4()
 			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 			req.OldRelayToAddr = binary.BigEndian.Uint32(b[:])
 		} else {
 		} else {
-			if target.Is4() {
-				relayFrom = h.vpnAddrs[0]
-			} else {
-				//todo do this smarter
-				relayFrom = h.vpnAddrs[len(h.vpnAddrs)-1]
+			ok = false
+			peerNetworks := h.GetCert().Certificate.Networks()
+			for i := range peerNetworks {
+				if peerNetworks[i].Contains(target) {
+					relayFrom = peerNetworks[i].Addr()
+					ok = true
+					break
+				}
 			}
 			}
+			if !ok {
+				rm.l.WithFields(logrus.Fields{"from": f.myVpnNetworks, "to": target}).
+					Error("cannot establish relay, no networks in common")
+				return
+			}
+
 			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			req.RelayToAddr = netAddrToProtoAddr(target)
 			req.RelayToAddr = netAddrToProtoAddr(target)
 		}
 		}