소스 검색

don't require a detailsVpnAddr in a HostUpdateNotification

JackDoan 2 일 전
부모
커밋
d3444f4593
1개의 변경된 파일32개의 추가작업 그리고 19개의 파일을 삭제
  1. 32 19
      lighthouse.go

+ 32 - 19
lighthouse.go

@@ -1060,7 +1060,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 	queryVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
 	if err != nil {
 		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery")
+			lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).
+				Debugln("Dropping malformed HostQuery")
+		}
+		return
+	} else if useVersion == cert.Version1 && queryVpnAddr.Is6() {
+		// this case really shouldn't be possible to represent, but reject it anyway.
+		if lhh.l.Level >= logrus.DebugLevel {
+			lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("queryVpnAddr", queryVpnAddr).
+				Debugln("invalid vpn addr for v1 handleHostQuery")
 		}
 		return
 	}
@@ -1069,9 +1077,6 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		if useVersion == cert.Version1 {
-			if !queryVpnAddr.Is4() {
-				return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
-			}
 			b := queryVpnAddr.As4()
 			n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
 		} else {
@@ -1231,16 +1236,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 		return
 	}
 
-	detailsVpnAddr, useVersion, err := n.Details.GetVpnAddrAndVersion()
-	if err != nil {
-		if lhh.l.Level >= logrus.DebugLevel {
-			lhh.l.WithField("details", n.Details).WithError(err).Debugln("dropping invalid HostUpdateNotification")
-		}
+	// not using GetVpnAddrAndVersion because we don't want to error on a blank detailsVpnAddr
+	var detailsVpnAddr netip.Addr
+	var useVersion cert.Version
+	if n.Details.OldVpnAddr != 0 { //v1 always sets this field
+		b := [4]byte{}
+		binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr)
+		detailsVpnAddr = netip.AddrFrom4(b)
+		useVersion = cert.Version1
+	} else if n.Details.VpnAddr != nil { //this field is "optional" in v2, but if it's set, we should enforce it
+		detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr)
+		useVersion = cert.Version2
+	} else {
+		detailsVpnAddr = netip.Addr{}
+		useVersion = cert.Version2
 	}
 
-	//TODO: CERT-V2 why do we care about the vpnAddr in the packet? We know where it came from, right?
-	//Simple check that the host sent this not someone else
-	if !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
+	//Simple check that the host sent this not someone else, if detailsVpnAddr is filled
+	if detailsVpnAddr.IsValid() && !slices.Contains(fromVpnAddrs, detailsVpnAddr) {
 		if lhh.l.Level >= logrus.DebugLevel {
 			lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update")
 		}
@@ -1254,24 +1267,24 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVp
 	am.Lock()
 	lhh.lh.Unlock()
 
-	am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
-	am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
+	am.unlockedSetV4(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(fromVpnAddrs[0], fromVpnAddrs[0], n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6)
 	am.unlockedSetRelay(fromVpnAddrs[0], relays)
 	am.Unlock()
 
 	n = lhh.resetMeta()
 	n.Type = NebulaMeta_HostUpdateNotificationAck
-
-	if useVersion == cert.Version1 {
+	switch useVersion {
+	case cert.Version1:
 		if !fromVpnAddrs[0].Is4() {
 			lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message")
 			return
 		}
 		vpnAddrB := fromVpnAddrs[0].As4()
 		n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:])
-	} else if useVersion == cert.Version2 {
-		n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
-	} else {
+	case cert.Version2:
+		// do nothing, we want to send a blank message
+	default:
 		lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version")
 		return
 	}