Browse Source

remove unused and stale tunnels. punch less.

Ryan Huber 4 months ago
parent
commit
47d4055e10
1 changed files with 201 additions and 3 deletions
  1. 201 3
      connection_manager.go

+ 201 - 3
connection_manager.go

@@ -11,6 +11,7 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 )
 
@@ -26,6 +27,12 @@ const (
 	sendTestPacket trafficDecision = 6
 )
 
+// LastCommunication tracks when we last communicated with a host
+type LastCommunication struct {
+	timestamp time.Time
+	vpnIp     netip.Addr // To help with logging
+}
+
 type connectionManager struct {
 	in     map[uint32]struct{}
 	inLock *sync.RWMutex
@@ -37,6 +44,12 @@ type connectionManager struct {
 	relayUsed     map[uint32]struct{}
 	relayUsedLock *sync.RWMutex
 
+	// Track last communication with hosts
+	lastCommMap       map[uint32]*LastCommunication
+	lastCommLock      *sync.RWMutex
+	inactivityTimer   *LockingTimerWheel[uint32]
+	inactivityTimeout time.Duration
+
 	hostMap                 *HostMap
 	trafficTimer            *LockingTimerWheel[uint32]
 	intf                    *Interface
@@ -65,6 +78,9 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 		outLock:                 &sync.RWMutex{},
 		relayUsed:               make(map[uint32]struct{}),
 		relayUsedLock:           &sync.RWMutex{},
+		lastCommMap:             make(map[uint32]*LastCommunication),
+		lastCommLock:            &sync.RWMutex{},
+		inactivityTimeout:       1 * time.Minute, // Default inactivity timeout: 10 minutes
 		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
 		intf:                    intf,
 		pendingDeletion:         make(map[uint32]struct{}),
@@ -75,10 +91,42 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 		l:                       l,
 	}
 
+	// Initialize the inactivity timer wheel - make wheel duration slightly longer than the timeout
+	nc.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, nc.inactivityTimeout+time.Minute)
+
 	nc.Start(ctx)
 	return nc
 }
 
+func (n *connectionManager) updateLastCommunication(localIndex uint32) {
+	// Get host info to record VPN IP for better logging
+	hostInfo := n.hostMap.QueryIndex(localIndex)
+	if hostInfo == nil {
+		return
+	}
+
+	now := time.Now()
+	n.lastCommLock.Lock()
+	lastComm, exists := n.lastCommMap[localIndex]
+	if !exists {
+		// First time we've seen this host
+		lastComm = &LastCommunication{
+			timestamp: now,
+			vpnIp:     hostInfo.vpnIp,
+		}
+		n.lastCommMap[localIndex] = lastComm
+	} else {
+		// Update existing record
+		lastComm.timestamp = now
+	}
+	n.lastCommLock.Unlock()
+
+	// Reset the inactivity timer for this host
+	n.inactivityTimer.m.Lock()
+	n.inactivityTimer.t.Add(localIndex, n.inactivityTimeout)
+	n.inactivityTimer.m.Unlock()
+}
+
 func (n *connectionManager) In(localIndex uint32) {
 	n.inLock.RLock()
 	// If this already exists, return
@@ -90,6 +138,9 @@ func (n *connectionManager) In(localIndex uint32) {
 	n.inLock.Lock()
 	n.in[localIndex] = struct{}{}
 	n.inLock.Unlock()
+
+	// Update last communication time
+	n.updateLastCommunication(localIndex)
 }
 
 func (n *connectionManager) Out(localIndex uint32) {
@@ -103,6 +154,9 @@ func (n *connectionManager) Out(localIndex uint32) {
 	n.outLock.Lock()
 	n.out[localIndex] = struct{}{}
 	n.outLock.Unlock()
+
+	// Update last communication time
+	n.updateLastCommunication(localIndex)
 }
 
 func (n *connectionManager) RelayUsed(localIndex uint32) {
@@ -144,6 +198,134 @@ func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
 	n.outLock.Unlock()
 }
 
+// checkInactiveTunnels checks for tunnels that have been inactive for too long and drops them
+func (n *connectionManager) checkInactiveTunnels() {
+	now := time.Now()
+
+	// First, advance the timer wheel to the current time
+	n.inactivityTimer.m.Lock()
+	n.inactivityTimer.t.Advance(now)
+	n.inactivityTimer.m.Unlock()
+
+	// Check for expired timers (inactive connections)
+	for {
+		// Get the next expired tunnel
+		n.inactivityTimer.m.Lock()
+		localIndex, ok := n.inactivityTimer.t.Purge()
+		n.inactivityTimer.m.Unlock()
+
+		if !ok {
+			// No more expired timers
+			break
+		}
+
+		n.lastCommLock.RLock()
+		lastComm, exists := n.lastCommMap[localIndex]
+		n.lastCommLock.RUnlock()
+
+		if !exists {
+			// No last communication record, odd but skip
+			continue
+		}
+
+		// Calculate inactivity duration
+		inactiveDuration := now.Sub(lastComm.timestamp)
+
+		// Check if we've exceeded the inactivity timeout
+		if inactiveDuration >= n.inactivityTimeout {
+			// Get the host info (if it still exists)
+			hostInfo := n.hostMap.QueryIndex(localIndex)
+			if hostInfo == nil {
+				// Host info is gone, remove from our tracking map
+				n.lastCommLock.Lock()
+				delete(n.lastCommMap, localIndex)
+				n.lastCommLock.Unlock()
+				continue
+			}
+
+			// Log the inactivity and drop the tunnel
+			n.l.WithField("vpnIp", lastComm.vpnIp).
+				WithField("localIndex", localIndex).
+				WithField("inactiveDuration", inactiveDuration).
+				WithField("timeout", n.inactivityTimeout).
+				Info("Dropping tunnel due to inactivity")
+
+			// Close the tunnel using the existing mechanism
+			n.intf.closeTunnel(hostInfo)
+
+			// Clean up our tracking map
+			n.lastCommLock.Lock()
+			delete(n.lastCommMap, localIndex)
+			n.lastCommLock.Unlock()
+		} else {
+			// Re-add to the timer wheel with the remaining time
+			remainingTime := n.inactivityTimeout - inactiveDuration
+			n.inactivityTimer.m.Lock()
+			n.inactivityTimer.t.Add(localIndex, remainingTime)
+			n.inactivityTimer.m.Unlock()
+		}
+	}
+}
+
+// CleanupDeletedHostInfos removes entries from our lastCommMap for hosts that no longer exist
+func (n *connectionManager) CleanupDeletedHostInfos() {
+	n.lastCommLock.Lock()
+	defer n.lastCommLock.Unlock()
+
+	// Find indexes to delete
+	var toDelete []uint32
+	for localIndex := range n.lastCommMap {
+		if n.hostMap.QueryIndex(localIndex) == nil {
+			toDelete = append(toDelete, localIndex)
+		}
+	}
+
+	// Delete them
+	for _, localIndex := range toDelete {
+		delete(n.lastCommMap, localIndex)
+	}
+
+	if len(toDelete) > 0 && n.l.Level >= logrus.DebugLevel {
+		n.l.WithField("count", len(toDelete)).Debug("Cleaned up deleted host entries from lastCommMap")
+	}
+}
+
+// ReloadConfig updates the connection manager configuration
+func (n *connectionManager) ReloadConfig(c *config.C) {
+	// Get the inactivity timeout from config
+	inactivityTimeout := c.GetDuration("timers.inactivity_timeout", 10*time.Minute)
+
+	// Only update if different
+	if inactivityTimeout != n.inactivityTimeout {
+		n.l.WithField("old", n.inactivityTimeout).
+			WithField("new", inactivityTimeout).
+			Info("Updating inactivity timeout")
+
+		n.inactivityTimeout = inactivityTimeout
+
+		// Recreate the inactivity timer wheel with the new timeout
+		n.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, n.inactivityTimeout+time.Minute)
+
+		// Re-add all existing hosts to the new timer wheel
+		n.lastCommLock.RLock()
+		for localIndex, lastComm := range n.lastCommMap {
+			// Calculate remaining time based on last communication
+			now := time.Now()
+			elapsed := now.Sub(lastComm.timestamp)
+
+			// If the elapsed time exceeds the new timeout, this will be caught
+			// in the next inactivity check. Otherwise, add with remaining time.
+			if elapsed < n.inactivityTimeout {
+				remainingTime := n.inactivityTimeout - elapsed
+				n.inactivityTimer.m.Lock()
+				n.inactivityTimer.t.Add(localIndex, remainingTime)
+				n.inactivityTimer.m.Unlock()
+			}
+		}
+		n.lastCommLock.RUnlock()
+	}
+}
+
 func (n *connectionManager) Start(ctx context.Context) {
 	go n.Run(ctx)
 }
@@ -153,6 +335,14 @@ func (n *connectionManager) Run(ctx context.Context) {
 	clockSource := time.NewTicker(500 * time.Millisecond)
 	defer clockSource.Stop()
 
+	// Create ticker for inactivity checks (every minute)
+	inactivityTicker := time.NewTicker(time.Minute)
+	defer inactivityTicker.Stop()
+
+	// Create ticker for cleanup (every 5 minutes)
+	cleanupTicker := time.NewTicker(5 * time.Minute)
+	defer cleanupTicker.Stop()
+
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
@@ -172,6 +362,14 @@ func (n *connectionManager) Run(ctx context.Context) {
 
 				n.doTrafficCheck(localIndex, p, nb, out, now)
 			}
+
+		case <-inactivityTicker.C:
+			// Check for inactive tunnels
+			n.checkInactiveTunnels()
+
+		case <-cleanupTicker.C:
+			// Periodically clean up deleted hosts
+			n.CleanupDeletedHostInfos()
 		}
 	}
 }
@@ -367,7 +565,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 
 		if !outTraffic {
 			// Send a punch packet to keep the NAT state alive
-			n.sendPunch(hostinfo)
+			//n.sendPunch(hostinfo)
 		}
 
 		return decision, hostinfo, primary
@@ -388,7 +586,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		if !outTraffic {
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// Just maintain NAT state if configured to do so.
-			n.sendPunch(hostinfo)
+			//n.sendPunch(hostinfo)
 			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
 			return doNothing, nil, nil
 
@@ -398,7 +596,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 			// This is similar to the old punchy behavior with a slight optimization.
 			// We aren't receiving traffic but we are sending it, punch on all known
 			// ips in case we need to re-prime NAT state
-			n.sendPunch(hostinfo)
+			//n.sendPunch(hostinfo)
 		}
 
 		if n.l.Level >= logrus.DebugLevel {