소스 검색

Fix re-entrant `GetOrHandshake` issues (#1044)

Nate Brown 1 년 전
부모
커밋
072edd56b3
8개의 변경된 파일74개의 추가작업 그리고 36개의 파일을 삭제
  1. 9 4
      connection_manager.go
  2. 3 2
      connection_manager_test.go
  3. 4 0
      examples/config.yml
  4. 3 4
      handshake_manager.go
  5. 1 1
      hostmap.go
  6. 1 1
      inside.go
  7. 52 23
      lighthouse.go
  8. 1 1
      ssh.go

+ 9 - 4
connection_manager.go

@@ -23,6 +23,7 @@ const (
 	swapPrimary    trafficDecision = 3
 	migrateRelays  trafficDecision = 4
 	tryRehandshake trafficDecision = 5
+	sendTestPacket trafficDecision = 6
 )
 
 type connectionManager struct {
@@ -176,7 +177,7 @@ func (n *connectionManager) Run(ctx context.Context) {
 }
 
 func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
+	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
 
 	switch decision {
 	case deleteTunnel:
@@ -197,6 +198,9 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 
 	case tryRehandshake:
 		n.tryRehandshake(hostinfo)
+
+	case sendTestPacket:
+		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 	}
 
 	n.resetRelayTrafficCheck(hostinfo)
@@ -289,7 +293,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	}
 }
 
-func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
 	n.hostMap.RLock()
 	defer n.hostMap.RUnlock()
 
@@ -356,6 +360,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
 		return deleteTunnel, hostinfo, nil
 	}
 
+	decision := doNothing
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
@@ -380,7 +385,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
 		}
 
 		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		decision = sendTestPacket
 
 	} else {
 		if n.l.Level >= logrus.DebugLevel {
@@ -390,7 +395,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []
 
 	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
 	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
-	return doNothing, nil, nil
+	return decision, hostinfo, nil
 }
 
 func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {

+ 3 - 2
connection_manager_test.go

@@ -21,8 +21,9 @@ var vpnIp iputil.VpnIp
 
 func newTestLighthouse() *LightHouse {
 	lh := &LightHouse{
-		l:       test.NewLogger(),
-		addrMap: map[iputil.VpnIp]*RemoteList{},
+		l:         test.NewLogger(),
+		addrMap:   map[iputil.VpnIp]*RemoteList{},
+		queryChan: make(chan iputil.VpnIp, 10),
 	}
 	lighthouses := map[iputil.VpnIp]struct{}{}
 	staticList := map[iputil.VpnIp]struct{}{}

+ 4 - 0
examples/config.yml

@@ -289,6 +289,10 @@ logging:
   # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
   #try_interval: 100ms
   #retries: 20
+
+  # query_buffer is the size of the buffer channel for querying lighthouses
+  #query_buffer: 64
+
   # trigger_buffer is the size of the buffer channel for quickly sending handshakes
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64

+ 3 - 4
handshake_manager.go

@@ -230,7 +230,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
 		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		hm.lightHouse.QueryServer(vpnIp, hm.f)
+		hm.lightHouse.QueryServer(vpnIp)
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
@@ -374,13 +374,13 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
 func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
+	defer hm.Unlock()
 
 	if hh, ok := hm.vpnIps[vpnIp]; ok {
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
 			cacheCb(hh)
 		}
-		hm.Unlock()
 		return hh.hostinfo
 	}
 
@@ -421,8 +421,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
 		}
 	}
 
-	hm.Unlock()
-	hm.lightHouse.QueryServer(vpnIp, hm.f)
+	hm.lightHouse.QueryServer(vpnIp)
 	return hostinfo
 }
 

+ 1 - 1
hostmap.go

@@ -561,7 +561,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 		}
 
 		i.nextLHQuery.Store(now + ifce.reQueryWait.Load())
-		ifce.lightHouse.QueryServer(i.vpnIp, ifce)
+		ifce.lightHouse.QueryServer(i.vpnIp)
 	}
 }
 

+ 1 - 1
inside.go

@@ -288,7 +288,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount {
 		//NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is
 		// finally used again. This tunnel would eventually be torn down and recreated if this action didn't help.
-		f.lightHouse.QueryServer(hostinfo.vpnIp, f)
+		f.lightHouse.QueryServer(hostinfo.vpnIp)
 		hostinfo.lastRebindCount = f.rebindCount
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter")

+ 52 - 23
lighthouse.go

@@ -74,6 +74,8 @@ type LightHouse struct {
 	// IP's of relays that can be used by peers to access me
 	relaysForMe atomic.Pointer[[]iputil.VpnIp]
 
+	queryChan chan iputil.VpnIp
+
 	calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote
 
 	metrics           *MessageMetrics
@@ -110,6 +112,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		nebulaPort:   nebulaPort,
 		punchConn:    pc,
 		punchy:       p,
+		queryChan:    make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)),
 		l:            l,
 	}
 	lighthouses := make(map[iputil.VpnIp]struct{})
@@ -139,6 +142,8 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 		}
 	})
 
+	h.startQueryWorker()
+
 	return &h, nil
 }
 
@@ -443,9 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList
 	return nil
 }
 
-func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
+func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
-		lh.QueryServer(ip, f)
+		lh.QueryServer(ip)
 	}
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
@@ -456,30 +461,14 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList {
 	return nil
 }
 
-// This is asynchronous so no reply should be expected
-func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) {
-	if lh.amLighthouse {
-		return
-	}
-
-	if lh.IsLighthouseIP(ip) {
-		return
-	}
-
-	// Send a query to the lighthouses and hope for the best next time
-	query, err := NewLhQueryByInt(ip).Marshal()
-	if err != nil {
-		lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
+// QueryServer is asynchronous so no reply should be expected
+func (lh *LightHouse) QueryServer(ip iputil.VpnIp) {
+	// Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses
+	if lh.amLighthouse || lh.IsLighthouseIP(ip) {
 		return
 	}
 
-	lighthouses := lh.GetLighthouses()
-	lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
-	nb := make([]byte, 12, 12)
-	out := make([]byte, mtu)
-	for n := range lighthouses {
-		f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
-	}
+	lh.queryChan <- ip
 }
 
 func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList {
@@ -752,6 +741,46 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr {
 	return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port))
 }
 
+func (lh *LightHouse) startQueryWorker() {
+	if lh.amLighthouse {
+		return
+	}
+
+	go func() {
+		nb := make([]byte, 12, 12)
+		out := make([]byte, mtu)
+
+		for {
+			select {
+			case <-lh.ctx.Done():
+				return
+			case ip := <-lh.queryChan:
+				lh.innerQueryServer(ip, nb, out)
+			}
+		}
+	}()
+}
+
+func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) {
+	if lh.IsLighthouseIP(ip) {
+		return
+	}
+
+	// Send a query to the lighthouses and hope for the best next time
+	query, err := NewLhQueryByInt(ip).Marshal()
+	if err != nil {
+		lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload")
+		return
+	}
+
+	lighthouses := lh.GetLighthouses()
+	lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
+
+	for n := range lighthouses {
+		lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
+	}
+}
+
 func (lh *LightHouse) StartUpdateWorker() {
 	interval := lh.GetUpdateInterval()
 	if lh.amLighthouse || interval == 0 {

+ 1 - 1
ssh.go

@@ -518,7 +518,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 	}
 
 	var cm *CacheMap
-	rl := ifce.lightHouse.Query(vpnIp, ifce)
+	rl := ifce.lightHouse.Query(vpnIp)
 	if rl != nil {
 		cm = rl.CopyCache()
 	}