瀏覽代碼

Use connection manager to drive NAT maintenance (#835)

Co-authored-by: brad-defined <[email protected]>
Nate Brown 2 年之前
父節點
當前提交
ee8e1348e9
共有 9 個文件被更改,包括 213 次插入313 次删除
  1. 130 190
      connection_manager.go
  2. 48 41
      connection_manager_test.go
  3. 3 12
      handshake_ix.go
  4. 1 3
      handshake_manager.go
  5. 0 49
      hostmap.go
  6. 4 3
      interface.go
  7. 3 7
      main.go
  8. 0 3
      outside.go
  9. 24 5
      punchy.go

+ 130 - 190
connection_manager.go

@@ -5,49 +5,55 @@ import (
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
+	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/udp"
 )
 )
 
 
-// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
-// and something like every 10 packets we could lock, send 10, then unlock for a moment
-
 type connectionManager struct {
 type connectionManager struct {
-	hostMap      *HostMap
-	in           map[uint32]struct{}
-	inLock       *sync.RWMutex
-	out          map[uint32]struct{}
-	outLock      *sync.RWMutex
-	TrafficTimer *LockingTimerWheel[uint32]
-	intf         *Interface
+	in     map[uint32]struct{}
+	inLock *sync.RWMutex
 
 
-	pendingDeletion      map[uint32]int
-	pendingDeletionLock  *sync.RWMutex
-	pendingDeletionTimer *LockingTimerWheel[uint32]
+	out     map[uint32]struct{}
+	outLock *sync.RWMutex
 
 
-	checkInterval           int
-	pendingDeletionInterval int
+	hostMap                 *HostMap
+	trafficTimer            *LockingTimerWheel[uint32]
+	intf                    *Interface
+	pendingDeletion         map[uint32]struct{}
+	punchy                  *Punchy
+	checkInterval           time.Duration
+	pendingDeletionInterval time.Duration
+	metricsTxPunchy         metrics.Counter
 
 
 	l *logrus.Logger
 	l *logrus.Logger
-	// I wanted to call one matLock
 }
 }
 
 
-func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
+func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
+	var max time.Duration
+	if checkInterval < pendingDeletionInterval {
+		max = pendingDeletionInterval
+	} else {
+		max = checkInterval
+	}
+
 	nc := &connectionManager{
 	nc := &connectionManager{
 		hostMap:                 intf.hostMap,
 		hostMap:                 intf.hostMap,
 		in:                      make(map[uint32]struct{}),
 		in:                      make(map[uint32]struct{}),
 		inLock:                  &sync.RWMutex{},
 		inLock:                  &sync.RWMutex{},
 		out:                     make(map[uint32]struct{}),
 		out:                     make(map[uint32]struct{}),
 		outLock:                 &sync.RWMutex{},
 		outLock:                 &sync.RWMutex{},
-		TrafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
+		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
 		intf:                    intf,
 		intf:                    intf,
-		pendingDeletion:         make(map[uint32]int),
-		pendingDeletionLock:     &sync.RWMutex{},
-		pendingDeletionTimer:    NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
+		pendingDeletion:         make(map[uint32]struct{}),
 		checkInterval:           checkInterval,
 		checkInterval:           checkInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
+		punchy:                  punchy,
+		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
 		l:                       l,
 		l:                       l,
 	}
 	}
+
 	nc.Start(ctx)
 	nc.Start(ctx)
 	return nc
 	return nc
 }
 }
@@ -74,65 +80,27 @@ func (n *connectionManager) Out(localIndex uint32) {
 	}
 	}
 	n.outLock.RUnlock()
 	n.outLock.RUnlock()
 	n.outLock.Lock()
 	n.outLock.Lock()
-	// double check since we dropped the lock temporarily
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.Unlock()
-		return
-	}
 	n.out[localIndex] = struct{}{}
 	n.out[localIndex] = struct{}{}
-	n.AddTrafficWatch(localIndex, n.checkInterval)
 	n.outLock.Unlock()
 	n.outLock.Unlock()
 }
 }
 
 
-func (n *connectionManager) CheckIn(localIndex uint32) bool {
-	n.inLock.RLock()
-	if _, ok := n.in[localIndex]; ok {
-		n.inLock.RUnlock()
-		return true
-	}
-	n.inLock.RUnlock()
-	return false
-}
-
-func (n *connectionManager) ClearLocalIndex(localIndex uint32) {
+// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
+// resets the state for this local index
+func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
 	n.inLock.Lock()
 	n.inLock.Lock()
 	n.outLock.Lock()
 	n.outLock.Lock()
+	_, in := n.in[localIndex]
+	_, out := n.out[localIndex]
 	delete(n.in, localIndex)
 	delete(n.in, localIndex)
 	delete(n.out, localIndex)
 	delete(n.out, localIndex)
 	n.inLock.Unlock()
 	n.inLock.Unlock()
 	n.outLock.Unlock()
 	n.outLock.Unlock()
+	return in, out
 }
 }
 
 
-func (n *connectionManager) ClearPendingDeletion(localIndex uint32) {
-	n.pendingDeletionLock.Lock()
-	delete(n.pendingDeletion, localIndex)
-	n.pendingDeletionLock.Unlock()
-}
-
-func (n *connectionManager) AddPendingDeletion(localIndex uint32) {
-	n.pendingDeletionLock.Lock()
-	if _, ok := n.pendingDeletion[localIndex]; ok {
-		n.pendingDeletion[localIndex] += 1
-	} else {
-		n.pendingDeletion[localIndex] = 0
-	}
-	n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval))
-	n.pendingDeletionLock.Unlock()
-}
-
-func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool {
-	n.pendingDeletionLock.RLock()
-	if _, ok := n.pendingDeletion[localIndex]; ok {
-
-		n.pendingDeletionLock.RUnlock()
-		return true
-	}
-	n.pendingDeletionLock.RUnlock()
-	return false
-}
-
-func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) {
-	n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds))
+func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
+	n.Out(localIndex)
+	n.trafficTimer.Add(localIndex, n.checkInterval)
 }
 }
 
 
 func (n *connectionManager) Start(ctx context.Context) {
 func (n *connectionManager) Start(ctx context.Context) {
@@ -140,6 +108,7 @@ func (n *connectionManager) Start(ctx context.Context) {
 }
 }
 
 
 func (n *connectionManager) Run(ctx context.Context) {
 func (n *connectionManager) Run(ctx context.Context) {
+	//TODO: this tick should be based on the min wheel tick? Check firewall
 	clockSource := time.NewTicker(500 * time.Millisecond)
 	clockSource := time.NewTicker(500 * time.Millisecond)
 	defer clockSource.Stop()
 	defer clockSource.Stop()
 
 
@@ -151,151 +120,106 @@ func (n *connectionManager) Run(ctx context.Context) {
 		select {
 		select {
 		case <-ctx.Done():
 		case <-ctx.Done():
 			return
 			return
+
 		case now := <-clockSource.C:
 		case now := <-clockSource.C:
-			n.HandleMonitorTick(now, p, nb, out)
-			n.HandleDeletionTick(now)
+			n.trafficTimer.Advance(now)
+			for {
+				localIndex, has := n.trafficTimer.Purge()
+				if !has {
+					break
+				}
+
+				n.doTrafficCheck(localIndex, p, nb, out, now)
+			}
 		}
 		}
 	}
 	}
 }
 }
 
 
-func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
-	n.TrafficTimer.Advance(now)
-	for {
-		localIndex, has := n.TrafficTimer.Purge()
-		if !has {
-			break
-		}
+func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
+	hostinfo, err := n.hostMap.QueryIndex(localIndex)
+	if err != nil {
+		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
+		delete(n.pendingDeletion, localIndex)
+		return
+	}
 
 
-		// Check for traffic coming back in from this host.
-		traf := n.CheckIn(localIndex)
+	if n.handleInvalidCertificate(now, hostinfo) {
+		return
+	}
 
 
-		hostinfo, err := n.hostMap.QueryIndex(localIndex)
-		if err != nil {
-			n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-			continue
-		}
+	primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
+	mainHostInfo := true
+	if primary != nil && primary != hostinfo {
+		mainHostInfo = false
+	}
 
 
-		if n.handleInvalidCertificate(now, hostinfo) {
-			continue
-		}
+	// Check for traffic on this hostinfo
+	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
 
 
-		// Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
-		// decide if this should be the main hostinfo if we are seeing traffic on it
-		primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
-		mainHostInfo := true
-		if primary != nil && primary != hostinfo {
-			mainHostInfo = false
+	// A hostinfo is determined alive if there is incoming traffic
+	if inTraffic {
+		if n.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(n.l).
+				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
+				Debug("Tunnel status")
 		}
 		}
+		delete(n.pendingDeletion, hostinfo.localIndexId)
 
 
-		// If we saw an incoming packets from this ip and peer's certificate is not
-		// expired, just ignore.
-		if traf {
-			if n.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(n.l).
-					WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
-					Debug("Tunnel status")
+		if !mainHostInfo {
+			if hostinfo.vpnIp > n.intf.myVpnIp {
+				// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
+				// This the primary and prime the old primary hostinfo for testing
+				n.hostMap.MakePrimary(hostinfo)
 			}
 			}
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-
-			if !mainHostInfo {
-				if hostinfo.vpnIp > n.intf.myVpnIp {
-					// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
-					// This the primary and prime the old primary hostinfo for testing
-					n.hostMap.MakePrimary(hostinfo)
-					n.Out(primary.localIndexId)
-				} else {
-					// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
-					// Keep tracking so that we can tear it down when it goes away
-					n.Out(hostinfo.localIndexId)
-				}
-			}
-
-			continue
-		}
-
-		if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
-			// Don't probe lighthouses since recv_error should naturally catch this.
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-			continue
 		}
 		}
 
 
-		hostinfo.logger(n.l).
-			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
-			Debug("Tunnel status")
-
-		if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
-			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
 
 
-		} else {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if !outTraffic {
+			// Send a punch packet to keep the NAT state alive
+			n.sendPunch(hostinfo)
 		}
 		}
-		n.AddPendingDeletion(localIndex)
-	}
-
-}
 
 
-func (n *connectionManager) HandleDeletionTick(now time.Time) {
-	n.pendingDeletionTimer.Advance(now)
-	for {
-		localIndex, has := n.pendingDeletionTimer.Purge()
-		if !has {
-			break
-		}
+		return
+	}
 
 
-		hostinfo, mainHostInfo, err := n.hostMap.QueryIndexIsPrimary(localIndex)
-		if err != nil {
-			n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-			continue
-		}
+	if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
+		// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
+		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+		return
+	}
 
 
-		if n.handleInvalidCertificate(now, hostinfo) {
-			continue
-		}
+	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
+		// We have already sent a test packet and nothing was returned, this hostinfo is dead
+		hostinfo.logger(n.l).
+			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
+			Info("Tunnel status")
 
 
-		// If we saw an incoming packets from this ip and peer's certificate is not
-		// expired, just ignore.
-		traf := n.CheckIn(localIndex)
-		if traf {
-			hostinfo.logger(n.l).
-				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
-				Debug("Tunnel status")
+		n.hostMap.DeleteHostInfo(hostinfo)
+		delete(n.pendingDeletion, hostinfo.localIndexId)
+		return
+	}
 
 
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
+	hostinfo.logger(n.l).
+		WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
+		Debug("Tunnel status")
 
 
-			if !mainHostInfo {
-				// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
-				// Keep tracking so that we can tear it down when it goes away
-				n.Out(localIndex)
-			}
-			continue
+	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
+		if n.punchy.GetTargetEverything() {
+			// Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all
+			// known remotes, go ahead and do that AND send a test packet
+			n.sendPunch(hostinfo)
 		}
 		}
 
 
-		// If it comes around on deletion wheel and hasn't resolved itself, delete
-		if n.checkPendingDeletion(localIndex) {
-			cn := ""
-			if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
-				cn = hostinfo.ConnectionState.peerCert.Details.Name
-			}
-
-			hostinfo.logger(n.l).
-				WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
-				WithField("certName", cn).
-				Info("Tunnel status")
+		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
+		n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
 
 
-			n.hostMap.DeleteHostInfo(hostinfo)
-		}
-
-		n.ClearLocalIndex(localIndex)
-		n.ClearPendingDeletion(localIndex)
+	} else {
+		hostinfo.logger(n.l).Debugf("Hostinfo sadness")
 	}
 	}
+
+	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
+	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
 }
 }
 
 
 // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
 // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
@@ -322,8 +246,24 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
 	// Inform the remote and close the tunnel locally
 	// Inform the remote and close the tunnel locally
 	n.intf.sendCloseTunnel(hostinfo)
 	n.intf.sendCloseTunnel(hostinfo)
 	n.intf.closeTunnel(hostinfo)
 	n.intf.closeTunnel(hostinfo)
-
-	n.ClearLocalIndex(hostinfo.localIndexId)
-	n.ClearPendingDeletion(hostinfo.localIndexId)
+	delete(n.pendingDeletion, hostinfo.localIndexId)
 	return true
 	return true
 }
 }
+
+func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
+	if !n.punchy.GetPunch() {
+		// Punching is disabled
+		return
+	}
+
+	if n.punchy.GetTargetEverything() {
+		hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
+			n.metricsTxPunchy.Inc(1)
+			n.intf.outside.WriteTo([]byte{1}, addr)
+		})
+
+	} else if hostinfo.remote != nil {
+		n.metricsTxPunchy.Inc(1)
+		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
+	}
+}

+ 48 - 41
connection_manager_test.go

@@ -10,6 +10,7 @@ import (
 
 
 	"github.com/flynn/noise"
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
@@ -60,16 +61,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		l:                l,
 		l:                l,
 	}
 	}
 	ifce.certState.Store(cs)
 	ifce.certState.Store(cs)
-	now := time.Now()
 
 
 	// Create manager
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 	defer cancel()
-	nc := newConnectionManager(ctx, l, ifce, 5, 10)
+	punchy := NewPunchyFromConfig(l, config.NewC(l))
+	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	p := []byte("")
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
-	nc.HandleMonitorTick(now, p, nb, out)
+
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
 		vpnIp:         vpnIp,
 		vpnIp:         vpnIp,
@@ -84,26 +85,28 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
+	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	// Move ahead 5s. Nothing should happen
-	next_tick := now.Add(5 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// Move ahead 6s. We haven't heard back
-	next_tick = now.Add(6 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// This host should now be up for deletion
+	assert.Contains(t, nc.out, hostinfo.localIndexId)
+
+	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+
+	// Do another traffic check tick, this host should be pending deletion now
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	// Move ahead some more
-	next_tick = now.Add(45 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// The host should be evicted
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+
+	// Do a final traffic check tick, the host should now be removed
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
@@ -136,16 +139,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		l:                l,
 		l:                l,
 	}
 	}
 	ifce.certState.Store(cs)
 	ifce.certState.Store(cs)
-	now := time.Now()
 
 
 	// Create manager
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 	defer cancel()
-	nc := newConnectionManager(ctx, l, ifce, 5, 10)
+	punchy := NewPunchyFromConfig(l, config.NewC(l))
+	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	p := []byte("")
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
 	out := make([]byte, mtu)
-	nc.HandleMonitorTick(now, p, nb, out)
+
 	// Add an ip we have established a connection w/ to hostmap
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
 		vpnIp:         vpnIp,
 		vpnIp:         vpnIp,
@@ -160,30 +163,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
-	// Move ahead 5s. Nothing should happen
-	next_tick := now.Add(5 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// Move ahead 6s. We haven't heard back
-	next_tick = now.Add(6 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// This host should now be up for deletion
+	nc.In(hostinfo.localIndexId)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+
+	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+
+	// Do another traffic check tick, this host should be pending deletion now
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	// We heard back this time
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+
+	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
 	nc.In(hostinfo.localIndexId)
-	// Move ahead some more
-	next_tick = now.Add(45 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// The host should not be evicted
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 }
 }
 
 
 // Check if we can disconnect the peer.
 // Check if we can disconnect the peer.
@@ -257,7 +263,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	// Create manager
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 	defer cancel()
-	nc := newConnectionManager(ctx, l, ifce, 5, 10)
+	punchy := NewPunchyFromConfig(l, config.NewC(l))
+	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	ifce.connectionManager = nc
 	ifce.connectionManager = nc
 	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
 	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{

+ 3 - 12
handshake_ix.go

@@ -332,12 +332,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 			Info("Handshake message sent")
 			Info("Handshake message sent")
 	}
 	}
 
 
-	if existing != nil {
-		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
-		f.connectionManager.Out(existing.localIndexId)
-	}
-
-	f.connectionManager.Out(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 
 
 	return
 	return
@@ -495,12 +490,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
 	hostinfo.CreateRemoteCIDR(remoteCert)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
-	existing := f.handshakeManager.Complete(hostinfo, f)
-	if existing != nil {
-		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
-		f.connectionManager.Out(existing.localIndexId)
-	}
-
+	f.handshakeManager.Complete(hostinfo, f)
+	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	f.metricHandshakes.Update(duration)
 	f.metricHandshakes.Update(duration)
 
 

+ 1 - 3
handshake_manager.go

@@ -380,7 +380,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // Complete is a simpler version of CheckAndComplete when we already know we
 // Complete is a simpler version of CheckAndComplete when we already know we
 // won't have a localIndexId collision because we already have an entry in the
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap. An existing hostinfo is returned if there was one.
 // pendingHostMap. An existing hostinfo is returned if there was one.
-func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo {
+func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	c.pendingHostMap.Lock()
 	c.pendingHostMap.Lock()
 	defer c.pendingHostMap.Unlock()
 	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	c.mainHostMap.Lock()
@@ -395,11 +395,9 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
-	existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
 	// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
 	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
 	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
 	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
-	return existingHostInfo
 }
 }
 
 
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo

+ 0 - 49
hostmap.go

@@ -1,7 +1,6 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
@@ -621,54 +620,6 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 	}
 }
 }
 
 
-// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
-// The caller can then do the its work outside of the read lock
-func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
-	hm.RLock()
-	defer hm.RUnlock()
-
-	for _, v := range hm.Hosts {
-		if v.remotes != nil {
-			rl = append(rl, v.remotes)
-		}
-	}
-	return rl
-}
-
-// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
-func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
-	var metricsTxPunchy metrics.Counter
-	if hm.metricsEnabled {
-		metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
-	} else {
-		metricsTxPunchy = metrics.NilCounter{}
-	}
-
-	var remotes []*RemoteList
-	b := []byte{1}
-
-	clockSource := time.NewTicker(time.Second * 10)
-	defer clockSource.Stop()
-
-	for {
-		remotes = hm.punchList(remotes[:0])
-		for _, rl := range remotes {
-			//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
-			for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
-				metricsTxPunchy.Inc(1)
-				conn.WriteTo(b, addr)
-			}
-		}
-
-		select {
-		case <-ctx.Done():
-			return
-		case <-clockSource.C:
-			continue
-		}
-	}
-}
-
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {

+ 4 - 3
interface.go

@@ -33,8 +33,8 @@ type InterfaceConfig struct {
 	ServeDns                bool
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
 	HandshakeManager        *HandshakeManager
 	lightHouse              *LightHouse
 	lightHouse              *LightHouse
-	checkInterval           int
-	pendingDeletionInterval int
+	checkInterval           time.Duration
+	pendingDeletionInterval time.Duration
 	DropLocalBroadcast      bool
 	DropLocalBroadcast      bool
 	DropMulticast           bool
 	DropMulticast           bool
 	routines                int
 	routines                int
@@ -43,6 +43,7 @@ type InterfaceConfig struct {
 	caPool                  *cert.NebulaCAPool
 	caPool                  *cert.NebulaCAPool
 	disconnectInvalid       bool
 	disconnectInvalid       bool
 	relayManager            *relayManager
 	relayManager            *relayManager
+	punchy                  *Punchy
 
 
 	ConntrackCacheTimeout time.Duration
 	ConntrackCacheTimeout time.Duration
 	l                     *logrus.Logger
 	l                     *logrus.Logger
@@ -172,7 +173,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	}
 	}
 
 
 	ifce.certState.Store(c.certState)
 	ifce.certState.Store(c.certState)
-	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
+	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
 
 
 	return ifce, nil
 	return ifce, nil
 }
 }

+ 3 - 7
main.go

@@ -213,11 +213,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	*/
 	*/
 
 
 	punchy := NewPunchyFromConfig(l, c)
 	punchy := NewPunchyFromConfig(l, c)
-	if punchy.GetPunch() && !configTest {
-		l.Info("UDP hole punching enabled")
-		go hostMap.Punchy(ctx, udpConns[0])
-	}
-
 	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
 	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
 	switch {
 	switch {
 	case errors.As(err, &util.ContextualError{}):
 	case errors.As(err, &util.ContextualError{}):
@@ -272,8 +267,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		ServeDns:                serveDns,
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
 		HandshakeManager:        handshakeManager,
 		lightHouse:              lightHouse,
 		lightHouse:              lightHouse,
-		checkInterval:           checkInterval,
-		pendingDeletionInterval: pendingDeletionInterval,
+		checkInterval:           time.Second * time.Duration(checkInterval),
+		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
 		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
 		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
 		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		routines:                routines,
 		routines:                routines,
@@ -282,6 +277,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		caPool:                  caPool,
 		caPool:                  caPool,
 		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
+		punchy:                  punchy,
 
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
 		l:                     l,

+ 0 - 3
outside.go

@@ -238,9 +238,6 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
 
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
-	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
-	f.connectionManager.ClearLocalIndex(hostInfo.localIndexId)
-	f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId)
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	if final {
 	if final {
 		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
 		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage

+ 24 - 5
punchy.go

@@ -9,11 +9,12 @@ import (
 )
 )
 
 
 type Punchy struct {
 type Punchy struct {
-	punch        atomic.Bool
-	respond      atomic.Bool
-	delay        atomic.Int64
-	respondDelay atomic.Int64
-	l            *logrus.Logger
+	punch           atomic.Bool
+	respond         atomic.Bool
+	delay           atomic.Int64
+	respondDelay    atomic.Int64
+	punchEverything atomic.Bool
+	l               *logrus.Logger
 }
 }
 
 
 func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
 func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
@@ -38,6 +39,12 @@ func (p *Punchy) reload(c *config.C, initial bool) {
 		}
 		}
 
 
 		p.punch.Store(yes)
 		p.punch.Store(yes)
+		if yes {
+			p.l.Info("punchy enabled")
+		} else {
+			p.l.Info("punchy disabled")
+		}
+
 	} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
 	} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
 		//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
 		//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
 		p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
 		p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
@@ -66,6 +73,14 @@ func (p *Punchy) reload(c *config.C, initial bool) {
 			p.l.Infof("punchy.delay changed to %s", p.GetDelay())
 			p.l.Infof("punchy.delay changed to %s", p.GetDelay())
 		}
 		}
 	}
 	}
+
+	if initial || c.HasChanged("punchy.target_all_remotes") {
+		p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", true))
+		if !initial {
+			p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
+		}
+	}
+
 	if initial || c.HasChanged("punchy.respond_delay") {
 	if initial || c.HasChanged("punchy.respond_delay") {
 		p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
 		p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
 		if !initial {
 		if !initial {
@@ -89,3 +104,7 @@ func (p *Punchy) GetDelay() time.Duration {
 func (p *Punchy) GetRespondDelay() time.Duration {
 func (p *Punchy) GetRespondDelay() time.Duration {
 	return (time.Duration)(p.respondDelay.Load())
 	return (time.Duration)(p.respondDelay.Load())
 }
 }
+
+func (p *Punchy) GetTargetEverything() bool {
+	return p.punchEverything.Load()
+}