Przeglądaj źródła

Drop inactive tunnels (#1413)

Nate Brown 1 miesiąc temu
rodzic
commit
9877648da9
13 zmienionych plików z 480 dodań i 277 usunięć
  1. 202 175
      connection_manager.go
  2. 136 39
      connection_manager_test.go
  3. 12 8
      control.go
  4. 3 5
      e2e/handshakes_test.go
  5. 1 0
      e2e/router/router.go
  6. 55 0
      e2e/tunnels_test.go
  7. 12 0
      examples/config.yml
  8. 2 2
      handshake_ix.go
  9. 8 0
      hostmap.go
  10. 2 2
      inside.go
  11. 22 19
      interface.go
  12. 22 24
      main.go
  13. 3 3
      outside.go

+ 202 - 175
connection_manager.go

@@ -7,11 +7,13 @@ import (
 	"fmt"
 	"fmt"
 	"net/netip"
 	"net/netip"
 	"sync"
 	"sync"
+	"sync/atomic"
 	"time"
 	"time"
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 )
 )
 
 
@@ -28,130 +30,124 @@ const (
 )
 )
 
 
 type connectionManager struct {
 type connectionManager struct {
-	in     map[uint32]struct{}
-	inLock *sync.RWMutex
-
-	out     map[uint32]struct{}
-	outLock *sync.RWMutex
-
 	// relayUsed holds which relay localIndexs are in use
 	// relayUsed holds which relay localIndexs are in use
 	relayUsed     map[uint32]struct{}
 	relayUsed     map[uint32]struct{}
 	relayUsedLock *sync.RWMutex
 	relayUsedLock *sync.RWMutex
 
 
-	hostMap                 *HostMap
-	trafficTimer            *LockingTimerWheel[uint32]
-	intf                    *Interface
-	pendingDeletion         map[uint32]struct{}
-	punchy                  *Punchy
+	hostMap      *HostMap
+	trafficTimer *LockingTimerWheel[uint32]
+	intf         *Interface
+	punchy       *Punchy
+
+	// Configuration settings
 	checkInterval           time.Duration
 	checkInterval           time.Duration
 	pendingDeletionInterval time.Duration
 	pendingDeletionInterval time.Duration
-	metricsTxPunchy         metrics.Counter
+	inactivityTimeout       atomic.Int64
+	dropInactive            atomic.Bool
+
+	metricsTxPunchy metrics.Counter
 
 
 	l *logrus.Logger
 	l *logrus.Logger
 }
 }
 
 
-func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
-	var max time.Duration
-	if checkInterval < pendingDeletionInterval {
-		max = pendingDeletionInterval
-	} else {
-		max = checkInterval
+func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
+	cm := &connectionManager{
+		hostMap:         hm,
+		l:               l,
+		punchy:          p,
+		relayUsed:       make(map[uint32]struct{}),
+		relayUsedLock:   &sync.RWMutex{},
+		metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
 	}
 	}
 
 
-	nc := &connectionManager{
-		hostMap:                 intf.hostMap,
-		in:                      make(map[uint32]struct{}),
-		inLock:                  &sync.RWMutex{},
-		out:                     make(map[uint32]struct{}),
-		outLock:                 &sync.RWMutex{},
-		relayUsed:               make(map[uint32]struct{}),
-		relayUsedLock:           &sync.RWMutex{},
-		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
-		intf:                    intf,
-		pendingDeletion:         make(map[uint32]struct{}),
-		checkInterval:           checkInterval,
-		pendingDeletionInterval: pendingDeletionInterval,
-		punchy:                  punchy,
-		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
-		l:                       l,
-	}
+	cm.reload(c, true)
+	c.RegisterReloadCallback(func(c *config.C) {
+		cm.reload(c, false)
+	})
 
 
-	nc.Start(ctx)
-	return nc
+	return cm
 }
 }
 
 
-func (n *connectionManager) In(localIndex uint32) {
-	n.inLock.RLock()
-	// If this already exists, return
-	if _, ok := n.in[localIndex]; ok {
-		n.inLock.RUnlock()
-		return
+func (cm *connectionManager) reload(c *config.C, initial bool) {
+	if initial {
+		cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
+		cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
+
+		// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
+		// pretty close to their configured duration.
+		// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
+		minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
+		maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
+		cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
 	}
 	}
-	n.inLock.RUnlock()
-	n.inLock.Lock()
-	n.in[localIndex] = struct{}{}
-	n.inLock.Unlock()
-}
 
 
-func (n *connectionManager) Out(localIndex uint32) {
-	n.outLock.RLock()
-	// If this already exists, return
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.RUnlock()
-		return
+	if initial || c.HasChanged("tunnels.inactivity_timeout") {
+		old := cm.getInactivityTimeout()
+		cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
+		if !initial {
+			cm.l.WithField("oldDuration", old).
+				WithField("newDuration", cm.getInactivityTimeout()).
+				Info("Inactivity timeout has changed")
+		}
+	}
+
+	if initial || c.HasChanged("tunnels.drop_inactive") {
+		old := cm.dropInactive.Load()
+		cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
+		if !initial {
+			cm.l.WithField("oldBool", old).
+				WithField("newBool", cm.dropInactive.Load()).
+				Info("Drop inactive setting has changed")
+		}
 	}
 	}
-	n.outLock.RUnlock()
-	n.outLock.Lock()
-	n.out[localIndex] = struct{}{}
-	n.outLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) RelayUsed(localIndex uint32) {
-	n.relayUsedLock.RLock()
+func (cm *connectionManager) getInactivityTimeout() time.Duration {
+	return (time.Duration)(cm.inactivityTimeout.Load())
+}
+
+func (cm *connectionManager) In(h *HostInfo) {
+	h.in.Store(true)
+}
+
+func (cm *connectionManager) Out(h *HostInfo) {
+	h.out.Store(true)
+}
+
+func (cm *connectionManager) RelayUsed(localIndex uint32) {
+	cm.relayUsedLock.RLock()
 	// If this already exists, return
 	// If this already exists, return
-	if _, ok := n.relayUsed[localIndex]; ok {
-		n.relayUsedLock.RUnlock()
+	if _, ok := cm.relayUsed[localIndex]; ok {
+		cm.relayUsedLock.RUnlock()
 		return
 		return
 	}
 	}
-	n.relayUsedLock.RUnlock()
-	n.relayUsedLock.Lock()
-	n.relayUsed[localIndex] = struct{}{}
-	n.relayUsedLock.Unlock()
+	cm.relayUsedLock.RUnlock()
+	cm.relayUsedLock.Lock()
+	cm.relayUsed[localIndex] = struct{}{}
+	cm.relayUsedLock.Unlock()
 }
 }
 
 
 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
 // resets the state for this local index
 // resets the state for this local index
-func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
-	n.inLock.Lock()
-	n.outLock.Lock()
-	_, in := n.in[localIndex]
-	_, out := n.out[localIndex]
-	delete(n.in, localIndex)
-	delete(n.out, localIndex)
-	n.inLock.Unlock()
-	n.outLock.Unlock()
+func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
+	in := h.in.Swap(false)
+	out := h.out.Swap(false)
+	if in || out {
+		h.lastUsed = now
+	}
 	return in, out
 	return in, out
 }
 }
 
 
-func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
-	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
-	n.outLock.Lock()
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.Unlock()
-		return
+// AddTrafficWatch must be called for every new HostInfo.
+// We will continue to monitor the HostInfo until the tunnel is dropped.
+func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
+	if h.out.Swap(true) == false {
+		cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
 	}
 	}
-	n.out[localIndex] = struct{}{}
-	n.trafficTimer.Add(localIndex, n.checkInterval)
-	n.outLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) Start(ctx context.Context) {
-	go n.Run(ctx)
-}
-
-func (n *connectionManager) Run(ctx context.Context) {
-	//TODO: this tick should be based on the min wheel tick? Check firewall
-	clockSource := time.NewTicker(500 * time.Millisecond)
+func (cm *connectionManager) Start(ctx context.Context) {
+	clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
 	defer clockSource.Stop()
 	defer clockSource.Stop()
 
 
 	p := []byte("")
 	p := []byte("")
@@ -164,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) {
 			return
 			return
 
 
 		case now := <-clockSource.C:
 		case now := <-clockSource.C:
-			n.trafficTimer.Advance(now)
+			cm.trafficTimer.Advance(now)
 			for {
 			for {
-				localIndex, has := n.trafficTimer.Purge()
+				localIndex, has := cm.trafficTimer.Purge()
 				if !has {
 				if !has {
 					break
 					break
 				}
 				}
 
 
-				n.doTrafficCheck(localIndex, p, nb, out, now)
+				cm.doTrafficCheck(localIndex, p, nb, out, now)
 			}
 			}
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
+func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
+	decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
 
 
 	switch decision {
 	switch decision {
 	case deleteTunnel:
 	case deleteTunnel:
-		if n.hostMap.DeleteHostInfo(hostinfo) {
+		if cm.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
+			cm.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)
 		}
 		}
 
 
 	case closeTunnel:
 	case closeTunnel:
-		n.intf.sendCloseTunnel(hostinfo)
-		n.intf.closeTunnel(hostinfo)
+		cm.intf.sendCloseTunnel(hostinfo)
+		cm.intf.closeTunnel(hostinfo)
 
 
 	case swapPrimary:
 	case swapPrimary:
-		n.swapPrimary(hostinfo, primary)
+		cm.swapPrimary(hostinfo, primary)
 
 
 	case migrateRelays:
 	case migrateRelays:
-		n.migrateRelayUsed(hostinfo, primary)
+		cm.migrateRelayUsed(hostinfo, primary)
 
 
 	case tryRehandshake:
 	case tryRehandshake:
-		n.tryRehandshake(hostinfo)
+		cm.tryRehandshake(hostinfo)
 
 
 	case sendTestPacket:
 	case sendTestPacket:
-		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 	}
 	}
 
 
-	n.resetRelayTrafficCheck(hostinfo)
+	cm.resetRelayTrafficCheck(hostinfo)
 }
 }
 
 
-func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
+func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
 	if hostinfo != nil {
 	if hostinfo != nil {
-		n.relayUsedLock.Lock()
-		defer n.relayUsedLock.Unlock()
+		cm.relayUsedLock.Lock()
+		defer cm.relayUsedLock.Unlock()
 		// No need to migrate any relays, delete usage info now.
 		// No need to migrate any relays, delete usage info now.
 		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
 		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
-			delete(n.relayUsed, idx)
+			delete(cm.relayUsed, idx)
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
+func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 
 	for _, r := range relayFor {
 	for _, r := range relayFor {
@@ -238,7 +234,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 				index = existing.LocalIndex
 				index = existing.LocalIndex
 				switch r.Type {
 				switch r.Type {
 				case TerminalType:
 				case TerminalType:
-					relayFrom = n.intf.myVpnNet.Addr()
+					relayFrom = cm.intf.myVpnNet.Addr()
 					relayTo = existing.PeerIp
 					relayTo = existing.PeerIp
 				case ForwardingType:
 				case ForwardingType:
 					relayFrom = existing.PeerIp
 					relayFrom = existing.PeerIp
@@ -249,23 +245,23 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 				}
 				}
 			}
 			}
 		case !ok:
 		case !ok:
-			n.relayUsedLock.RLock()
-			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
+			cm.relayUsedLock.RLock()
+			if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
 				// The relay hasn't been used; don't migrate it.
 				// The relay hasn't been used; don't migrate it.
-				n.relayUsedLock.RUnlock()
+				cm.relayUsedLock.RUnlock()
 				continue
 				continue
 			}
 			}
-			n.relayUsedLock.RUnlock()
+			cm.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerIp, nil, r.Type, Requested)
 			if err != nil {
 			if err != nil {
-				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
+				cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 				continue
 			}
 			}
 			switch r.Type {
 			switch r.Type {
 			case TerminalType:
 			case TerminalType:
-				relayFrom = n.intf.myVpnNet.Addr()
+				relayFrom = cm.intf.myVpnNet.Addr()
 				relayTo = r.PeerIp
 				relayTo = r.PeerIp
 			case ForwardingType:
 			case ForwardingType:
 				relayFrom = r.PeerIp
 				relayFrom = r.PeerIp
@@ -289,10 +285,10 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		}
 		}
 		msg, err := req.Marshal()
 		msg, err := req.Marshal()
 		if err != nil {
 		if err != nil {
-			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
+			cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
 		} else {
-			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
-			n.l.WithFields(logrus.Fields{
+			cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
+			cm.l.WithFields(logrus.Fields{
 				"relayFrom":           req.RelayFromIp,
 				"relayFrom":           req.RelayFromIp,
 				"relayTo":             req.RelayToIp,
 				"relayTo":             req.RelayToIp,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -303,46 +299,45 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	}
 	}
 }
 }
 
 
-func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
-	n.hostMap.RLock()
-	defer n.hostMap.RUnlock()
+func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+	// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
+	cm.hostMap.RLock()
+	defer cm.hostMap.RUnlock()
 
 
-	hostinfo := n.hostMap.Indexes[localIndex]
+	hostinfo := cm.hostMap.Indexes[localIndex]
 	if hostinfo == nil {
 	if hostinfo == nil {
-		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-		delete(n.pendingDeletion, localIndex)
+		cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
 		return doNothing, nil, nil
 		return doNothing, nil, nil
 	}
 	}
 
 
-	if n.isInvalidCertificate(now, hostinfo) {
-		delete(n.pendingDeletion, hostinfo.localIndexId)
+	if cm.isInvalidCertificate(now, hostinfo) {
 		return closeTunnel, hostinfo, nil
 		return closeTunnel, hostinfo, nil
 	}
 	}
 
 
-	primary := n.hostMap.Hosts[hostinfo.vpnIp]
+	primary := cm.hostMap.Hosts[hostinfo.vpnIp]
 	mainHostInfo := true
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
 		mainHostInfo = false
 	}
 	}
 
 
 	// Check for traffic on this hostinfo
 	// Check for traffic on this hostinfo
-	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
+	inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
 
 
 	// A hostinfo is determined alive if there is incoming traffic
 	// A hostinfo is determined alive if there is incoming traffic
 	if inTraffic {
 	if inTraffic {
 		decision := doNothing
 		decision := doNothing
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).
 				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 				Debug("Tunnel status")
 				Debug("Tunnel status")
 		}
 		}
-		delete(n.pendingDeletion, hostinfo.localIndexId)
+		hostinfo.pendingDeletion.Store(false)
 
 
 		if mainHostInfo {
 		if mainHostInfo {
 			decision = tryRehandshake
 			decision = tryRehandshake
 
 
 		} else {
 		} else {
-			if n.shouldSwapPrimary(hostinfo, primary) {
+			if cm.shouldSwapPrimary(hostinfo, primary) {
 				decision = swapPrimary
 				decision = swapPrimary
 			} else {
 			} else {
 				// migrate the relays to the primary, if in use.
 				// migrate the relays to the primary, if in use.
@@ -350,46 +345,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 			}
 			}
 		}
 		}
 
 
-		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+		cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
 
 
 		if !outTraffic {
 		if !outTraffic {
 			// Send a punch packet to keep the NAT state alive
 			// Send a punch packet to keep the NAT state alive
-			n.sendPunch(hostinfo)
+			cm.sendPunch(hostinfo)
 		}
 		}
 
 
 		return decision, hostinfo, primary
 		return decision, hostinfo, primary
 	}
 	}
 
 
-	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
+	if hostinfo.pendingDeletion.Load() {
 		// We have already sent a test packet and nothing was returned, this hostinfo is dead
 		// We have already sent a test packet and nothing was returned, this hostinfo is dead
-		hostinfo.logger(n.l).
+		hostinfo.logger(cm.l).
 			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 			Info("Tunnel status")
 			Info("Tunnel status")
 
 
-		delete(n.pendingDeletion, hostinfo.localIndexId)
 		return deleteTunnel, hostinfo, nil
 		return deleteTunnel, hostinfo, nil
 	}
 	}
 
 
 	decision := doNothing
 	decision := doNothing
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
 		if !outTraffic {
+			inactiveFor, isInactive := cm.isInactive(hostinfo, now)
+			if isInactive {
+				// Tunnel is inactive, tear it down
+				hostinfo.logger(cm.l).
+					WithField("inactiveDuration", inactiveFor).
+					WithField("primary", mainHostInfo).
+					Info("Dropping tunnel due to inactivity")
+
+				return closeTunnel, hostinfo, primary
+			}
+
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// Just maintain NAT state if configured to do so.
 			// Just maintain NAT state if configured to do so.
-			n.sendPunch(hostinfo)
-			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+			cm.sendPunch(hostinfo)
+			cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
 			return doNothing, nil, nil
 			return doNothing, nil, nil
-
 		}
 		}
 
 
-		if n.punchy.GetTargetEverything() {
+		if cm.punchy.GetTargetEverything() {
 			// This is similar to the old punchy behavior with a slight optimization.
 			// This is similar to the old punchy behavior with a slight optimization.
 			// We aren't receiving traffic but we are sending it, punch on all known
 			// We aren't receiving traffic but we are sending it, punch on all known
 			// ips in case we need to re-prime NAT state
 			// ips in case we need to re-prime NAT state
-			n.sendPunch(hostinfo)
+			cm.sendPunch(hostinfo)
 		}
 		}
 
 
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).
 				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 				Debug("Tunnel status")
 				Debug("Tunnel status")
 		}
 		}
@@ -398,95 +402,118 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		decision = sendTestPacket
 		decision = sendTestPacket
 
 
 	} else {
 	} else {
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
 		}
 		}
 	}
 	}
 
 
-	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
-	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
+	hostinfo.pendingDeletion.Store(true)
+	cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
 	return decision, hostinfo, nil
 	return decision, hostinfo, nil
 }
 }
 
 
-func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
+func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
+	if cm.dropInactive.Load() == false {
+		// We aren't configured to drop inactive tunnels
+		return 0, false
+	}
+
+	inactiveDuration := now.Sub(hostinfo.lastUsed)
+	if inactiveDuration < cm.getInactivityTimeout() {
+		// It's not considered inactive
+		return inactiveDuration, false
+	}
+
+	// The tunnel is inactive
+	return inactiveDuration, true
+}
+
+func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
 	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
 	// Let's sort this out.
 
 
-	if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 {
+	if current.vpnIp.Compare(cm.intf.myVpnNet.Addr()) < 0 {
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
 		// The remotes vpn ip is lower than mine. I will not flip.
 		// The remotes vpn ip is lower than mine. I will not flip.
 		return false
 		return false
 	}
 	}
 
 
-	certState := n.intf.pki.GetCertState()
+	certState := cm.intf.pki.GetCertState()
 	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
 	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
 }
 }
 
 
-func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
-	n.hostMap.Lock()
+func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
+	cm.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnIp] == primary {
-		n.hostMap.unlockedMakePrimary(current)
+	if cm.hostMap.Hosts[current.vpnIp] == primary {
+		cm.hostMap.unlockedMakePrimary(current)
 	}
 	}
-	n.hostMap.Unlock()
+	cm.hostMap.Unlock()
 }
 }
 
 
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
 // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
 // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
 // check and return true.
 // check and return true.
-func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
+func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
 	if remoteCert == nil {
 		return false
 		return false
 	}
 	}
 
 
-	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
+	valid, err := remoteCert.VerifyWithCache(now, cm.intf.pki.GetCAPool())
 	if valid {
 	if valid {
 		return false
 		return false
 	}
 	}
 
 
-	if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
+	if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
 		// Block listed certificates should always be disconnected
 		// Block listed certificates should always be disconnected
 		return false
 		return false
 	}
 	}
 
 
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
-	hostinfo.logger(n.l).WithError(err).
+	hostinfo.logger(cm.l).WithError(err).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
 
 	return true
 	return true
 }
 }
 
 
-func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
-	if !n.punchy.GetPunch() {
+func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
+	if !cm.punchy.GetPunch() {
 		// Punching is disabled
 		// Punching is disabled
 		return
 		return
 	}
 	}
 
 
-	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
-			n.metricsTxPunchy.Inc(1)
-			n.intf.outside.WriteTo([]byte{1}, addr)
+	if cm.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
+		// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
+		// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
+		// would lose the ability to notify us and punchy.respond would become unreliable.
+		return
+	}
+
+	if cm.punchy.GetTargetEverything() {
+		hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
+			cm.metricsTxPunchy.Inc(1)
+			cm.intf.outside.WriteTo([]byte{1}, addr)
 		})
 		})
 
 
 	} else if hostinfo.remote.IsValid() {
 	} else if hostinfo.remote.IsValid() {
-		n.metricsTxPunchy.Inc(1)
-		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
+		cm.metricsTxPunchy.Inc(1)
+		cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 	}
 }
 }
 
 
-func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.pki.GetCertState()
+func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
+	certState := cm.intf.pki.GetCertState()
 	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
 	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
 		return
 		return
 	}
 	}
 
 
-	n.l.WithField("vpnIp", hostinfo.vpnIp).
+	cm.l.WithField("vpnIp", hostinfo.vpnIp).
 		WithField("reason", "local certificate is not current").
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 		Info("Re-handshaking with remote")
 
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
+	cm.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil)
 }
 }

+ 136 - 39
connection_manager_test.go

@@ -1,7 +1,6 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"context"
 	"crypto/ed25519"
 	"crypto/ed25519"
 	"crypto/rand"
 	"crypto/rand"
 	"net"
 	"net"
@@ -65,10 +64,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	ifce.pki.cs.Store(cs)
 	ifce.pki.cs.Store(cs)
 
 
 	// Create manager
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	p := []byte("")
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
@@ -86,31 +85,32 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
-	nc.Out(hostinfo.localIndexId)
-	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.False(t, hostinfo.pendingDeletion.Load())
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.out, hostinfo.localIndexId)
+	assert.True(t, hostinfo.out.Load())
+	assert.True(t, hostinfo.in.Load())
 
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 
 
 	// Do another traffic check tick, this host should be pending deletion now
 	// Do another traffic check tick, this host should be pending deletion now
-	nc.Out(hostinfo.localIndexId)
+	nc.Out(hostinfo)
+	assert.True(t, hostinfo.out.Load())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.True(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 
 
 	// Do a final traffic check tick, the host should now be removed
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 }
@@ -148,10 +148,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	ifce.pki.cs.Store(cs)
 	ifce.pki.cs.Store(cs)
 
 
 	// Create manager
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	p := []byte("")
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
@@ -169,33 +169,130 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
-	nc.Out(hostinfo.localIndexId)
-	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.True(t, hostinfo.in.Load())
+	assert.True(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.pendingDeletion.Load())
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 
 
 	// Do another traffic check tick, this host should be pending deletion now
 	// Do another traffic check tick, this host should be pending deletion now
-	nc.Out(hostinfo.localIndexId)
+	nc.Out(hostinfo)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.True(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 
 
 	// We saw traffic, should no longer be pending deletion
 	// We saw traffic, should no longer be pending deletion
-	nc.In(hostinfo.localIndexId)
+	nc.In(hostinfo)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+}
+
+func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
+	l := test.NewLogger()
+	vpncidr := netip.MustParsePrefix("172.1.1.1/24")
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnIp := netip.MustParseAddr("172.1.1.2")
+	preferredRanges := []netip.Prefix{localrange}
+
+	// Very incomplete mock objects
+	hostMap := newHostMap(l, vpncidr)
+	hostMap.preferredRanges.Store(&preferredRanges)
+
+	cs := &CertState{
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
+	}
+
+	lh := newTestLighthouse()
+	ifce := &Interface{
+		hostMap:          hostMap,
+		inside:           &test.NoopTun{},
+		outside:          &udp.NoopConn{},
+		firewall:         &Firewall{},
+		lightHouse:       lh,
+		pki:              &PKI{},
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		l:                l,
+	}
+	ifce.pki.cs.Store(cs)
+
+	// Create manager
+	conf := config.NewC(l)
+	conf.Settings["tunnels"] = map[interface{}]interface{}{
+		"drop_inactive": true,
+	}
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	assert.True(t, nc.dropInactive.Load())
+	nc.intf = ifce
+
+	// Add an ip we have established a connection w/ to hostmap
+	hostinfo := &HostInfo{
+		vpnIp:         vpnIp,
+		localIndexId:  1099,
+		remoteIndexId: 9901,
+	}
+	hostinfo.ConnectionState = &ConnectionState{
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
+	}
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
+
+	// Do a traffic check tick, in and out should be cleared but should not be pending deletion
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.True(t, hostinfo.out.Load())
+	assert.True(t, hostinfo.in.Load())
+
+	now := time.Now()
+	decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
+	assert.Equal(t, tryRehandshake, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
+	assert.Equal(t, doNothing, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+
+	// Do another traffic check tick, should still not be pending deletion
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
+	assert.Equal(t, doNothing, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+
+	// Finally advance beyond the inactivity timeout
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
+	assert.Equal(t, closeTunnel, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 }
 }
@@ -273,10 +370,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.disconnectInvalid.Store(true)
 	ifce.disconnectInvalid.Store(true)
 
 
 	// Create manager
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	ifce.connectionManager = nc
 	ifce.connectionManager = nc
 
 
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{

+ 12 - 8
control.go

@@ -26,14 +26,15 @@ type controlHostLister interface {
 }
 }
 
 
 type Control struct {
 type Control struct {
-	f               *Interface
-	l               *logrus.Logger
-	ctx             context.Context
-	cancel          context.CancelFunc
-	sshStart        func()
-	statsStart      func()
-	dnsStart        func()
-	lighthouseStart func()
+	f                      *Interface
+	l                      *logrus.Logger
+	ctx                    context.Context
+	cancel                 context.CancelFunc
+	sshStart               func()
+	statsStart             func()
+	dnsStart               func()
+	lighthouseStart        func()
+	connectionManagerStart func(context.Context)
 }
 }
 
 
 type ControlHostInfo struct {
 type ControlHostInfo struct {
@@ -63,6 +64,9 @@ func (c *Control) Start() {
 	if c.dnsStart != nil {
 	if c.dnsStart != nil {
 		go c.dnsStart()
 		go c.dnsStart()
 	}
 	}
+	if c.connectionManagerStart != nil {
+		go c.connectionManagerStart(c.ctx)
+	}
 	if c.lighthouseStart != nil {
 	if c.lighthouseStart != nil {
 		c.lighthouseStart()
 		c.lighthouseStart()
 	}
 	}

+ 3 - 5
e2e/handshakes_test.go

@@ -4,7 +4,6 @@
 package e2e
 package e2e
 
 
 import (
 import (
-	"fmt"
 	"net/netip"
 	"net/netip"
 	"slices"
 	"slices"
 	"testing"
 	"testing"
@@ -414,7 +413,7 @@ func TestReestablishRelays(t *testing.T) {
 	curIndexes := len(myControl.GetHostmap().Indexes)
 	curIndexes := len(myControl.GetHostmap().Indexes)
 	for curIndexes >= start {
 	for curIndexes >= start {
 		curIndexes = len(myControl.GetHostmap().Indexes)
 		curIndexes = len(myControl.GetHostmap().Indexes)
-		r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
+		r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
 		myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me should fail"))
 		myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me should fail"))
 
 
 		r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 		r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -964,9 +963,8 @@ func TestRehandshakingLoser(t *testing.T) {
 	t.Log("Stand up a tunnel between me and them")
 	t.Log("Stand up a tunnel between me and them")
 	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
 
 
-	tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
-	tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
-	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
+	myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false)
+	theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false)
 
 
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 

+ 1 - 0
e2e/router/router.go

@@ -690,6 +690,7 @@ func (r *R) FlushAll() {
 			r.Unlock()
 			r.Unlock()
 			panic("Can't FlushAll for host: " + p.To.String())
 			panic("Can't FlushAll for host: " + p.To.String())
 		}
 		}
+		receiver.InjectUDPPacket(p)
 		r.Unlock()
 		r.Unlock()
 	}
 	}
 }
 }

+ 55 - 0
e2e/tunnels_test.go

@@ -0,0 +1,55 @@
+//go:build e2e_testing
+// +build e2e_testing
+
+package e2e
+
+import (
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/e2e/router"
+)
+
+func TestDropInactiveTunnels(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+
+	r.Log("Assert the tunnel between me and them works")
+	assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r)
+
+	r.Log("Go inactive and wait for the tunnels to get dropped")
+	waitStart := time.Now()
+	for {
+		myIndexes := len(myControl.GetHostmap().Indexes)
+		theirIndexes := len(theirControl.GetHostmap().Indexes)
+		if myIndexes == 0 && theirIndexes == 0 {
+			break
+		}
+
+		since := time.Since(waitStart)
+		r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
+		if since > time.Second*30 {
+			t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
+		}
+
+		time.Sleep(1 * time.Second)
+		r.FlushAll()
+	}
+
+	r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 12 - 0
examples/config.yml

@@ -303,6 +303,18 @@ logging:
   # after receiving the response for lighthouse queries
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64
   #trigger_buffer: 64
 
 
+# Tunnel manager settings
+#tunnels:
+  # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
+  # elapsed.
+  # In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
+  # This setting is reloadable
+  #drop_inactive: false
+
+  # inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
+  # inactive and eligible to be dropped.
+  # This setting is reloadable
+  #inactivity_timeout: 10m
 
 
 # Nebula security group configuration
 # Nebula security group configuration
 firewall:
 firewall:

+ 2 - 2
handshake_ix.go

@@ -335,7 +335,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			Info("Handshake message sent")
 			Info("Handshake message sent")
 	}
 	}
 
 
-	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo)
 
 
 	hostinfo.remotes.ResetBlockedRemotes()
 	hostinfo.remotes.ResetBlockedRemotes()
 
 
@@ -493,7 +493,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	f.handshakeManager.Complete(hostinfo, f)
 	f.handshakeManager.Complete(hostinfo, f)
-	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo)
 
 
 	if f.l.Level >= logrus.DebugLevel {
 	if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
 		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))

+ 8 - 0
hostmap.go

@@ -242,6 +242,14 @@ type HostInfo struct {
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Synchronised via hostmap lock and not the hostinfo lock.
 	// Synchronised via hostmap lock and not the hostinfo lock.
 	next, prev *HostInfo
 	next, prev *HostInfo
+
+	//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
+	in, out, pendingDeletion atomic.Bool
+
+	// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
+	// This value will be behind against actual tunnel utilization in the hot path.
+	// This should only be used by the ConnectionManagers ticker routine.
+	lastUsed time.Time
 }
 }
 
 
 type ViaSender struct {
 type ViaSender struct {

+ 2 - 2
inside.go

@@ -213,7 +213,7 @@ func (f *Interface) SendVia(via *HostInfo,
 	c := via.ConnectionState.messageCounter.Add(1)
 	c := via.ConnectionState.messageCounter.Add(1)
 
 
 	out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
 	out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
-	f.connectionManager.Out(via.localIndexId)
+	f.connectionManager.Out(via)
 
 
 	// Authenticate the header and payload, but do not encrypt for this message type.
 	// Authenticate the header and payload, but do not encrypt for this message type.
 	// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
 	// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -282,7 +282,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 
 
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
 	out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
 	out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
-	f.connectionManager.Out(hostinfo.localIndexId)
+	f.connectionManager.Out(hostinfo)
 
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// all our IPs and enable a faster roaming.
 	// all our IPs and enable a faster roaming.

+ 22 - 19
interface.go

@@ -24,24 +24,23 @@ import (
 const mtu = 9001
 const mtu = 9001
 
 
 type InterfaceConfig struct {
 type InterfaceConfig struct {
-	HostMap                 *HostMap
-	Outside                 udp.Conn
-	Inside                  overlay.Device
-	pki                     *PKI
-	Cipher                  string
-	Firewall                *Firewall
-	ServeDns                bool
-	HandshakeManager        *HandshakeManager
-	lightHouse              *LightHouse
-	checkInterval           time.Duration
-	pendingDeletionInterval time.Duration
-	DropLocalBroadcast      bool
-	DropMulticast           bool
-	routines                int
-	MessageMetrics          *MessageMetrics
-	version                 string
-	relayManager            *relayManager
-	punchy                  *Punchy
+	HostMap            *HostMap
+	Outside            udp.Conn
+	Inside             overlay.Device
+	pki                *PKI
+	Cipher             string
+	Firewall           *Firewall
+	ServeDns           bool
+	HandshakeManager   *HandshakeManager
+	lightHouse         *LightHouse
+	connectionManager  *connectionManager
+	DropLocalBroadcast bool
+	DropMulticast      bool
+	routines           int
+	MessageMetrics     *MessageMetrics
+	version            string
+	relayManager       *relayManager
+	punchy             *Punchy
 
 
 	tryPromoteEvery uint32
 	tryPromoteEvery uint32
 	reQueryEvery    uint32
 	reQueryEvery    uint32
@@ -154,6 +153,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	if c.Firewall == nil {
 	if c.Firewall == nil {
 		return nil, errors.New("no firewall rules")
 		return nil, errors.New("no firewall rules")
 	}
 	}
+	if c.connectionManager == nil {
+		return nil, errors.New("no connection manager")
+	}
 
 
 	certificate := c.pki.GetCertState().Certificate
 	certificate := c.pki.GetCertState().Certificate
 
 
@@ -196,6 +198,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		myVpnNet:           myVpnNet,
 		myVpnNet:           myVpnNet,
 		relayManager:       c.relayManager,
 		relayManager:       c.relayManager,
+		connectionManager:  c.connectionManager,
 
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 
@@ -219,7 +222,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 
 
-	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
+	ifce.connectionManager.intf = ifce
 
 
 	return ifce, nil
 	return ifce, nil
 }
 }

+ 22 - 24
main.go

@@ -199,6 +199,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 
 	hostMap := NewHostMapFromConfig(l, tunCidr, c)
 	hostMap := NewHostMapFromConfig(l, tunCidr, c)
 	punchy := NewPunchyFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
+	connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
 	if err != nil {
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -234,31 +235,27 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 		}
 	}
 	}
 
 
-	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
-	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
-
 	ifConfig := &InterfaceConfig{
 	ifConfig := &InterfaceConfig{
-		HostMap:                 hostMap,
-		Inside:                  tun,
-		Outside:                 udpConns[0],
-		pki:                     pki,
-		Cipher:                  c.GetString("cipher", "aes"),
-		Firewall:                fw,
-		ServeDns:                serveDns,
-		HandshakeManager:        handshakeManager,
-		lightHouse:              lightHouse,
-		checkInterval:           time.Second * time.Duration(checkInterval),
-		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
-		tryPromoteEvery:         c.GetUint32("counters.try_promote", defaultPromoteEvery),
-		reQueryEvery:            c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
-		reQueryWait:             c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
-		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
-		DropMulticast:           c.GetBool("tun.drop_multicast", false),
-		routines:                routines,
-		MessageMetrics:          messageMetrics,
-		version:                 buildVersion,
-		relayManager:            NewRelayManager(ctx, l, hostMap, c),
-		punchy:                  punchy,
+		HostMap:            hostMap,
+		Inside:             tun,
+		Outside:            udpConns[0],
+		pki:                pki,
+		Cipher:             c.GetString("cipher", "aes"),
+		Firewall:           fw,
+		ServeDns:           serveDns,
+		HandshakeManager:   handshakeManager,
+		connectionManager:  connManager,
+		lightHouse:         lightHouse,
+		tryPromoteEvery:    c.GetUint32("counters.try_promote", defaultPromoteEvery),
+		reQueryEvery:       c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
+		reQueryWait:        c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
+		DropLocalBroadcast: c.GetBool("tun.drop_local_broadcast", false),
+		DropMulticast:      c.GetBool("tun.drop_multicast", false),
+		routines:           routines,
+		MessageMetrics:     messageMetrics,
+		version:            buildVersion,
+		relayManager:       NewRelayManager(ctx, l, hostMap, c),
+		punchy:             punchy,
 
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
 		l:                     l,
@@ -325,5 +322,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		statsStart,
 		statsStart,
 		dnsStart,
 		dnsStart,
 		lightHouse.StartUpdateWorker,
 		lightHouse.StartUpdateWorker,
+		connManager.Start,
 	}, nil
 	}, nil
 }
 }

+ 3 - 3
outside.go

@@ -102,7 +102,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			// Pull the Roaming parts up here, and return in all call paths.
 			// Pull the Roaming parts up here, and return in all call paths.
 			f.handleHostRoaming(hostinfo, ip)
 			f.handleHostRoaming(hostinfo, ip)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
-			f.connectionManager.In(hostinfo.localIndexId)
+			f.connectionManager.In(hostinfo)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
 
 
 			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
 			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -246,7 +246,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 
 	f.handleHostRoaming(hostinfo, ip)
 	f.handleHostRoaming(hostinfo, ip)
 
 
-	f.connectionManager.In(hostinfo.localIndexId)
+	f.connectionManager.In(hostinfo)
 }
 }
 
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -418,7 +418,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return false
 		return false
 	}
 	}
 
 
-	f.connectionManager.In(hostinfo.localIndexId)
+	f.connectionManager.In(hostinfo)
 	_, err = f.readers[q].Write(out)
 	_, err = f.readers[q].Write(out)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
 		f.l.WithError(err).Error("Failed to write to tun")