Browse Source

[cert-v2] punchy-respond on an address in common with the querying host (#1261)

Jack Doan 10 tháng trước cách đây
mục cha
commit
5380fef7b0
3 tập tin đã thay đổi với 99 bổ sung15 xóa
  1. 37 13
      lighthouse.go
  2. 60 0
      lighthouse_test.go
  3. 2 2
      relay_manager.go

+ 37 - 13
lighthouse.go

@@ -1108,32 +1108,44 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 	lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1)
 	w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0])
 
-	// This signals the other side to punch some zero byte udp packets
-	found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) {
+	lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w)
+}
+
+// sendHostPunchNotification signals the other side to punch some zero byte udp packets
+func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) {
+	whereToPunch := fromVpnAddrs[0]
+	found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
-		targetHI := lhh.lh.ifce.GetHostInfo(queryVpnAddr)
+		targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest)
+		var useVersion cert.Version
 		if targetHI == nil {
 			useVersion = lhh.lh.ifce.GetCertState().defaultVersion
 		} else {
-			useVersion = targetHI.GetCert().Certificate.Version()
+			crt := targetHI.GetCert().Certificate
+			useVersion = crt.Version()
+			// we can only retarget if we have a hostinfo
+			newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs)
+			if ok {
+				whereToPunch = newDest
+			} else {
+				//TODO this means the destination will have no addresses in common with the punch-ee
+				//choosing to do nothing for now, but maybe we return an error?
+			}
 		}
 
 		if useVersion == cert.Version1 {
-			if !fromVpnAddrs[0].Is4() {
+			if !whereToPunch.Is4() {
 				return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery")
 			}
-			b := fromVpnAddrs[0].As4()
+			b := whereToPunch.As4()
 			n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])
-			lhh.coalesceAnswers(useVersion, c, n)
-
 		} else if useVersion == cert.Version2 {
-			n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0])
-			lhh.coalesceAnswers(useVersion, c, n)
-
+			n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch)
 		} else {
-			panic("unsupported version")
+			return 0, errors.New("unsupported version")
 		}
+		lhh.coalesceAnswers(useVersion, c, n)
 
 		return n.MarshalTo(lhh.pb)
 	})
@@ -1148,7 +1160,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
 	}
 
 	lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1)
-	w.SendMessageToVpnAddr(header.LightHouse, 0, queryVpnAddr, lhh.pb[:ln], lhh.nb, lhh.out[:0])
+	w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 
 func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) {
@@ -1429,3 +1441,15 @@ func (d *NebulaMetaDetails) GetRelays() []netip.Addr {
 	}
 	return relays
 }
+
+// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able
+func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) {
+	for i := range prefixes {
+		for j := range addrs {
+			if prefixes[i].Contains(addrs[j]) {
+				return addrs[j], true
+			}
+		}
+	}
+	return netip.Addr{}, false
+}

+ 60 - 0
lighthouse_test.go

@@ -494,3 +494,63 @@ func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort)
 		}
 	}
 }
+
+func Test_findNetworkUnion(t *testing.T) {
+	var out netip.Addr
+	var ok bool
+
+	tenDot := netip.MustParsePrefix("10.0.0.0/8")
+	oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16")
+	fe80 := netip.MustParsePrefix("fe80::/8")
+	fc00 := netip.MustParsePrefix("fc00::/7")
+
+	a1 := netip.MustParseAddr("10.0.0.1")
+	afe81 := netip.MustParseAddr("fe80::1")
+
+	//simple
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed lengths
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//mixed family
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+
+	//ordering
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, a1)
+	out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//some mismatches
+	out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81})
+	assert.True(t, ok)
+	assert.Equal(t, out, afe81)
+
+	//falsey cases
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+	out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81})
+	assert.False(t, ok)
+}

+ 2 - 2
relay_manager.go

@@ -137,8 +137,8 @@ func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) {
 
 func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) {
 	rm.l.WithFields(logrus.Fields{
-		"relayFrom":           m.RelayFromAddr,
-		"relayTo":             m.RelayToAddr,
+		"relayFrom":           protoAddrToNetAddr(m.RelayFromAddr),
+		"relayTo":             protoAddrToNetAddr(m.RelayToAddr),
 		"initiatorRelayIndex": m.InitiatorRelayIndex,
 		"responderRelayIndex": m.ResponderRelayIndex,
 		"vpnAddrs":            h.vpnAddrs}).