Browse Source

Drop inactive tunnels (#1427)

Nate Brown 3 weeks ago
parent
commit
52623820c2
14 changed files with 485 additions and 279 deletions
  1. 205 178
      connection_manager.go
  2. 136 40
      connection_manager_test.go
  3. 12 8
      control.go
  4. 4 1
      e2e/handshakes_test.go
  5. 1 0
      e2e/router/router.go
  6. 57 0
      e2e/tunnels_test.go
  7. 12 0
      examples/config.yml
  8. 2 2
      handshake_ix.go
  9. 8 0
      hostmap.go
  10. 2 2
      inside.go
  11. 22 19
      interface.go
  12. 21 24
      main.go
  13. 3 3
      outside.go
  14. 0 2
      udp/udp_darwin.go

+ 205 - 178
connection_manager.go

@@ -7,11 +7,13 @@ import (
 	"fmt"
 	"net/netip"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/header"
 )
 
@@ -28,130 +30,124 @@ const (
 )
 
 type connectionManager struct {
-	in     map[uint32]struct{}
-	inLock *sync.RWMutex
-
-	out     map[uint32]struct{}
-	outLock *sync.RWMutex
-
 	// relayUsed holds which relay localIndexs are in use
 	relayUsed     map[uint32]struct{}
 	relayUsedLock *sync.RWMutex
 
-	hostMap                 *HostMap
-	trafficTimer            *LockingTimerWheel[uint32]
-	intf                    *Interface
-	pendingDeletion         map[uint32]struct{}
-	punchy                  *Punchy
+	hostMap      *HostMap
+	trafficTimer *LockingTimerWheel[uint32]
+	intf         *Interface
+	punchy       *Punchy
+
+	// Configuration settings
 	checkInterval           time.Duration
 	pendingDeletionInterval time.Duration
-	metricsTxPunchy         metrics.Counter
+	inactivityTimeout       atomic.Int64
+	dropInactive            atomic.Bool
+
+	metricsTxPunchy metrics.Counter
 
 	l *logrus.Logger
 }
 
-func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
-	var max time.Duration
-	if checkInterval < pendingDeletionInterval {
-		max = pendingDeletionInterval
-	} else {
-		max = checkInterval
+func newConnectionManagerFromConfig(l *logrus.Logger, c *config.C, hm *HostMap, p *Punchy) *connectionManager {
+	cm := &connectionManager{
+		hostMap:         hm,
+		l:               l,
+		punchy:          p,
+		relayUsed:       make(map[uint32]struct{}),
+		relayUsedLock:   &sync.RWMutex{},
+		metricsTxPunchy: metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
 	}
 
-	nc := &connectionManager{
-		hostMap:                 intf.hostMap,
-		in:                      make(map[uint32]struct{}),
-		inLock:                  &sync.RWMutex{},
-		out:                     make(map[uint32]struct{}),
-		outLock:                 &sync.RWMutex{},
-		relayUsed:               make(map[uint32]struct{}),
-		relayUsedLock:           &sync.RWMutex{},
-		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
-		intf:                    intf,
-		pendingDeletion:         make(map[uint32]struct{}),
-		checkInterval:           checkInterval,
-		pendingDeletionInterval: pendingDeletionInterval,
-		punchy:                  punchy,
-		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
-		l:                       l,
-	}
+	cm.reload(c, true)
+	c.RegisterReloadCallback(func(c *config.C) {
+		cm.reload(c, false)
+	})
 
-	nc.Start(ctx)
-	return nc
+	return cm
 }
 
-func (n *connectionManager) In(localIndex uint32) {
-	n.inLock.RLock()
-	// If this already exists, return
-	if _, ok := n.in[localIndex]; ok {
-		n.inLock.RUnlock()
-		return
+func (cm *connectionManager) reload(c *config.C, initial bool) {
+	if initial {
+		cm.checkInterval = time.Duration(c.GetInt("timers.connection_alive_interval", 5)) * time.Second
+		cm.pendingDeletionInterval = time.Duration(c.GetInt("timers.pending_deletion_interval", 10)) * time.Second
+
+		// We want at least a minimum resolution of 500ms per tick so that we can hit these intervals
+		// pretty close to their configured duration.
+		// The inactivity duration is checked each time a hostinfo ticks through so we don't need the wheel to contain it.
+		minDuration := min(time.Millisecond*500, cm.checkInterval, cm.pendingDeletionInterval)
+		maxDuration := max(cm.checkInterval, cm.pendingDeletionInterval)
+		cm.trafficTimer = NewLockingTimerWheel[uint32](minDuration, maxDuration)
 	}
-	n.inLock.RUnlock()
-	n.inLock.Lock()
-	n.in[localIndex] = struct{}{}
-	n.inLock.Unlock()
-}
 
-func (n *connectionManager) Out(localIndex uint32) {
-	n.outLock.RLock()
-	// If this already exists, return
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.RUnlock()
-		return
+	if initial || c.HasChanged("tunnels.inactivity_timeout") {
+		old := cm.getInactivityTimeout()
+		cm.inactivityTimeout.Store((int64)(c.GetDuration("tunnels.inactivity_timeout", 10*time.Minute)))
+		if !initial {
+			cm.l.WithField("oldDuration", old).
+				WithField("newDuration", cm.getInactivityTimeout()).
+				Info("Inactivity timeout has changed")
+		}
+	}
+
+	if initial || c.HasChanged("tunnels.drop_inactive") {
+		old := cm.dropInactive.Load()
+		cm.dropInactive.Store(c.GetBool("tunnels.drop_inactive", false))
+		if !initial {
+			cm.l.WithField("oldBool", old).
+				WithField("newBool", cm.dropInactive.Load()).
+				Info("Drop inactive setting has changed")
+		}
 	}
-	n.outLock.RUnlock()
-	n.outLock.Lock()
-	n.out[localIndex] = struct{}{}
-	n.outLock.Unlock()
 }
 
-func (n *connectionManager) RelayUsed(localIndex uint32) {
-	n.relayUsedLock.RLock()
+func (cm *connectionManager) getInactivityTimeout() time.Duration {
+	return (time.Duration)(cm.inactivityTimeout.Load())
+}
+
+func (cm *connectionManager) In(h *HostInfo) {
+	h.in.Store(true)
+}
+
+func (cm *connectionManager) Out(h *HostInfo) {
+	h.out.Store(true)
+}
+
+func (cm *connectionManager) RelayUsed(localIndex uint32) {
+	cm.relayUsedLock.RLock()
 	// If this already exists, return
-	if _, ok := n.relayUsed[localIndex]; ok {
-		n.relayUsedLock.RUnlock()
+	if _, ok := cm.relayUsed[localIndex]; ok {
+		cm.relayUsedLock.RUnlock()
 		return
 	}
-	n.relayUsedLock.RUnlock()
-	n.relayUsedLock.Lock()
-	n.relayUsed[localIndex] = struct{}{}
-	n.relayUsedLock.Unlock()
+	cm.relayUsedLock.RUnlock()
+	cm.relayUsedLock.Lock()
+	cm.relayUsed[localIndex] = struct{}{}
+	cm.relayUsedLock.Unlock()
 }
 
 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
 // resets the state for this local index
-func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
-	n.inLock.Lock()
-	n.outLock.Lock()
-	_, in := n.in[localIndex]
-	_, out := n.out[localIndex]
-	delete(n.in, localIndex)
-	delete(n.out, localIndex)
-	n.inLock.Unlock()
-	n.outLock.Unlock()
+func (cm *connectionManager) getAndResetTrafficCheck(h *HostInfo, now time.Time) (bool, bool) {
+	in := h.in.Swap(false)
+	out := h.out.Swap(false)
+	if in || out {
+		h.lastUsed = now
+	}
 	return in, out
 }
 
-func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
-	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
-	n.outLock.Lock()
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.Unlock()
-		return
+// AddTrafficWatch must be called for every new HostInfo.
+// We will continue to monitor the HostInfo until the tunnel is dropped.
+func (cm *connectionManager) AddTrafficWatch(h *HostInfo) {
+	if h.out.Swap(true) == false {
+		cm.trafficTimer.Add(h.localIndexId, cm.checkInterval)
 	}
-	n.out[localIndex] = struct{}{}
-	n.trafficTimer.Add(localIndex, n.checkInterval)
-	n.outLock.Unlock()
 }
 
-func (n *connectionManager) Start(ctx context.Context) {
-	go n.Run(ctx)
-}
-
-func (n *connectionManager) Run(ctx context.Context) {
-	//TODO: this tick should be based on the min wheel tick? Check firewall
-	clockSource := time.NewTicker(500 * time.Millisecond)
+func (cm *connectionManager) Start(ctx context.Context) {
+	clockSource := time.NewTicker(cm.trafficTimer.t.tickDuration)
 	defer clockSource.Stop()
 
 	p := []byte("")
@@ -164,61 +160,61 @@ func (n *connectionManager) Run(ctx context.Context) {
 			return
 
 		case now := <-clockSource.C:
-			n.trafficTimer.Advance(now)
+			cm.trafficTimer.Advance(now)
 			for {
-				localIndex, has := n.trafficTimer.Purge()
+				localIndex, has := cm.trafficTimer.Purge()
 				if !has {
 					break
 				}
 
-				n.doTrafficCheck(localIndex, p, nb, out, now)
+				cm.doTrafficCheck(localIndex, p, nb, out, now)
 			}
 		}
 	}
 }
 
-func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now)
+func (cm *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
+	decision, hostinfo, primary := cm.makeTrafficDecision(localIndex, now)
 
 	switch decision {
 	case deleteTunnel:
-		if n.hostMap.DeleteHostInfo(hostinfo) {
+		if cm.hostMap.DeleteHostInfo(hostinfo) {
 			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap
-			n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
+			cm.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs)
 		}
 
 	case closeTunnel:
-		n.intf.sendCloseTunnel(hostinfo)
-		n.intf.closeTunnel(hostinfo)
+		cm.intf.sendCloseTunnel(hostinfo)
+		cm.intf.closeTunnel(hostinfo)
 
 	case swapPrimary:
-		n.swapPrimary(hostinfo, primary)
+		cm.swapPrimary(hostinfo, primary)
 
 	case migrateRelays:
-		n.migrateRelayUsed(hostinfo, primary)
+		cm.migrateRelayUsed(hostinfo, primary)
 
 	case tryRehandshake:
-		n.tryRehandshake(hostinfo)
+		cm.tryRehandshake(hostinfo)
 
 	case sendTestPacket:
-		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		cm.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 	}
 
-	n.resetRelayTrafficCheck(hostinfo)
+	cm.resetRelayTrafficCheck(hostinfo)
 }
 
-func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
+func (cm *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
 	if hostinfo != nil {
-		n.relayUsedLock.Lock()
-		defer n.relayUsedLock.Unlock()
+		cm.relayUsedLock.Lock()
+		defer cm.relayUsedLock.Unlock()
 		// No need to migrate any relays, delete usage info now.
 		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
-			delete(n.relayUsed, idx)
+			delete(cm.relayUsed, idx)
 		}
 	}
 }
 
-func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
+func (cm *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
 	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
 
 	for _, r := range relayFor {
@@ -238,7 +234,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 				index = existing.LocalIndex
 				switch r.Type {
 				case TerminalType:
-					relayFrom = n.intf.myVpnAddrs[0]
+					relayFrom = cm.intf.myVpnAddrs[0]
 					relayTo = existing.PeerAddr
 				case ForwardingType:
 					relayFrom = existing.PeerAddr
@@ -249,23 +245,23 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 				}
 			}
 		case !ok:
-			n.relayUsedLock.RLock()
-			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
+			cm.relayUsedLock.RLock()
+			if _, relayUsed := cm.relayUsed[r.LocalIndex]; !relayUsed {
 				// The relay hasn't been used; don't migrate it.
-				n.relayUsedLock.RUnlock()
+				cm.relayUsedLock.RUnlock()
 				continue
 			}
-			n.relayUsedLock.RUnlock()
+			cm.relayUsedLock.RUnlock()
 			// The relay doesn't exist at all; create some relay state and send the request.
 			var err error
-			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested)
+			index, err = AddRelay(cm.l, newhostinfo, cm.hostMap, r.PeerAddr, nil, r.Type, Requested)
 			if err != nil {
-				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
+				cm.l.WithError(err).Error("failed to migrate relay to new hostinfo")
 				continue
 			}
 			switch r.Type {
 			case TerminalType:
-				relayFrom = n.intf.myVpnAddrs[0]
+				relayFrom = cm.intf.myVpnAddrs[0]
 				relayTo = r.PeerAddr
 			case ForwardingType:
 				relayFrom = r.PeerAddr
@@ -285,12 +281,12 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 		switch newhostinfo.GetCert().Certificate.Version() {
 		case cert.Version1:
 			if !relayFrom.Is4() {
-				n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
+				cm.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version")
 				continue
 			}
 
 			if !relayTo.Is4() {
-				n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
+				cm.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version")
 				continue
 			}
 
@@ -302,16 +298,16 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 			req.RelayFromAddr = netAddrToProtoAddr(relayFrom)
 			req.RelayToAddr = netAddrToProtoAddr(relayTo)
 		default:
-			newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay")
+			newhostinfo.logger(cm.l).Error("Unknown certificate version found while attempting to migrate relay")
 			continue
 		}
 
 		msg, err := req.Marshal()
 		if err != nil {
-			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
+			cm.l.WithError(err).Error("failed to marshal Control message to migrate relay")
 		} else {
-			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
-			n.l.WithFields(logrus.Fields{
+			cm.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
+			cm.l.WithFields(logrus.Fields{
 				"relayFrom":           req.RelayFromAddr,
 				"relayTo":             req.RelayToAddr,
 				"initiatorRelayIndex": req.InitiatorRelayIndex,
@@ -322,46 +318,45 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo)
 	}
 }
 
-func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
-	n.hostMap.RLock()
-	defer n.hostMap.RUnlock()
+func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+	// Read lock the main hostmap to order decisions based on tunnels being the primary tunnel
+	cm.hostMap.RLock()
+	defer cm.hostMap.RUnlock()
 
-	hostinfo := n.hostMap.Indexes[localIndex]
+	hostinfo := cm.hostMap.Indexes[localIndex]
 	if hostinfo == nil {
-		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-		delete(n.pendingDeletion, localIndex)
+		cm.l.WithField("localIndex", localIndex).Debugln("Not found in hostmap")
 		return doNothing, nil, nil
 	}
 
-	if n.isInvalidCertificate(now, hostinfo) {
-		delete(n.pendingDeletion, hostinfo.localIndexId)
+	if cm.isInvalidCertificate(now, hostinfo) {
 		return closeTunnel, hostinfo, nil
 	}
 
-	primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]]
+	primary := cm.hostMap.Hosts[hostinfo.vpnAddrs[0]]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
 	}
 
 	// Check for traffic on this hostinfo
-	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
+	inTraffic, outTraffic := cm.getAndResetTrafficCheck(hostinfo, now)
 
 	// A hostinfo is determined alive if there is incoming traffic
 	if inTraffic {
 		decision := doNothing
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).
 				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
 				Debug("Tunnel status")
 		}
-		delete(n.pendingDeletion, hostinfo.localIndexId)
+		hostinfo.pendingDeletion.Store(false)
 
 		if mainHostInfo {
 			decision = tryRehandshake
 
 		} else {
-			if n.shouldSwapPrimary(hostinfo, primary) {
+			if cm.shouldSwapPrimary(hostinfo, primary) {
 				decision = swapPrimary
 			} else {
 				// migrate the relays to the primary, if in use.
@@ -369,46 +364,55 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 			}
 		}
 
-		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+		cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
 
 		if !outTraffic {
 			// Send a punch packet to keep the NAT state alive
-			n.sendPunch(hostinfo)
+			cm.sendPunch(hostinfo)
 		}
 
 		return decision, hostinfo, primary
 	}
 
-	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
+	if hostinfo.pendingDeletion.Load() {
 		// We have already sent a test packet and nothing was returned, this hostinfo is dead
-		hostinfo.logger(n.l).
+		hostinfo.logger(cm.l).
 			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 			Info("Tunnel status")
 
-		delete(n.pendingDeletion, hostinfo.localIndexId)
 		return deleteTunnel, hostinfo, nil
 	}
 
 	decision := doNothing
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
+			inactiveFor, isInactive := cm.isInactive(hostinfo, now)
+			if isInactive {
+				// Tunnel is inactive, tear it down
+				hostinfo.logger(cm.l).
+					WithField("inactiveDuration", inactiveFor).
+					WithField("primary", mainHostInfo).
+					Info("Dropping tunnel due to inactivity")
+
+				return closeTunnel, hostinfo, primary
+			}
+
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// Just maintain NAT state if configured to do so.
-			n.sendPunch(hostinfo)
-			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+			cm.sendPunch(hostinfo)
+			cm.trafficTimer.Add(hostinfo.localIndexId, cm.checkInterval)
 			return doNothing, nil, nil
-
 		}
 
-		if n.punchy.GetTargetEverything() {
+		if cm.punchy.GetTargetEverything() {
 			// This is similar to the old punchy behavior with a slight optimization.
 			// We aren't receiving traffic but we are sending it, punch on all known
 			// ips in case we need to re-prime NAT state
-			n.sendPunch(hostinfo)
+			cm.sendPunch(hostinfo)
 		}
 
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).
 				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 				Debug("Tunnel status")
 		}
@@ -417,17 +421,33 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time
 		decision = sendTestPacket
 
 	} else {
-		if n.l.Level >= logrus.DebugLevel {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if cm.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(cm.l).Debugf("Hostinfo sadness")
 		}
 	}
 
-	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
-	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
+	hostinfo.pendingDeletion.Store(true)
+	cm.trafficTimer.Add(hostinfo.localIndexId, cm.pendingDeletionInterval)
 	return decision, hostinfo, nil
 }
 
-func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
+func (cm *connectionManager) isInactive(hostinfo *HostInfo, now time.Time) (time.Duration, bool) {
+	if cm.dropInactive.Load() == false {
+		// We aren't configured to drop inactive tunnels
+		return 0, false
+	}
+
+	inactiveDuration := now.Sub(hostinfo.lastUsed)
+	if inactiveDuration < cm.getInactivityTimeout() {
+		// It's not considered inactive
+		return inactiveDuration, false
+	}
+
+	// The tunnel is inactive
+	return inactiveDuration, true
+}
+
+func (cm *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
 	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
 	// Let's sort this out.
@@ -435,73 +455,80 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	// Only one side should swap because if both swap then we may never resolve to a single tunnel.
 	// vpn addr is static across all tunnels for this host pair so lets
 	// use that to determine if we should consider swapping.
-	if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 {
+	if current.vpnAddrs[0].Compare(cm.intf.myVpnAddrs[0]) < 0 {
 		// Their primary vpn addr is less than mine. Do not swap.
 		return false
 	}
 
-	crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
 	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
 	// settle down.
 	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
 }
 
-func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
-	n.hostMap.Lock()
+func (cm *connectionManager) swapPrimary(current, primary *HostInfo) {
+	cm.hostMap.Lock()
 	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
-	if n.hostMap.Hosts[current.vpnAddrs[0]] == primary {
-		n.hostMap.unlockedMakePrimary(current)
+	if cm.hostMap.Hosts[current.vpnAddrs[0]] == primary {
+		cm.hostMap.unlockedMakePrimary(current)
 	}
-	n.hostMap.Unlock()
+	cm.hostMap.Unlock()
 }
 
 // isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
 // the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid
 // check and return true.
-func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
+func (cm *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	remoteCert := hostinfo.GetCert()
 	if remoteCert == nil {
 		return false
 	}
 
-	caPool := n.intf.pki.GetCAPool()
+	caPool := cm.intf.pki.GetCAPool()
 	err := caPool.VerifyCachedCertificate(now, remoteCert)
 	if err == nil {
 		return false
 	}
 
-	if !n.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
+	if !cm.intf.disconnectInvalid.Load() && err != cert.ErrBlockListed {
 		// Block listed certificates should always be disconnected
 		return false
 	}
 
-	hostinfo.logger(n.l).WithError(err).
+	hostinfo.logger(cm.l).WithError(err).
 		WithField("fingerprint", remoteCert.Fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
 	return true
 }
 
-func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
-	if !n.punchy.GetPunch() {
+func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
+	if !cm.punchy.GetPunch() {
 		// Punching is disabled
 		return
 	}
 
-	if n.punchy.GetTargetEverything() {
-		hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
-			n.metricsTxPunchy.Inc(1)
-			n.intf.outside.WriteTo([]byte{1}, addr)
+	if cm.intf.lightHouse.IsAnyLighthouseAddr(hostinfo.vpnAddrs) {
+		// Do not punch to lighthouses, we assume our lighthouse update interval is good enough.
+		// In the event the update interval is not sufficient to maintain NAT state then a publicly available lighthouse
+		// would lose the ability to notify us and punchy.respond would become unreliable.
+		return
+	}
+
+	if cm.punchy.GetTargetEverything() {
+		hostinfo.remotes.ForEach(cm.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) {
+			cm.metricsTxPunchy.Inc(1)
+			cm.intf.outside.WriteTo([]byte{1}, addr)
 		})
 
 	} else if hostinfo.remote.IsValid() {
-		n.metricsTxPunchy.Inc(1)
-		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
+		cm.metricsTxPunchy.Inc(1)
+		cm.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 }
 
-func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	cs := n.intf.pki.getCertState()
+func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
+	cs := cm.intf.pki.getCertState()
 	curCrt := hostinfo.ConnectionState.myCert
 	myCrt := cs.getCertificate(curCrt.Version())
 	if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
@@ -509,9 +536,9 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 		return
 	}
 
-	n.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+	cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
 		WithField("reason", "local certificate is not current").
 		Info("Re-handshaking with remote")
 
-	n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+	cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
 }

+ 136 - 40
connection_manager_test.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"context"
 	"crypto/ed25519"
 	"crypto/rand"
 	"net/netip"
@@ -64,10 +63,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	ifce.pki.cs.Store(cs)
 
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
@@ -85,32 +84,33 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// We saw traffic out to vpnIp
-	nc.Out(hostinfo.localIndexId)
-	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.False(t, hostinfo.pendingDeletion.Load())
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	assert.Contains(t, nc.out, hostinfo.localIndexId)
+	assert.True(t, hostinfo.out.Load())
+	assert.True(t, hostinfo.in.Load())
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 
 	// Do another traffic check tick, this host should be pending deletion now
-	nc.Out(hostinfo.localIndexId)
+	nc.Out(hostinfo)
+	assert.True(t, hostinfo.out.Load())
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.True(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// Do a final traffic check tick, the host should now be removed
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
+	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 }
 
@@ -146,10 +146,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	ifce.pki.cs.Store(cs)
 
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
@@ -167,33 +167,129 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// We saw traffic out to vpnIp
-	nc.Out(hostinfo.localIndexId)
-	nc.In(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0])
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.True(t, hostinfo.in.Load())
+	assert.True(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.pendingDeletion.Load())
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 
 	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 
 	// Do another traffic check tick, this host should be pending deletion now
-	nc.Out(hostinfo.localIndexId)
+	nc.Out(hostinfo)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.True(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 
 	// We saw traffic, should no longer be pending deletion
-	nc.In(hostinfo.localIndexId)
+	nc.In(hostinfo)
 	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
-	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.NotContains(t, nc.out, hostinfo.localIndexId)
-	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
+}
+
+func Test_NewConnectionManager_DisconnectInactive(t *testing.T) {
+	l := test.NewLogger()
+	localrange := netip.MustParsePrefix("10.1.1.1/24")
+	vpnAddrs := []netip.Addr{netip.MustParseAddr("172.1.1.2")}
+	preferredRanges := []netip.Prefix{localrange}
+
+	// Very incomplete mock objects
+	hostMap := newHostMap(l)
+	hostMap.preferredRanges.Store(&preferredRanges)
+
+	cs := &CertState{
+		initiatingVersion: cert.Version1,
+		privateKey:        []byte{},
+		v1Cert:            &dummyCert{version: cert.Version1},
+		v1HandshakeBytes:  []byte{},
+	}
+
+	lh := newTestLighthouse()
+	ifce := &Interface{
+		hostMap:          hostMap,
+		inside:           &test.NoopTun{},
+		outside:          &udp.NoopConn{},
+		firewall:         &Firewall{},
+		lightHouse:       lh,
+		pki:              &PKI{},
+		handshakeManager: NewHandshakeManager(l, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
+		l:                l,
+	}
+	ifce.pki.cs.Store(cs)
+
+	// Create manager
+	conf := config.NewC(l)
+	conf.Settings["tunnels"] = map[string]any{
+		"drop_inactive": true,
+	}
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	assert.True(t, nc.dropInactive.Load())
+	nc.intf = ifce
+
+	// Add an ip we have established a connection w/ to hostmap
+	hostinfo := &HostInfo{
+		vpnAddrs:      vpnAddrs,
+		localIndexId:  1099,
+		remoteIndexId: 9901,
+	}
+	hostinfo.ConnectionState = &ConnectionState{
+		myCert: &dummyCert{version: cert.Version1},
+		H:      &noise.HandshakeState{},
+	}
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
+
+	// Do a traffic check tick, in and out should be cleared but should not be pending deletion
+	nc.Out(hostinfo)
+	nc.In(hostinfo)
+	assert.True(t, hostinfo.out.Load())
+	assert.True(t, hostinfo.in.Load())
+
+	now := time.Now()
+	decision, _, _ := nc.makeTrafficDecision(hostinfo.localIndexId, now)
+	assert.Equal(t, tryRehandshake, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*5))
+	assert.Equal(t, doNothing, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+
+	// Do another traffic check tick, should still not be pending deletion
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Second*10))
+	assert.Equal(t, doNothing, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
+
+	// Finally advance beyond the inactivity timeout
+	decision, _, _ = nc.makeTrafficDecision(hostinfo.localIndexId, now.Add(time.Minute*10))
+	assert.Equal(t, closeTunnel, decision)
+	assert.Equal(t, now, hostinfo.lastUsed)
+	assert.False(t, hostinfo.pendingDeletion.Load())
+	assert.False(t, hostinfo.out.Load())
+	assert.False(t, hostinfo.in.Load())
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0])
 }
@@ -264,10 +360,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	ifce.disconnectInvalid.Store(true)
 
 	// Create manager
-	ctx, cancel := context.WithCancel(context.Background())
-	defer cancel()
-	punchy := NewPunchyFromConfig(l, config.NewC(l))
-	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
+	conf := config.NewC(l)
+	punchy := NewPunchyFromConfig(l, conf)
+	nc := newConnectionManagerFromConfig(l, conf, hostMap, punchy)
+	nc.intf = ifce
 	ifce.connectionManager = nc
 
 	hostinfo := &HostInfo{

+ 12 - 8
control.go

@@ -26,14 +26,15 @@ type controlHostLister interface {
 }
 
 type Control struct {
-	f               *Interface
-	l               *logrus.Logger
-	ctx             context.Context
-	cancel          context.CancelFunc
-	sshStart        func()
-	statsStart      func()
-	dnsStart        func()
-	lighthouseStart func()
+	f                      *Interface
+	l                      *logrus.Logger
+	ctx                    context.Context
+	cancel                 context.CancelFunc
+	sshStart               func()
+	statsStart             func()
+	dnsStart               func()
+	lighthouseStart        func()
+	connectionManagerStart func(context.Context)
 }
 
 type ControlHostInfo struct {
@@ -63,6 +64,9 @@ func (c *Control) Start() {
 	if c.dnsStart != nil {
 		go c.dnsStart()
 	}
+	if c.connectionManagerStart != nil {
+		go c.connectionManagerStart(c.ctx)
+	}
 	if c.lighthouseStart != nil {
 		c.lighthouseStart()
 	}

+ 4 - 1
e2e/handshakes_test.go

@@ -506,7 +506,7 @@ func TestReestablishRelays(t *testing.T) {
 	curIndexes := len(myControl.GetHostmap().Indexes)
 	for curIndexes >= start {
 		curIndexes = len(myControl.GetHostmap().Indexes)
-		r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes)
+		r.Logf("Wait for the dead index to go away:start=%v indexes, current=%v indexes", start, curIndexes)
 		myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail"))
 
 		r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
@@ -1052,6 +1052,9 @@ func TestRehandshakingLoser(t *testing.T) {
 	t.Log("Stand up a tunnel between me and them")
 	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 
+	myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+	theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+
 	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
 
 	r.Log("Renew their certificate and spin until mine sees it")

+ 1 - 0
e2e/router/router.go

@@ -700,6 +700,7 @@ func (r *R) FlushAll() {
 			r.Unlock()
 			panic("Can't FlushAll for host: " + p.To.String())
 		}
+		receiver.InjectUDPPacket(p)
 		r.Unlock()
 	}
 }

+ 57 - 0
e2e/tunnels_test.go

@@ -0,0 +1,57 @@
+//go:build e2e_testing
+// +build e2e_testing
+
+package e2e
+
+import (
+	"testing"
+	"time"
+
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/cert_test"
+	"github.com/slackhq/nebula/e2e/router"
+)
+
+func TestDropInactiveTunnels(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "5s"}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", m{"tunnels": m{"drop_inactive": true, "inactivity_timeout": "10m"}})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+
+	r.Log("Assert the tunnel between me and them works")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+
+	r.Log("Go inactive and wait for the tunnels to get dropped")
+	waitStart := time.Now()
+	for {
+		myIndexes := len(myControl.GetHostmap().Indexes)
+		theirIndexes := len(theirControl.GetHostmap().Indexes)
+		if myIndexes == 0 && theirIndexes == 0 {
+			break
+		}
+
+		since := time.Since(waitStart)
+		r.Logf("my tunnels: %v; their tunnels: %v; duration: %v", myIndexes, theirIndexes, since)
+		if since > time.Second*30 {
+			t.Fatal("Tunnel should have been declared inactive after 5 seconds and before 30 seconds")
+		}
+
+		time.Sleep(1 * time.Second)
+		r.FlushAll()
+	}
+
+	r.Logf("Inactive tunnels were dropped within %v", time.Since(waitStart))
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 12 - 0
examples/config.yml

@@ -338,6 +338,18 @@ logging:
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64
 
+# Tunnel manager settings
+#tunnels:
+  # drop_inactive controls whether inactive tunnels are maintained or dropped after the inactive_timeout period has
+  # elapsed.
+  # In general, it is a good idea to enable this setting. It will be enabled by default in a future release.
+  # This setting is reloadable
+  #drop_inactive: false
+
+  # inactivity_timeout controls how long a tunnel MUST NOT see any inbound or outbound traffic before being considered
+  # inactive and eligible to be dropped.
+  # This setting is reloadable
+  #inactivity_timeout: 10m
 
 # Nebula security group configuration
 firewall:

+ 2 - 2
handshake_ix.go

@@ -457,7 +457,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
 			Info("Handshake message sent")
 	}
 
-	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo)
 
 	hostinfo.remotes.ResetBlockedRemotes()
 
@@ -652,7 +652,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for the vpnAddrs here
 	f.handshakeManager.Complete(hostinfo, f)
-	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo)
 
 	if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))

+ 8 - 0
hostmap.go

@@ -256,6 +256,14 @@ type HostInfo struct {
 	// Used to track other hostinfos for this vpn ip since only 1 can be primary
 	// Synchronised via hostmap lock and not the hostinfo lock.
 	next, prev *HostInfo
+
+	//TODO: in, out, and others might benefit from being an atomic.Int32. We could collapse connectionManager pendingDeletion, relayUsed, and in/out into this 1 thing
+	in, out, pendingDeletion atomic.Bool
+
+	// lastUsed tracks the last time ConnectionManager checked the tunnel and it was in use.
+	// This value will be behind against actual tunnel utilization in the hot path.
+	// This should only be used by the ConnectionManagers ticker routine.
+	lastUsed time.Time
 }
 
 type ViaSender struct {

+ 2 - 2
inside.go

@@ -288,7 +288,7 @@ func (f *Interface) SendVia(via *HostInfo,
 	c := via.ConnectionState.messageCounter.Add(1)
 
 	out = header.Encode(out, header.Version, header.Message, header.MessageRelay, relay.RemoteIndex, c)
-	f.connectionManager.Out(via.localIndexId)
+	f.connectionManager.Out(via)
 
 	// Authenticate the header and payload, but do not encrypt for this message type.
 	// The payload consists of the inner, unencrypted Nebula header, as well as the end-to-end encrypted payload.
@@ -356,7 +356,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 
 	//l.WithField("trace", string(debug.Stack())).Error("out Header ", &Header{Version, t, st, 0, hostinfo.remoteIndexId, c}, p)
 	out = header.Encode(out, header.Version, t, st, hostinfo.remoteIndexId, c)
-	f.connectionManager.Out(hostinfo.localIndexId)
+	f.connectionManager.Out(hostinfo)
 
 	// Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against
 	// all our addrs and enable a faster roaming.

+ 22 - 19
interface.go

@@ -24,23 +24,23 @@ import (
 const mtu = 9001
 
 type InterfaceConfig struct {
-	HostMap                 *HostMap
-	Outside                 udp.Conn
-	Inside                  overlay.Device
-	pki                     *PKI
-	Firewall                *Firewall
-	ServeDns                bool
-	HandshakeManager        *HandshakeManager
-	lightHouse              *LightHouse
-	checkInterval           time.Duration
-	pendingDeletionInterval time.Duration
-	DropLocalBroadcast      bool
-	DropMulticast           bool
-	routines                int
-	MessageMetrics          *MessageMetrics
-	version                 string
-	relayManager            *relayManager
-	punchy                  *Punchy
+	HostMap            *HostMap
+	Outside            udp.Conn
+	Inside             overlay.Device
+	pki                *PKI
+	Cipher             string
+	Firewall           *Firewall
+	ServeDns           bool
+	HandshakeManager   *HandshakeManager
+	lightHouse         *LightHouse
+	connectionManager  *connectionManager
+	DropLocalBroadcast bool
+	DropMulticast      bool
+	routines           int
+	MessageMetrics     *MessageMetrics
+	version            string
+	relayManager       *relayManager
+	punchy             *Punchy
 
 	tryPromoteEvery uint32
 	reQueryEvery    uint32
@@ -157,6 +157,9 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	if c.Firewall == nil {
 		return nil, errors.New("no firewall rules")
 	}
+	if c.connectionManager == nil {
+		return nil, errors.New("no connection manager")
+	}
 
 	cs := c.pki.getCertState()
 	ifce := &Interface{
@@ -181,7 +184,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		myVpnAddrsTable:       cs.myVpnAddrsTable,
 		myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable,
 		relayManager:          c.relayManager,
-
+		connectionManager:     c.connectionManager,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 		metricHandshakes: metrics.GetOrRegisterHistogram("handshakes", nil, metrics.NewExpDecaySample(1028, 0.015)),
@@ -198,7 +201,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 
-	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
+	ifce.connectionManager.intf = ifce
 
 	return ifce, nil
 }

+ 21 - 24
main.go

@@ -185,6 +185,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	hostMap := NewHostMapFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
+	connManager := newConnectionManagerFromConfig(l, c, hostMap, punchy)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy)
 	if err != nil {
 		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
@@ -220,31 +221,26 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	checkInterval := c.GetInt("timers.connection_alive_interval", 5)
-	pendingDeletionInterval := c.GetInt("timers.pending_deletion_interval", 10)
-
 	ifConfig := &InterfaceConfig{
-		HostMap:                 hostMap,
-		Inside:                  tun,
-		Outside:                 udpConns[0],
-		pki:                     pki,
-		Firewall:                fw,
-		ServeDns:                serveDns,
-		HandshakeManager:        handshakeManager,
-		lightHouse:              lightHouse,
-		checkInterval:           time.Second * time.Duration(checkInterval),
-		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
-		tryPromoteEvery:         c.GetUint32("counters.try_promote", defaultPromoteEvery),
-		reQueryEvery:            c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
-		reQueryWait:             c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
-		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
-		DropMulticast:           c.GetBool("tun.drop_multicast", false),
-		routines:                routines,
-		MessageMetrics:          messageMetrics,
-		version:                 buildVersion,
-		relayManager:            NewRelayManager(ctx, l, hostMap, c),
-		punchy:                  punchy,
-
+		HostMap:               hostMap,
+		Inside:                tun,
+		Outside:               udpConns[0],
+		pki:                   pki,
+		Firewall:              fw,
+		ServeDns:              serveDns,
+		HandshakeManager:      handshakeManager,
+		connectionManager:     connManager,
+		lightHouse:            lightHouse,
+		tryPromoteEvery:       c.GetUint32("counters.try_promote", defaultPromoteEvery),
+		reQueryEvery:          c.GetUint32("counters.requery_every_packets", defaultReQueryEvery),
+		reQueryWait:           c.GetDuration("timers.requery_wait_duration", defaultReQueryWait),
+		DropLocalBroadcast:    c.GetBool("tun.drop_local_broadcast", false),
+		DropMulticast:         c.GetBool("tun.drop_multicast", false),
+		routines:              routines,
+		MessageMetrics:        messageMetrics,
+		version:               buildVersion,
+		relayManager:          NewRelayManager(ctx, l, hostMap, c),
+		punchy:                punchy,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
 	}
@@ -296,5 +292,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		statsStart,
 		dnsStart,
 		lightHouse.StartUpdateWorker,
+		connManager.Start,
 	}, nil
 }

+ 3 - 3
outside.go

@@ -81,7 +81,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 			// Pull the Roaming parts up here, and return in all call paths.
 			f.handleHostRoaming(hostinfo, ip)
 			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
-			f.connectionManager.In(hostinfo.localIndexId)
+			f.connectionManager.In(hostinfo)
 			f.connectionManager.RelayUsed(h.RemoteIndex)
 
 			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
@@ -213,7 +213,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []
 
 	f.handleHostRoaming(hostinfo, ip)
 
-	f.connectionManager.In(hostinfo.localIndexId)
+	f.connectionManager.In(hostinfo)
 }
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
@@ -498,7 +498,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return false
 	}
 
-	f.connectionManager.In(hostinfo.localIndexId)
+	f.connectionManager.In(hostinfo)
 	_, err = f.readers[q].Write(out)
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")

+ 0 - 2
udp/udp_darwin.go

@@ -3,8 +3,6 @@
 
 package udp
 
-// Darwin support is primarily implemented in udp_generic, besides NewListenConfig
-
 import (
 	"context"
 	"encoding/binary"