|
@@ -1108,32 +1108,44 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
|
lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
|
|
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
|
w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
|
|
|
|
|
- // This signals the other side to punch some zero byte udp packets
|
|
|
|
- found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) {
|
|
|
|
|
|
+ lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// sendHostPunchNotification signals the other side to punch some zero byte udp packets
|
|
|
|
+func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) {
|
|
|
|
+ whereToPunch := fromVpnAddrs[0]
|
|
|
|
+ found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
|
|
n = lhh.resetMeta()
|
|
n = lhh.resetMeta()
|
|
n.Type = NebulaMeta_HostPunchNotification
|
|
n.Type = NebulaMeta_HostPunchNotification
|
|
- targetHI := lhh.lh.ifce.GetHostInfo(queryVpnAddr)
|
|
|
|
|
|
+ targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
|
|
|
|
+ var useVersion cert.Version
|
|
if targetHI == nil {
|
|
if targetHI == nil {
|
|
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
|
|
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
|
|
} else {
|
|
} else {
|
|
- useVersion = targetHI.GetCert().Certificate.Version()
|
|
|
|
|
|
+ crt := targetHI.GetCert().Certificate
|
|
|
|
+ useVersion = crt.Version()
|
|
|
|
+ // we can only retarget if we have a hostinfo
|
|
|
|
+ newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
|
|
|
|
+ if ok {
|
|
|
|
+ whereToPunch = newDest
|
|
|
|
+ } else {
|
|
|
|
+ //TODO this means the destination will have no addresses in common with the punch-ee
|
|
|
|
+ //choosing to do nothing for now, but maybe we return an error?
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
if useVersion == cert.Version1 {
|
|
if useVersion == cert.Version1 {
|
|
- if !fromVpnAddrs[0].Is4() {
|
|
|
|
|
|
+ if !whereToPunch.Is4() {
|
|
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
|
return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
|
|
}
|
|
}
|
|
- b := fromVpnAddrs[0].As4()
|
|
|
|
|
|
+ b := whereToPunch.As4()
|
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
|
n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
|
|
- lhh.coalesceAnswers(useVersion, c, n)
|
|
|
|
-
|
|
|
|
} else if useVersion == cert.Version2 {
|
|
} else if useVersion == cert.Version2 {
|
|
- n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
|
|
|
|
- lhh.coalesceAnswers(useVersion, c, n)
|
|
|
|
-
|
|
|
|
|
|
+ n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch)
|
|
} else {
|
|
} else {
|
|
- panic("unsupported version")
|
|
|
|
|
|
+ return 0, errors.New("unsupported version")
|
|
}
|
|
}
|
|
|
|
+ lhh.coalesceAnswers(useVersion, c, n)
|
|
|
|
|
|
return n.MarshalTo(lhh.pb)
|
|
return n.MarshalTo(lhh.pb)
|
|
})
|
|
})
|
|
@@ -1148,7 +1160,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
|
|
}
|
|
}
|
|
|
|
|
|
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
|
lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
|
|
- w.SendMessageToVpnAddr(header.LightHouse, 0, queryVpnAddr, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
|
|
|
|
|
+ w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0])
|
|
}
|
|
}
|
|
|
|
|
|
func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) {
|
|
func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) {
|
|
@@ -1429,3 +1441,15 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
|
|
}
|
|
}
|
|
return relays
|
|
return relays
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
|
|
|
|
+func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
|
|
|
|
+ for i := range prefixes {
|
|
|
|
+ for j := range addrs {
|
|
|
|
+ if prefixes[i].Contains(addrs[j]) {
|
|
|
|
+ return addrs[j], true
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return netip.Addr{}, false
|
|
|
|
+}
|