|
@@ -11,6 +11,7 @@ import (
|
|
|
"github.com/rcrowley/go-metrics"
|
|
|
"github.com/sirupsen/logrus"
|
|
|
"github.com/slackhq/nebula/cert"
|
|
|
+ "github.com/slackhq/nebula/config"
|
|
|
"github.com/slackhq/nebula/header"
|
|
|
)
|
|
|
|
|
@@ -26,6 +27,12 @@ const (
|
|
|
sendTestPacket trafficDecision = 6
|
|
|
)
|
|
|
|
|
|
+// LastCommunication tracks when we last communicated with a host
|
|
|
+type LastCommunication struct {
|
|
|
+ timestamp time.Time
|
|
|
+ vpnIp netip.Addr // To help with logging
|
|
|
+}
|
|
|
+
|
|
|
type connectionManager struct {
|
|
|
in map[uint32]struct{}
|
|
|
inLock *sync.RWMutex
|
|
@@ -37,6 +44,12 @@ type connectionManager struct {
|
|
|
relayUsed map[uint32]struct{}
|
|
|
relayUsedLock *sync.RWMutex
|
|
|
|
|
|
+ // Track last communication with hosts
|
|
|
+ lastCommMap map[uint32]*LastCommunication
|
|
|
+ lastCommLock *sync.RWMutex
|
|
|
+ inactivityTimer *LockingTimerWheel[uint32]
|
|
|
+ inactivityTimeout time.Duration
|
|
|
+
|
|
|
hostMap *HostMap
|
|
|
trafficTimer *LockingTimerWheel[uint32]
|
|
|
intf *Interface
|
|
@@ -65,6 +78,9 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
|
|
|
outLock: &sync.RWMutex{},
|
|
|
relayUsed: make(map[uint32]struct{}),
|
|
|
relayUsedLock: &sync.RWMutex{},
|
|
|
+ lastCommMap: make(map[uint32]*LastCommunication),
|
|
|
+ lastCommLock: &sync.RWMutex{},
|
|
|
+ inactivityTimeout: 1 * time.Minute, // Default inactivity timeout: 10 minutes
|
|
|
trafficTimer: NewLockingTimerWheel[uint32](time.Millisecond*500, max),
|
|
|
intf: intf,
|
|
|
pendingDeletion: make(map[uint32]struct{}),
|
|
@@ -75,10 +91,42 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
|
|
|
l: l,
|
|
|
}
|
|
|
|
|
|
+ // Initialize the inactivity timer wheel - make wheel duration slightly longer than the timeout
|
|
|
+ nc.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, nc.inactivityTimeout+time.Minute)
|
|
|
+
|
|
|
nc.Start(ctx)
|
|
|
return nc
|
|
|
}
|
|
|
|
|
|
+func (n *connectionManager) updateLastCommunication(localIndex uint32) {
|
|
|
+ // Get host info to record VPN IP for better logging
|
|
|
+ hostInfo := n.hostMap.QueryIndex(localIndex)
|
|
|
+ if hostInfo == nil {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ now := time.Now()
|
|
|
+ n.lastCommLock.Lock()
|
|
|
+ lastComm, exists := n.lastCommMap[localIndex]
|
|
|
+ if !exists {
|
|
|
+ // First time we've seen this host
|
|
|
+ lastComm = &LastCommunication{
|
|
|
+ timestamp: now,
|
|
|
+ vpnIp: hostInfo.vpnIp,
|
|
|
+ }
|
|
|
+ n.lastCommMap[localIndex] = lastComm
|
|
|
+ } else {
|
|
|
+ // Update existing record
|
|
|
+ lastComm.timestamp = now
|
|
|
+ }
|
|
|
+ n.lastCommLock.Unlock()
|
|
|
+
|
|
|
+ // Reset the inactivity timer for this host
|
|
|
+ n.inactivityTimer.m.Lock()
|
|
|
+ n.inactivityTimer.t.Add(localIndex, n.inactivityTimeout)
|
|
|
+ n.inactivityTimer.m.Unlock()
|
|
|
+}
|
|
|
+
|
|
|
func (n *connectionManager) In(localIndex uint32) {
|
|
|
n.inLock.RLock()
|
|
|
// If this already exists, return
|
|
@@ -90,6 +138,9 @@ func (n *connectionManager) In(localIndex uint32) {
|
|
|
n.inLock.Lock()
|
|
|
n.in[localIndex] = struct{}{}
|
|
|
n.inLock.Unlock()
|
|
|
+
|
|
|
+ // Update last communication time
|
|
|
+ n.updateLastCommunication(localIndex)
|
|
|
}
|
|
|
|
|
|
func (n *connectionManager) Out(localIndex uint32) {
|
|
@@ -103,6 +154,9 @@ func (n *connectionManager) Out(localIndex uint32) {
|
|
|
n.outLock.Lock()
|
|
|
n.out[localIndex] = struct{}{}
|
|
|
n.outLock.Unlock()
|
|
|
+
|
|
|
+ // Update last communication time
|
|
|
+ n.updateLastCommunication(localIndex)
|
|
|
}
|
|
|
|
|
|
func (n *connectionManager) RelayUsed(localIndex uint32) {
|
|
@@ -144,6 +198,134 @@ func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
|
|
|
n.outLock.Unlock()
|
|
|
}
|
|
|
|
|
|
+// checkInactiveTunnels checks for tunnels that have been inactive for too long and drops them
|
|
|
+func (n *connectionManager) checkInactiveTunnels() {
|
|
|
+ now := time.Now()
|
|
|
+
|
|
|
+ // First, advance the timer wheel to the current time
|
|
|
+ n.inactivityTimer.m.Lock()
|
|
|
+ n.inactivityTimer.t.Advance(now)
|
|
|
+ n.inactivityTimer.m.Unlock()
|
|
|
+
|
|
|
+ // Check for expired timers (inactive connections)
|
|
|
+ for {
|
|
|
+ // Get the next expired tunnel
|
|
|
+ n.inactivityTimer.m.Lock()
|
|
|
+ localIndex, ok := n.inactivityTimer.t.Purge()
|
|
|
+ n.inactivityTimer.m.Unlock()
|
|
|
+
|
|
|
+ if !ok {
|
|
|
+ // No more expired timers
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ n.lastCommLock.RLock()
|
|
|
+ lastComm, exists := n.lastCommMap[localIndex]
|
|
|
+ n.lastCommLock.RUnlock()
|
|
|
+
|
|
|
+ if !exists {
|
|
|
+ // No last communication record, odd but skip
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Calculate inactivity duration
|
|
|
+ inactiveDuration := now.Sub(lastComm.timestamp)
|
|
|
+
|
|
|
+ // Check if we've exceeded the inactivity timeout
|
|
|
+ if inactiveDuration >= n.inactivityTimeout {
|
|
|
+ // Get the host info (if it still exists)
|
|
|
+ hostInfo := n.hostMap.QueryIndex(localIndex)
|
|
|
+ if hostInfo == nil {
|
|
|
+ // Host info is gone, remove from our tracking map
|
|
|
+ n.lastCommLock.Lock()
|
|
|
+ delete(n.lastCommMap, localIndex)
|
|
|
+ n.lastCommLock.Unlock()
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // Log the inactivity and drop the tunnel
|
|
|
+ n.l.WithField("vpnIp", lastComm.vpnIp).
|
|
|
+ WithField("localIndex", localIndex).
|
|
|
+ WithField("inactiveDuration", inactiveDuration).
|
|
|
+ WithField("timeout", n.inactivityTimeout).
|
|
|
+ Info("Dropping tunnel due to inactivity")
|
|
|
+
|
|
|
+ // Close the tunnel using the existing mechanism
|
|
|
+ n.intf.closeTunnel(hostInfo)
|
|
|
+
|
|
|
+ // Clean up our tracking map
|
|
|
+ n.lastCommLock.Lock()
|
|
|
+ delete(n.lastCommMap, localIndex)
|
|
|
+ n.lastCommLock.Unlock()
|
|
|
+ } else {
|
|
|
+ // Re-add to the timer wheel with the remaining time
|
|
|
+ remainingTime := n.inactivityTimeout - inactiveDuration
|
|
|
+ n.inactivityTimer.m.Lock()
|
|
|
+ n.inactivityTimer.t.Add(localIndex, remainingTime)
|
|
|
+ n.inactivityTimer.m.Unlock()
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// CleanupDeletedHostInfos removes entries from our lastCommMap for hosts that no longer exist
|
|
|
+func (n *connectionManager) CleanupDeletedHostInfos() {
|
|
|
+ n.lastCommLock.Lock()
|
|
|
+ defer n.lastCommLock.Unlock()
|
|
|
+
|
|
|
+ // Find indexes to delete
|
|
|
+ var toDelete []uint32
|
|
|
+ for localIndex := range n.lastCommMap {
|
|
|
+ if n.hostMap.QueryIndex(localIndex) == nil {
|
|
|
+ toDelete = append(toDelete, localIndex)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Delete them
|
|
|
+ for _, localIndex := range toDelete {
|
|
|
+ delete(n.lastCommMap, localIndex)
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(toDelete) > 0 && n.l.Level >= logrus.DebugLevel {
|
|
|
+ n.l.WithField("count", len(toDelete)).Debug("Cleaned up deleted host entries from lastCommMap")
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// ReloadConfig updates the connection manager configuration
|
|
|
+func (n *connectionManager) ReloadConfig(c *config.C) {
|
|
|
+ // Get the inactivity timeout from config
|
|
|
+ inactivityTimeout := c.GetDuration("timers.inactivity_timeout", 10*time.Minute)
|
|
|
+
|
|
|
+ // Only update if different
|
|
|
+ if inactivityTimeout != n.inactivityTimeout {
|
|
|
+ n.l.WithField("old", n.inactivityTimeout).
|
|
|
+ WithField("new", inactivityTimeout).
|
|
|
+ Info("Updating inactivity timeout")
|
|
|
+
|
|
|
+ n.inactivityTimeout = inactivityTimeout
|
|
|
+
|
|
|
+ // Recreate the inactivity timer wheel with the new timeout
|
|
|
+ n.inactivityTimer = NewLockingTimerWheel[uint32](time.Minute, n.inactivityTimeout+time.Minute)
|
|
|
+
|
|
|
+ // Re-add all existing hosts to the new timer wheel
|
|
|
+ n.lastCommLock.RLock()
|
|
|
+ for localIndex, lastComm := range n.lastCommMap {
|
|
|
+ // Calculate remaining time based on last communication
|
|
|
+ now := time.Now()
|
|
|
+ elapsed := now.Sub(lastComm.timestamp)
|
|
|
+
|
|
|
+ // If the elapsed time exceeds the new timeout, this will be caught
|
|
|
+ // in the next inactivity check. Otherwise, add with remaining time.
|
|
|
+ if elapsed < n.inactivityTimeout {
|
|
|
+ remainingTime := n.inactivityTimeout - elapsed
|
|
|
+ n.inactivityTimer.m.Lock()
|
|
|
+ n.inactivityTimer.t.Add(localIndex, remainingTime)
|
|
|
+ n.inactivityTimer.m.Unlock()
|
|
|
+ }
|
|
|
+ }
|
|
|
+ n.lastCommLock.RUnlock()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (n *connectionManager) Start(ctx context.Context) {
|
|
|
go n.Run(ctx)
|
|
|
}
|
|
@@ -153,6 +335,14 @@ func (n *connectionManager) Run(ctx context.Context) {
|
|
|
clockSource := time.NewTicker(500 * time.Millisecond)
|
|
|
defer clockSource.Stop()
|
|
|
|
|
|
+ // Create ticker for inactivity checks (every minute)
|
|
|
+ inactivityTicker := time.NewTicker(time.Minute)
|
|
|
+ defer inactivityTicker.Stop()
|
|
|
+
|
|
|
+ // Create ticker for cleanup (every 5 minutes)
|
|
|
+ cleanupTicker := time.NewTicker(5 * time.Minute)
|
|
|
+ defer cleanupTicker.Stop()
|
|
|
+
|
|
|
p := []byte("")
|
|
|
nb := make([]byte, 12, 12)
|
|
|
out := make([]byte, mtu)
|
|
@@ -172,6 +362,14 @@ func (n *connectionManager) Run(ctx context.Context) {
|
|
|
|
|
|
n.doTrafficCheck(localIndex, p, nb, out, now)
|
|
|
}
|
|
|
+
|
|
|
+ case <-inactivityTicker.C:
|
|
|
+ // Check for inactive tunnels
|
|
|
+ n.checkInactiveTunnels()
|
|
|
+
|
|
|
+ case <-cleanupTicker.C:
|
|
|
+ // Periodically clean up deleted hosts
|
|
|
+ n.CleanupDeletedHostInfos()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -367,7 +565,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
|
|
|
|
|
if !outTraffic {
|
|
|
// Send a punch packet to keep the NAT state alive
|
|
|
- n.sendPunch(hostinfo)
|
|
|
+ //n.sendPunch(hostinfo)
|
|
|
}
|
|
|
|
|
|
return decision, hostinfo, primary
|
|
@@ -388,7 +586,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
|
|
if !outTraffic {
|
|
|
// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
|
|
|
// Just maintain NAT state if configured to do so.
|
|
|
- n.sendPunch(hostinfo)
|
|
|
+ //n.sendPunch(hostinfo)
|
|
|
n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
|
|
|
return doNothing, nil, nil
|
|
|
|
|
@@ -398,7 +596,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
|
|
|
// This is similar to the old punchy behavior with a slight optimization.
|
|
|
// We aren't receiving traffic but we are sending it, punch on all known
|
|
|
// ips in case we need to re-prime NAT state
|
|
|
- n.sendPunch(hostinfo)
|
|
|
+ //n.sendPunch(hostinfo)
|
|
|
}
|
|
|
|
|
|
if n.l.Level >= logrus.DebugLevel {
|