Explorar el Código

Use connection manager to drive NAT maintenance (#835)

Co-authored-by: brad-defined <[email protected]>
Nate Brown hace 2 años
padre
commit
ee8e1348e9
Se han modificado 9 ficheros con 213 adiciones y 313 borrados
  1. 130 190
      connection_manager.go
  2. 48 41
      connection_manager_test.go
  3. 3 12
      handshake_ix.go
  4. 1 3
      handshake_manager.go
  5. 0 49
      hostmap.go
  6. 4 3
      interface.go
  7. 3 7
      main.go
  8. 0 3
      outside.go
  9. 24 5
      punchy.go

+ 130 - 190
connection_manager.go

@@ -5,49 +5,55 @@ import (
 	"sync"
 	"time"
 
+	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/udp"
 )
 
-// TODO: incount and outcount are intended as a shortcut to locking the mutexes for every single packet
-// and something like every 10 packets we could lock, send 10, then unlock for a moment
-
 type connectionManager struct {
-	hostMap      *HostMap
-	in           map[uint32]struct{}
-	inLock       *sync.RWMutex
-	out          map[uint32]struct{}
-	outLock      *sync.RWMutex
-	TrafficTimer *LockingTimerWheel[uint32]
-	intf         *Interface
+	in     map[uint32]struct{}
+	inLock *sync.RWMutex
 
-	pendingDeletion      map[uint32]int
-	pendingDeletionLock  *sync.RWMutex
-	pendingDeletionTimer *LockingTimerWheel[uint32]
+	out     map[uint32]struct{}
+	outLock *sync.RWMutex
 
-	checkInterval           int
-	pendingDeletionInterval int
+	hostMap                 *HostMap
+	trafficTimer            *LockingTimerWheel[uint32]
+	intf                    *Interface
+	pendingDeletion         map[uint32]struct{}
+	punchy                  *Punchy
+	checkInterval           time.Duration
+	pendingDeletionInterval time.Duration
+	metricsTxPunchy         metrics.Counter
 
 	l *logrus.Logger
-	// I wanted to call one matLock
 }
 
-func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval int) *connectionManager {
+func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {
+	var max time.Duration
+	if checkInterval < pendingDeletionInterval {
+		max = pendingDeletionInterval
+	} else {
+		max = checkInterval
+	}
+
 	nc := &connectionManager{
 		hostMap:                 intf.hostMap,
 		in:                      make(map[uint32]struct{}),
 		inLock:                  &sync.RWMutex{},
 		out:                     make(map[uint32]struct{}),
 		outLock:                 &sync.RWMutex{},
-		TrafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
+		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
 		intf:                    intf,
-		pendingDeletion:         make(map[uint32]int),
-		pendingDeletionLock:     &sync.RWMutex{},
-		pendingDeletionTimer:    NewLockingTimerWheel[uint32](time.Millisecond*500, time.Second*60),
+		pendingDeletion:         make(map[uint32]struct{}),
 		checkInterval:           checkInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
+		punchy:                  punchy,
+		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),
 		l:                       l,
 	}
+
 	nc.Start(ctx)
 	return nc
 }
@@ -74,65 +80,27 @@ func (n *connectionManager) Out(localIndex uint32) {
 	}
 	n.outLock.RUnlock()
 	n.outLock.Lock()
-	// double check since we dropped the lock temporarily
-	if _, ok := n.out[localIndex]; ok {
-		n.outLock.Unlock()
-		return
-	}
 	n.out[localIndex] = struct{}{}
-	n.AddTrafficWatch(localIndex, n.checkInterval)
 	n.outLock.Unlock()
 }
 
-func (n *connectionManager) CheckIn(localIndex uint32) bool {
-	n.inLock.RLock()
-	if _, ok := n.in[localIndex]; ok {
-		n.inLock.RUnlock()
-		return true
-	}
-	n.inLock.RUnlock()
-	return false
-}
-
-func (n *connectionManager) ClearLocalIndex(localIndex uint32) {
+// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
+// resets the state for this local index
+func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
 	n.inLock.Lock()
 	n.outLock.Lock()
+	_, in := n.in[localIndex]
+	_, out := n.out[localIndex]
 	delete(n.in, localIndex)
 	delete(n.out, localIndex)
 	n.inLock.Unlock()
 	n.outLock.Unlock()
+	return in, out
 }
 
-func (n *connectionManager) ClearPendingDeletion(localIndex uint32) {
-	n.pendingDeletionLock.Lock()
-	delete(n.pendingDeletion, localIndex)
-	n.pendingDeletionLock.Unlock()
-}
-
-func (n *connectionManager) AddPendingDeletion(localIndex uint32) {
-	n.pendingDeletionLock.Lock()
-	if _, ok := n.pendingDeletion[localIndex]; ok {
-		n.pendingDeletion[localIndex] += 1
-	} else {
-		n.pendingDeletion[localIndex] = 0
-	}
-	n.pendingDeletionTimer.Add(localIndex, time.Second*time.Duration(n.pendingDeletionInterval))
-	n.pendingDeletionLock.Unlock()
-}
-
-func (n *connectionManager) checkPendingDeletion(localIndex uint32) bool {
-	n.pendingDeletionLock.RLock()
-	if _, ok := n.pendingDeletion[localIndex]; ok {
-
-		n.pendingDeletionLock.RUnlock()
-		return true
-	}
-	n.pendingDeletionLock.RUnlock()
-	return false
-}
-
-func (n *connectionManager) AddTrafficWatch(localIndex uint32, seconds int) {
-	n.TrafficTimer.Add(localIndex, time.Second*time.Duration(seconds))
+func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
+	n.Out(localIndex)
+	n.trafficTimer.Add(localIndex, n.checkInterval)
 }
 
 func (n *connectionManager) Start(ctx context.Context) {
@@ -140,6 +108,7 @@ func (n *connectionManager) Start(ctx context.Context) {
 }
 
 func (n *connectionManager) Run(ctx context.Context) {
+	//TODO: this tick should be based on the min wheel tick? Check firewall
 	clockSource := time.NewTicker(500 * time.Millisecond)
 	defer clockSource.Stop()
 
@@ -151,151 +120,106 @@ func (n *connectionManager) Run(ctx context.Context) {
 		select {
 		case <-ctx.Done():
 			return
+
 		case now := <-clockSource.C:
-			n.HandleMonitorTick(now, p, nb, out)
-			n.HandleDeletionTick(now)
+			n.trafficTimer.Advance(now)
+			for {
+				localIndex, has := n.trafficTimer.Purge()
+				if !has {
+					break
+				}
+
+				n.doTrafficCheck(localIndex, p, nb, out, now)
+			}
 		}
 	}
 }
 
-func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
-	n.TrafficTimer.Advance(now)
-	for {
-		localIndex, has := n.TrafficTimer.Purge()
-		if !has {
-			break
-		}
+func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
+	hostinfo, err := n.hostMap.QueryIndex(localIndex)
+	if err != nil {
+		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
+		delete(n.pendingDeletion, localIndex)
+		return
+	}
 
-		// Check for traffic coming back in from this host.
-		traf := n.CheckIn(localIndex)
+	if n.handleInvalidCertificate(now, hostinfo) {
+		return
+	}
 
-		hostinfo, err := n.hostMap.QueryIndex(localIndex)
-		if err != nil {
-			n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-			continue
-		}
+	primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
+	mainHostInfo := true
+	if primary != nil && primary != hostinfo {
+		mainHostInfo = false
+	}
 
-		if n.handleInvalidCertificate(now, hostinfo) {
-			continue
-		}
+	// Check for traffic on this hostinfo
+	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)
 
-		// Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
-		// decide if this should be the main hostinfo if we are seeing traffic on it
-		primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
-		mainHostInfo := true
-		if primary != nil && primary != hostinfo {
-			mainHostInfo = false
+	// A hostinfo is determined alive if there is incoming traffic
+	if inTraffic {
+		if n.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(n.l).
+				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
+				Debug("Tunnel status")
 		}
+		delete(n.pendingDeletion, hostinfo.localIndexId)
 
-		// If we saw an incoming packets from this ip and peer's certificate is not
-		// expired, just ignore.
-		if traf {
-			if n.l.Level >= logrus.DebugLevel {
-				hostinfo.logger(n.l).
-					WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
-					Debug("Tunnel status")
+		if !mainHostInfo {
+			if hostinfo.vpnIp > n.intf.myVpnIp {
+				// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
+				// This the primary and prime the old primary hostinfo for testing
+				n.hostMap.MakePrimary(hostinfo)
 			}
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-
-			if !mainHostInfo {
-				if hostinfo.vpnIp > n.intf.myVpnIp {
-					// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
-					// This the primary and prime the old primary hostinfo for testing
-					n.hostMap.MakePrimary(hostinfo)
-					n.Out(primary.localIndexId)
-				} else {
-					// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
-					// Keep tracking so that we can tear it down when it goes away
-					n.Out(hostinfo.localIndexId)
-				}
-			}
-
-			continue
-		}
-
-		if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
-			// Don't probe lighthouses since recv_error should naturally catch this.
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-			continue
 		}
 
-		hostinfo.logger(n.l).
-			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
-			Debug("Tunnel status")
-
-		if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
-			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-			n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
 
-		} else {
-			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if !outTraffic {
+			// Send a punch packet to keep the NAT state alive
+			n.sendPunch(hostinfo)
 		}
-		n.AddPendingDeletion(localIndex)
-	}
-
-}
 
-func (n *connectionManager) HandleDeletionTick(now time.Time) {
-	n.pendingDeletionTimer.Advance(now)
-	for {
-		localIndex, has := n.pendingDeletionTimer.Purge()
-		if !has {
-			break
-		}
+		return
+	}
 
-		hostinfo, mainHostInfo, err := n.hostMap.QueryIndexIsPrimary(localIndex)
-		if err != nil {
-			n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
-			continue
-		}
+	if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
+		// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
+		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
+		return
+	}
 
-		if n.handleInvalidCertificate(now, hostinfo) {
-			continue
-		}
+	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
+		// We have already sent a test packet and nothing was returned, this hostinfo is dead
+		hostinfo.logger(n.l).
+			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
+			Info("Tunnel status")
 
-		// If we saw an incoming packets from this ip and peer's certificate is not
-		// expired, just ignore.
-		traf := n.CheckIn(localIndex)
-		if traf {
-			hostinfo.logger(n.l).
-				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
-				Debug("Tunnel status")
+		n.hostMap.DeleteHostInfo(hostinfo)
+		delete(n.pendingDeletion, hostinfo.localIndexId)
+		return
+	}
 
-			n.ClearLocalIndex(localIndex)
-			n.ClearPendingDeletion(localIndex)
+	hostinfo.logger(n.l).
+		WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
+		Debug("Tunnel status")
 
-			if !mainHostInfo {
-				// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
-				// Keep tracking so that we can tear it down when it goes away
-				n.Out(localIndex)
-			}
-			continue
+	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
+		if n.punchy.GetTargetEverything() {
+			// Maybe the remote is sending us packets but our NAT is blocking it and since we are configured to punch to all
+			// known remotes, go ahead and do that AND send a test packet
+			n.sendPunch(hostinfo)
 		}
 
-		// If it comes around on deletion wheel and hasn't resolved itself, delete
-		if n.checkPendingDeletion(localIndex) {
-			cn := ""
-			if hostinfo.ConnectionState != nil && hostinfo.ConnectionState.peerCert != nil {
-				cn = hostinfo.ConnectionState.peerCert.Details.Name
-			}
-
-			hostinfo.logger(n.l).
-				WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
-				WithField("certName", cn).
-				Info("Tunnel status")
+		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
+		n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
 
-			n.hostMap.DeleteHostInfo(hostinfo)
-		}
-
-		n.ClearLocalIndex(localIndex)
-		n.ClearPendingDeletion(localIndex)
+	} else {
+		hostinfo.logger(n.l).Debugf("Hostinfo sadness")
 	}
+
+	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
+	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
 }
 
 // handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
@@ -322,8 +246,24 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
 	// Inform the remote and close the tunnel locally
 	n.intf.sendCloseTunnel(hostinfo)
 	n.intf.closeTunnel(hostinfo)
-
-	n.ClearLocalIndex(hostinfo.localIndexId)
-	n.ClearPendingDeletion(hostinfo.localIndexId)
+	delete(n.pendingDeletion, hostinfo.localIndexId)
 	return true
 }
+
+func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
+	if !n.punchy.GetPunch() {
+		// Punching is disabled
+		return
+	}
+
+	if n.punchy.GetTargetEverything() {
+		hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
+			n.metricsTxPunchy.Inc(1)
+			n.intf.outside.WriteTo([]byte{1}, addr)
+		})
+
+	} else if hostinfo.remote != nil {
+		n.metricsTxPunchy.Inc(1)
+		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
+	}
+}

+ 48 - 41
connection_manager_test.go

@@ -10,6 +10,7 @@ import (
 
 	"github.com/flynn/noise"
 	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
 	"github.com/slackhq/nebula/udp"
@@ -60,16 +61,16 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		l:                l,
 	}
 	ifce.certState.Store(cs)
-	now := time.Now()
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
-	nc := newConnectionManager(ctx, l, ifce, 5, 10)
+	punchy := NewPunchyFromConfig(l, config.NewC(l))
+	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
-	nc.HandleMonitorTick(now, p, nb, out)
+
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 		vpnIp:         vpnIp,
@@ -84,26 +85,28 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
+	nc.In(hostinfo.localIndexId)
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	// Move ahead 5s. Nothing should happen
-	next_tick := now.Add(5 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// Move ahead 6s. We haven't heard back
-	next_tick = now.Add(6 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// This host should now be up for deletion
+	assert.Contains(t, nc.out, hostinfo.localIndexId)
+
+	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+
+	// Do another traffic check tick, this host should be pending deletion now
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	// Move ahead some more
-	next_tick = now.Add(45 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// The host should be evicted
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+
+	// Do a final traffic check tick, the host should now be removed
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
 	assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 	assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
@@ -136,16 +139,16 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		l:                l,
 	}
 	ifce.certState.Store(cs)
-	now := time.Now()
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
-	nc := newConnectionManager(ctx, l, ifce, 5, 10)
+	punchy := NewPunchyFromConfig(l, config.NewC(l))
+	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	p := []byte("")
 	nb := make([]byte, 12, 12)
 	out := make([]byte, mtu)
-	nc.HandleMonitorTick(now, p, nb, out)
+
 	// Add an ip we have established a connection w/ to hostmap
 	hostinfo := &HostInfo{
 		vpnIp:         vpnIp,
@@ -160,30 +163,33 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
-	assert.NotContains(t, nc.pendingDeletion, vpnIp)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
-	// Move ahead 5s. Nothing should happen
-	next_tick := now.Add(5 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// Move ahead 6s. We haven't heard back
-	next_tick = now.Add(6 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// This host should now be up for deletion
+	nc.In(hostinfo.localIndexId)
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+
+	// Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
+	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
+
+	// Do another traffic check tick, this host should be pending deletion now
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.Contains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, vpnIp)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
-	// We heard back this time
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+
+	// We saw traffic, should no longer be pending deletion
 	nc.In(hostinfo.localIndexId)
-	// Move ahead some more
-	next_tick = now.Add(45 * time.Second)
-	nc.HandleMonitorTick(next_tick, p, nb, out)
-	nc.HandleDeletionTick(next_tick)
-	// The host should not be evicted
+	nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now())
 	assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId)
-	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
+	assert.NotContains(t, nc.out, hostinfo.localIndexId)
+	assert.NotContains(t, nc.in, hostinfo.localIndexId)
 	assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId)
+	assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp)
 }
 
 // Check if we can disconnect the peer.
@@ -257,7 +263,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
-	nc := newConnectionManager(ctx, l, ifce, 5, 10)
+	punchy := NewPunchyFromConfig(l, config.NewC(l))
+	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	ifce.connectionManager = nc
 	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
 	hostinfo.ConnectionState = &ConnectionState{

+ 3 - 12
handshake_ix.go

@@ -332,12 +332,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 			Info("Handshake message sent")
 	}
 
-	if existing != nil {
-		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
-		f.connectionManager.Out(existing.localIndexId)
-	}
-
-	f.connectionManager.Out(hostinfo.localIndexId)
+	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 
 	return
@@ -495,12 +490,8 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
-	existing := f.handshakeManager.Complete(hostinfo, f)
-	if existing != nil {
-		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
-		f.connectionManager.Out(existing.localIndexId)
-	}
-
+	f.handshakeManager.Complete(hostinfo, f)
+	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	f.metricHandshakes.Update(duration)
 

+ 1 - 3
handshake_manager.go

@@ -380,7 +380,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // Complete is a simpler version of CheckAndComplete when we already know we
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap. An existing hostinfo is returned if there was one.
-func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo {
+func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	c.pendingHostMap.Lock()
 	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
@@ -395,11 +395,9 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo
 			Info("New host shadows existing host remoteIndex")
 	}
 
-	existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
 	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
 	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
-	return existingHostInfo
 }
 
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo

+ 0 - 49
hostmap.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"context"
 	"errors"
 	"fmt"
 	"net"
@@ -621,54 +620,6 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 }
 
-// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
-// The caller can then do the its work outside of the read lock
-func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
-	hm.RLock()
-	defer hm.RUnlock()
-
-	for _, v := range hm.Hosts {
-		if v.remotes != nil {
-			rl = append(rl, v.remotes)
-		}
-	}
-	return rl
-}
-
-// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
-func (hm *HostMap) Punchy(ctx context.Context, conn *udp.Conn) {
-	var metricsTxPunchy metrics.Counter
-	if hm.metricsEnabled {
-		metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
-	} else {
-		metricsTxPunchy = metrics.NilCounter{}
-	}
-
-	var remotes []*RemoteList
-	b := []byte{1}
-
-	clockSource := time.NewTicker(time.Second * 10)
-	defer clockSource.Stop()
-
-	for {
-		remotes = hm.punchList(remotes[:0])
-		for _, rl := range remotes {
-			//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
-			for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
-				metricsTxPunchy.Inc(1)
-				conn.WriteTo(b, addr)
-			}
-		}
-
-		select {
-		case <-ctx.Done():
-			return
-		case <-clockSource.C:
-			continue
-		}
-	}
-}
-
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {

+ 4 - 3
interface.go

@@ -33,8 +33,8 @@ type InterfaceConfig struct {
 	ServeDns                bool
 	HandshakeManager        *HandshakeManager
 	lightHouse              *LightHouse
-	checkInterval           int
-	pendingDeletionInterval int
+	checkInterval           time.Duration
+	pendingDeletionInterval time.Duration
 	DropLocalBroadcast      bool
 	DropMulticast           bool
 	routines                int
@@ -43,6 +43,7 @@ type InterfaceConfig struct {
 	caPool                  *cert.NebulaCAPool
 	disconnectInvalid       bool
 	relayManager            *relayManager
+	punchy                  *Punchy
 
 	ConntrackCacheTimeout time.Duration
 	l                     *logrus.Logger
@@ -172,7 +173,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	}
 
 	ifce.certState.Store(c.certState)
-	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
+	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
 
 	return ifce, nil
 }

+ 3 - 7
main.go

@@ -213,11 +213,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	*/
 
 	punchy := NewPunchyFromConfig(l, c)
-	if punchy.GetPunch() && !configTest {
-		l.Info("UDP hole punching enabled")
-		go hostMap.Punchy(ctx, udpConns[0])
-	}
-
 	lightHouse, err := NewLightHouseFromConfig(l, c, tunCidr, udpConns[0], punchy)
 	switch {
 	case errors.As(err, &util.ContextualError{}):
@@ -272,8 +267,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		ServeDns:                serveDns,
 		HandshakeManager:        handshakeManager,
 		lightHouse:              lightHouse,
-		checkInterval:           checkInterval,
-		pendingDeletionInterval: pendingDeletionInterval,
+		checkInterval:           time.Second * time.Duration(checkInterval),
+		pendingDeletionInterval: time.Second * time.Duration(pendingDeletionInterval),
 		DropLocalBroadcast:      c.GetBool("tun.drop_local_broadcast", false),
 		DropMulticast:           c.GetBool("tun.drop_multicast", false),
 		routines:                routines,
@@ -282,6 +277,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		caPool:                  caPool,
 		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
+		punchy:                  punchy,
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,

+ 0 - 3
outside.go

@@ -238,9 +238,6 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via interface{}, out []by
 
 // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
-	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
-	f.connectionManager.ClearLocalIndex(hostInfo.localIndexId)
-	f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId)
 	final := f.hostMap.DeleteHostInfo(hostInfo)
 	if final {
 		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage

+ 24 - 5
punchy.go

@@ -9,11 +9,12 @@ import (
 )
 
 type Punchy struct {
-	punch        atomic.Bool
-	respond      atomic.Bool
-	delay        atomic.Int64
-	respondDelay atomic.Int64
-	l            *logrus.Logger
+	punch           atomic.Bool
+	respond         atomic.Bool
+	delay           atomic.Int64
+	respondDelay    atomic.Int64
+	punchEverything atomic.Bool
+	l               *logrus.Logger
 }
 
 func NewPunchyFromConfig(l *logrus.Logger, c *config.C) *Punchy {
@@ -38,6 +39,12 @@ func (p *Punchy) reload(c *config.C, initial bool) {
 		}
 
 		p.punch.Store(yes)
+		if yes {
+			p.l.Info("punchy enabled")
+		} else {
+			p.l.Info("punchy disabled")
+		}
+
 	} else if c.HasChanged("punchy.punch") || c.HasChanged("punchy") {
 		//TODO: it should be relatively easy to support this, just need to be able to cancel the goroutine and boot it up from here
 		p.l.Warn("Changing punchy.punch with reload is not supported, ignoring.")
@@ -66,6 +73,14 @@ func (p *Punchy) reload(c *config.C, initial bool) {
 			p.l.Infof("punchy.delay changed to %s", p.GetDelay())
 		}
 	}
+
+	if initial || c.HasChanged("punchy.target_all_remotes") {
+		p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", true))
+		if !initial {
+			p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
+		}
+	}
+
 	if initial || c.HasChanged("punchy.respond_delay") {
 		p.respondDelay.Store((int64)(c.GetDuration("punchy.respond_delay", 5*time.Second)))
 		if !initial {
@@ -89,3 +104,7 @@ func (p *Punchy) GetDelay() time.Duration {
 func (p *Punchy) GetRespondDelay() time.Duration {
 	return (time.Duration)(p.respondDelay.Load())
 }
+
+func (p *Punchy) GetTargetEverything() bool {
+	return p.punchEverything.Load()
+}