瀏覽代碼

Remove handshake race avoidance (#820)

Co-authored-by: Wade Simmons <[email protected]>
Nate Brown 2 年之前
父節點
當前提交
92cc32f844
共有 18 個文件被更改,包括 741 次插入157 次删除
  1. 23 1
      connection_manager.go
  2. 2 2
      connection_manager_test.go
  3. 27 5
      control.go
  4. 247 51
      e2e/handshakes_test.go
  5. 4 4
      e2e/helpers_test.go
  6. 32 2
      e2e/router/hostmap.go
  7. 33 1
      e2e/router/router.go
  8. 1 0
      go.mod
  9. 2 0
      go.sum
  10. 13 15
      handshake_ix.go
  11. 26 49
      handshake_manager.go
  12. 101 19
      hostmap.go
  13. 206 0
      hostmap_test.go
  14. 5 3
      outside.go
  15. 1 1
      overlay/tun_tester.go
  16. 5 0
      relay_manager.go
  17. 12 3
      ssh.go
  18. 1 1
      udp/udp_tester.go

+ 23 - 1
connection_manager.go

@@ -181,6 +181,14 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			continue
 		}
 
+		// Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
+		// decide if this should be the main hostinfo if we are seeing traffic on it
+		primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
+		mainHostInfo := true
+		if primary != nil && primary != hostinfo {
+			mainHostInfo = false
+		}
+
 		// If we saw an incoming packets from this ip and peer's certificate is not
 		// expired, just ignore.
 		if traf {
@@ -191,6 +199,20 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			}
 			n.ClearLocalIndex(localIndex)
 			n.ClearPendingDeletion(localIndex)
+
+			if !mainHostInfo {
+				if hostinfo.vpnIp > n.intf.myVpnIp {
+					// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
+					// This the primary and prime the old primary hostinfo for testing
+					n.hostMap.MakePrimary(hostinfo)
+					n.Out(primary.localIndexId)
+				} else {
+					// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
+					// Keep tracking so that we can tear it down when it goes away
+					n.Out(hostinfo.localIndexId)
+				}
+			}
+
 			continue
 		}
 
@@ -198,7 +220,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 			Debug("Tunnel status")
 
-		if hostinfo != nil && hostinfo.ConnectionState != nil {
+		if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
 

+ 2 - 2
connection_manager_test.go

@@ -80,7 +80,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		certState: cs,
 		H:         &noise.HandshakeState{},
 	}
-	nc.hostMap.addHostInfo(hostinfo, ifce)
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
@@ -156,7 +156,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		certState: cs,
 		H:         &noise.HandshakeState{},
 	}
-	nc.hostMap.addHostInfo(hostinfo, ifce)
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)

+ 27 - 5
control.go

@@ -95,12 +95,21 @@ func (c *Control) RebindUDPServer() {
 	c.f.rebindCount++
 }
 
-// ListHostmap returns details about the actual or pending (handshaking) hostmap
-func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
+// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
+func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
 	if pendingMap {
-		return listHostMap(c.f.handshakeManager.pendingHostMap)
+		return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
 	} else {
-		return listHostMap(c.f.hostMap)
+		return listHostMapHosts(c.f.hostMap)
+	}
+}
+
+// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
+func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
+	if pendingMap {
+		return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
+	} else {
+		return listHostMapIndexes(c.f.hostMap)
 	}
 }
 
@@ -232,7 +241,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	return chi
 }
 
-func listHostMap(hm *HostMap) []ControlHostInfo {
+func listHostMapHosts(hm *HostMap) []ControlHostInfo {
 	hm.RLock()
 	hosts := make([]ControlHostInfo, len(hm.Hosts))
 	i := 0
@@ -244,3 +253,16 @@ func listHostMap(hm *HostMap) []ControlHostInfo {
 
 	return hosts
 }
+
+func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
+	hm.RLock()
+	hosts := make([]ControlHostInfo, len(hm.Indexes))
+	i := 0
+	for _, v := range hm.Indexes {
+		hosts[i] = copyHostInfo(v, hm.preferredRanges)
+		i++
+	}
+	hm.RUnlock()
+
+	return hosts
+}

+ 247 - 51
e2e/handshakes_test.go

@@ -19,10 +19,10 @@ import (
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 	// Start the servers
 	myControl.Start()
@@ -32,7 +32,7 @@ func BenchmarkHotPath(b *testing.B) {
 	r.CancelFlowLogs()
 
 	for n := 0; n < b.N; n++ {
-		myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+		myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
 	}
 
@@ -42,18 +42,18 @@ func BenchmarkHotPath(b *testing.B) {
 
 func TestGoodHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 	// Start the servers
 	myControl.Start()
 	theirControl.Start()
 
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -74,16 +74,16 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.WaitForType(1, 0, theirControl)
 
 	t.Log("Make sure our host infos are correct")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
 
 	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 
 	t.Log("Do a bidirectional tunnel test")
 	r := router.NewR(t, myControl, theirControl)
 	defer r.RenderFlow()
-	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
@@ -97,15 +97,15 @@ func TestWrongResponderHandshake(t *testing.T) {
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
 	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
 
 	// Add their real udp addr, which should be tried after evil.
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
-	myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl, evilControl)
@@ -117,7 +117,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 	evilControl.Start()
 
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 		h := &header.H{}
 		err := h.Parse(p.Data)
@@ -136,18 +136,18 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 
 	t.Log("Test the tunnel with them")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
-	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
 
 	t.Log("Flush all packets from all controllers")
 	r.FlushAll()
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 	//TODO: assert hostmaps for everyone
@@ -157,14 +157,17 @@ func TestWrongResponderHandshake(t *testing.T) {
 	theirControl.Stop()
 }
 
-func Test_Case1_Stage1Race(t *testing.T) {
+func TestStage1Race(t *testing.T) {
+	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
+	// But will eventually collapse down to a single tunnel
+
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
@@ -175,8 +178,8 @@ func Test_Case1_Stage1Race(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Trigger a handshake to start on both me and them")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
 
 	t.Log("Get both stage 1 handshake packets")
 	myHsForThem := myControl.GetFromUDP(true)
@@ -185,44 +188,165 @@ func Test_Case1_Stage1Race(t *testing.T) {
 	r.Log("Now inject both stage 1 handshake packets")
 	r.InjectUDPPacket(theirControl, myControl, theirHsForMe)
 	r.InjectUDPPacket(myControl, theirControl, myHsForThem)
-	//TODO: they should win, grab their index for me and make sure I use it in the end.
 
-	r.Log("They should not have a stage 2 (won the race) but I should send one")
-	r.InjectUDPPacket(myControl, theirControl, myControl.GetFromUDP(true))
+	r.Log("Route until they receive a message packet")
+	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 
-	r.Log("Route for me until I send a message packet to them")
-	r.RouteForAllUntilAfterMsgTypeTo(theirControl, header.Message, header.MessageNone)
+	r.Log("Their cached packet should be received by me")
+	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
 
-	t.Log("My cached packet should be received by them")
-	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+	r.Log("Do a bidirectional tunnel test")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
 
-	t.Log("Route for them until I send a message packet to me")
-	theirControl.WaitForType(1, 0, myControl)
+	myHostmapHosts := myControl.ListHostmapHosts(false)
+	myHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
-	t.Log("Their cached packet should be received by me")
-	theirCachedPacket := myControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
+	// We should have two tunnels on both sides
+	assert.Len(t, myHostmapHosts, 1)
+	assert.Len(t, theirHostmapHosts, 1)
+	assert.Len(t, myHostmapIndexes, 2)
+	assert.Len(t, theirHostmapIndexes, 2)
 
-	t.Log("Do a bidirectional tunnel test")
-	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	r.Log("Spin until connection manager tears down a tunnel")
+
+	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		t.Log("Connection manager hasn't ticked yet")
+		time.Sleep(time.Second)
+	}
+
+	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
+	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
+
+	// We should only have a single tunnel now on both sides
+	assert.Len(t, myFinalHostmapHosts, 1)
+	assert.Len(t, theirFinalHostmapHosts, 1)
+	assert.Len(t, myFinalHostmapIndexes, 1)
+	assert.Len(t, theirFinalHostmapIndexes, 1)
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
 	theirControl.Stop()
-	//TODO: assert hostmaps
+}
+
+func TestUncleanShutdownRaceLoser(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r.Log("Trigger a handshake from me to them")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+
+	r.Log("Nuke my hostmap")
+	myHostmap := myControl.GetHostmap()
+	myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
+	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
+
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
+	p = r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+
+	r.Log("Assert the tunnel works")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	r.Log("Wait for the dead index to go away")
+	start := len(theirControl.GetHostmap().Indexes)
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		if len(theirControl.GetHostmap().Indexes) < start {
+			break
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+}
+
+func TestUncleanShutdownRaceWinner(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r.Log("Trigger a handshake from me to them")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	r.Log("Nuke my hostmap")
+	theirHostmap := theirControl.GetHostmap()
+	theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
+	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
+
+	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
+	p = r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
+
+	r.Log("Assert the tunnel works")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	r.Log("Wait for the dead index to go away")
+	start := len(myControl.GetHostmap().Indexes)
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		if len(myControl.GetHostmap().Indexes) < start {
+			break
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 }
 
 func TestRelays(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIp, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIp, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIp, []net.IP{relayVpnIp})
-	relayControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -234,12 +358,84 @@ func TestRelays(t *testing.T) {
 	theirControl.Start()
 
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIp, theirVpnIp, 80, 80)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 }
 
+func TestStage1RaceRelays(t *testing.T) {
+	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	r.Log("Trigger a handshake to start on both me and relay")
+	myControl.InjectTunUDPPacket(relayVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	relayControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from relay"))
+
+	r.Log("Get both stage 1 handshake packets")
+	//TODO: this is where it breaks, we need to get the hs packets for the relay not for the destination
+	myHsForThem := myControl.GetFromUDP(true)
+	relayHsForMe := relayControl.GetFromUDP(true)
+
+	r.Log("Now inject both stage 1 handshake packets")
+	r.InjectUDPPacket(relayControl, myControl, relayHsForMe)
+	r.InjectUDPPacket(myControl, relayControl, myHsForThem)
+
+	r.Log("Route for me until I send a message packet to relay")
+	r.RouteForAllUntilAfterMsgTypeTo(relayControl, header.Message, header.MessageNone)
+
+	r.Log("My cached packet should be received by relay")
+	myCachedPacket := relayControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, relayVpnIpNet.IP, 80, 80)
+
+	r.Log("Relays cached packet should be received by me")
+	relayCachedPacket := r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from relay"), relayCachedPacket, relayVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+
+	r.Log("Do a bidirectional tunnel test; me and relay")
+	assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+
+	r.Log("Create a tunnel between relay and them")
+	assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+
+	r.RenderHostmaps("Starting hostmaps", myControl, relayControl, theirControl)
+
+	r.Log("Trigger a handshake to start from me to them via the relay")
+	//TODO: if we initiate a handshake from me and then assert the tunnel it will cause a relay control race that can blow up
+	//	this is a problem that exists on master today
+	//myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+	relayControl.Stop()
+	//
+	////TODO: assert hostmaps
+}
+
 //TODO: add a test with many lies

+ 4 - 4
e2e/helpers_test.go

@@ -30,7 +30,7 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, net.IP, *net.UDPAddr) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) {
 	l := NewTestLogger()
 
 	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
@@ -101,7 +101,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 	}
 
-	return control, vpnIpNet.IP, &udpAddr
+	return control, vpnIpNet, &udpAddr
 }
 
 // newTestCaCert will generate a CA cert
@@ -231,12 +231,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
-	bPacket := r.RouteUntilTxTun(controlB, controlA)
+	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 	// And once more from me to them
 	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
-	aPacket := r.RouteUntilTxTun(controlA, controlB)
+	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 

+ 32 - 2
e2e/router/hostmap.go

@@ -5,9 +5,11 @@ package router
 
 import (
 	"fmt"
+	"sort"
 	"strings"
 
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/iputil"
 )
 
 type edge struct {
@@ -64,7 +66,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 
 	// Draw the vpn to index nodes
 	r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName)
-	for vpnIp, hi := range hm.Hosts {
+	for _, vpnIp := range sortedHosts(hm.Hosts) {
+		hi := hm.Hosts[vpnIp]
 		r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp)
 		lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex()))
 
@@ -91,7 +94,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 
 	// Draw the local index to relay or remote index nodes
 	r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName)
-	for idx, hi := range hm.Indexes {
+	for _, idx := range sortedIndexes(hm.Indexes) {
+		hi := hm.Indexes[idx]
 		r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
 		remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
 		globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
@@ -107,3 +111,29 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	r += "\tend\n"
 	return r, globalLines
 }
+
+func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
+	keys := make([]iputil.VpnIp, 0, len(hosts))
+	for key := range hosts {
+		keys = append(keys, key)
+	}
+
+	sort.SliceStable(keys, func(i, j int) bool {
+		return keys[i] > keys[j]
+	})
+
+	return keys
+}
+
+func sortedIndexes(indexes map[uint32]*nebula.HostInfo) []uint32 {
+	keys := make([]uint32, 0, len(indexes))
+	for key := range indexes {
+		keys = append(keys, key)
+	}
+
+	sort.SliceStable(keys, func(i, j int) bool {
+		return keys[i] > keys[j]
+	})
+
+	return keys
+}

+ 33 - 1
e2e/router/router.go

@@ -10,6 +10,7 @@ import (
 	"os"
 	"path/filepath"
 	"reflect"
+	"sort"
 	"strconv"
 	"strings"
 	"sync"
@@ -22,6 +23,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
+	"golang.org/x/exp/maps"
 )
 
 type R struct {
@@ -150,6 +152,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 			case <-ctx.Done():
 				return
 			case <-clockSource.C:
+				r.renderHostmaps("clock tick")
 				r.renderFlow()
 			}
 		}
@@ -220,11 +223,16 @@ func (r *R) renderFlow() {
 		)
 	}
 
+	if len(participantsVals) > 2 {
+		// Get the first and last participantVals for notes
+		participantsVals = []string{participantsVals[0], participantsVals[len(participantsVals)-1]}
+	}
+
 	// Print packets
 	h := &header.H{}
 	for _, e := range r.flow {
 		if e.packet == nil {
-			fmt.Fprintf(f, "    note over %s: %s\n", strings.Join(participantsVals, ", "), e.note)
+			//fmt.Fprintf(f, "    note over %s: %s\n", strings.Join(participantsVals, ", "), e.note)
 			continue
 		}
 
@@ -294,6 +302,28 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 	})
 }
 
+func (r *R) renderHostmaps(title string) {
+	c := maps.Values(r.controls)
+	sort.SliceStable(c, func(i, j int) bool {
+		return c[i].GetVpnIp() > c[j].GetVpnIp()
+	})
+
+	s := renderHostmaps(c...)
+	if len(r.additionalGraphs) > 0 {
+		lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1]
+		if lastGraph.content == s {
+			// Ignore this rendering if it matches the last rendering added
+			// This is useful if you want to track rendering changes
+			return
+		}
+	}
+
+	r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{
+		title:   title,
+		content: s,
+	})
+}
+
 // InjectFlow can be used to record packet flow if the test is handling the routing on its own.
 // The packet is assumed to have been received
 func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) {
@@ -332,6 +362,8 @@ func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool
 		return nil
 	}
 
+	r.renderHostmaps(fmt.Sprintf("Packet %v", len(r.flow)))
+
 	if len(r.ignoreFlows) > 0 {
 		var h header.H
 		err := h.Parse(p.Data)

+ 1 - 0
go.mod

@@ -21,6 +21,7 @@ require (
 	github.com/stretchr/testify v1.8.1
 	github.com/vishvananda/netlink v1.1.0
 	golang.org/x/crypto v0.3.0
+	golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
 	golang.org/x/net v0.2.0
 	golang.org/x/sys v0.2.0
 	golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224

+ 2 - 0
go.sum

@@ -266,6 +266,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
 golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
+golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI=
+golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
 golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=

+ 13 - 15
handshake_ix.go

@@ -207,9 +207,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 	hostinfo.SetRemote(addr)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
-	// Only overwrite existing record if we should win the handshake race
-	overwrite := vpnIp > f.myVpnIp
-	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
+	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
 		switch err {
 		case ErrAlreadySeen:
@@ -280,16 +278,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
-		case ErrExistingHandshake:
-			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
-				WithField("certName", certName).
-				WithField("fingerprint", fingerprint).
-				WithField("issuer", issuer).
-				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				Error("Prevented a pending handshake race")
-			return
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
@@ -344,6 +332,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 			Info("Handshake message sent")
 	}
 
+	if existing != nil {
+		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
+		f.connectionManager.Out(existing.localIndexId)
+	}
+
+	f.connectionManager.Out(hostinfo.localIndexId)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 
 	return
@@ -501,8 +495,12 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
-	//TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok?
-	f.handshakeManager.Complete(hostinfo, f)
+	existing := f.handshakeManager.Complete(hostinfo, f)
+	if existing != nil {
+		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
+		f.connectionManager.Out(existing.localIndexId)
+	}
+
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	f.metricHandshakes.Update(duration)
 

+ 26 - 49
handshake_manager.go

@@ -53,6 +53,10 @@ type HandshakeManager struct {
 	metricTimedOut         metrics.Counter
 	l                      *logrus.Logger
 
+	// vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead
+	// this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking
+	vpnIps map[iputil.VpnIp]struct{}
+
 	// can be used to trigger outbound handshake for the given vpnIp
 	trigger chan iputil.VpnIp
 }
@@ -66,6 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 		config:                 config,
 		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
 		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		vpnIps:                 map[iputil.VpnIp]struct{}{},
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -103,6 +108,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
 func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
 	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	if err != nil {
+		delete(c.vpnIps, vpnIp)
 		return
 	}
 	hostinfo.Lock()
@@ -160,7 +166,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
 		c.lightHouse.QueryServer(vpnIp, f)
 	}
 
-	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
+	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
 	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
 		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
@@ -260,7 +266,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
 
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
-		//TODO: feel like we dupe handshake real fast in a tight loop, why?
 		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 }
@@ -269,7 +274,10 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
 	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
 
 	if created {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+		if _, ok := c.vpnIps[vpnIp]; !ok {
+			c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+		}
+		c.vpnIps[vpnIp] = struct{}{}
 		c.metricInitiated.Inc(1)
 	}
 
@@ -280,7 +288,6 @@ var (
 	ErrExistingHostInfo    = errors.New("existing hostinfo")
 	ErrAlreadySeen         = errors.New("already seen")
 	ErrLocalIndexCollision = errors.New("local index collision")
-	ErrExistingHandshake   = errors.New("existing handshake")
 )
 
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
@@ -294,7 +301,7 @@ var (
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
+func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
 	c.pendingHostMap.Lock()
 	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
@@ -303,9 +310,14 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	// Check if we already have a tunnel with this vpn ip
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	if found && existingHostInfo != nil {
-		// Is it just a delayed handshake packet?
-		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
-			return existingHostInfo, ErrAlreadySeen
+		testHostInfo := existingHostInfo
+		for testHostInfo != nil {
+			// Is it just a delayed handshake packet?
+			if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
+				return existingHostInfo, ErrAlreadySeen
+			}
+
+			testHostInfo = testHostInfo.next
 		}
 
 		// Is this a newer handshake?
@@ -337,56 +349,19 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			Info("New host shadows existing host remoteIndex")
 	}
 
-	// Check if we are also handshaking with this vpn ip
-	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
-	if found && pendingHostInfo != nil {
-		if !overwrite {
-			// We won, let our pending handshake win
-			return pendingHostInfo, ErrExistingHandshake
-		}
-
-		// We lost, take this handshake and move any cached packets over so they get sent
-		pendingHostInfo.ConnectionState.queueLock.Lock()
-		hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
-		c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
-		pendingHostInfo.ConnectionState.queueLock.Unlock()
-		pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
-	}
-
-	if existingHostInfo != nil {
-		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
-		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
-		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
-		for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() {
-			delete(c.mainHostMap.Relays, relayIdx)
-		}
-	}
-
-	c.mainHostMap.addHostInfo(hostinfo, f)
+	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 }
 
 // Complete is a simpler version of CheckAndComplete when we already know we
 // won't have a localIndexId collision because we already have an entry in the
-// pendingHostMap
-func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
+// pendingHostMap. An existing hostinfo is returned if there was one.
+func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo {
 	c.pendingHostMap.Lock()
 	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
-	if found && existingHostInfo != nil {
-		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
-		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
-		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
-		for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() {
-			delete(c.mainHostMap.Relays, relayIdx)
-		}
-	}
-
 	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
 	if found && existingRemoteIndex != nil {
 		// We have a collision, but this can happen since we can't control
@@ -396,8 +371,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 			Info("New host shadows existing host remoteIndex")
 	}
 
-	c.mainHostMap.addHostInfo(hostinfo, f)
+	existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
+	return existingHostInfo
 }
 
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo

+ 101 - 19
hostmap.go

@@ -23,6 +23,10 @@ const PromoteEvery = 1000
 const ReQueryEvery = 5000
 const MaxRemotes = 10
 
+// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
+// 5 allows for an initial handshake and each host pair re-handshaking twice
+const MaxHostInfosPerVpnIp = 5
+
 // How long we should prevent roaming back to the previous IP.
 // This helps prevent flapping due to packets already in flight
 const RoamingSuppressSeconds = 2
@@ -180,6 +184,10 @@ type HostInfo struct {
 
 	lastRoam       time.Time
 	lastRoamRemote *udp.Addr
+
+	// Used to track other hostinfos for this vpn ip since only 1 can be primary
+	// Synchronised via hostmap lock and not the hostinfo lock.
+	next, prev *HostInfo
 }
 
 type ViaSender struct {
@@ -395,9 +403,12 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 	}
 }
 
-func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
+// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
+func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
 	hm.Lock()
+	// If we have a previous or next hostinfo then we are not the last one for this vpn ip
+	final := (hostinfo.next == nil && hostinfo.prev == nil)
 	hm.unlockedDeleteHostInfo(hostinfo)
 	hm.Unlock()
 
@@ -421,6 +432,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	for _, localIdx := range teardownRelayIdx {
 		hm.RemoveRelay(localIdx)
 	}
+
+	return final
 }
 
 func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
@@ -429,29 +442,81 @@ func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
 	delete(hm.RemoteIndexes, localIdx)
 }
 
+func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedMakePrimary(hostinfo)
+}
+
+func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
+	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	if oldHostinfo == hostinfo {
+		return
+	}
+
+	if hostinfo.prev != nil {
+		hostinfo.prev.next = hostinfo.next
+	}
+
+	if hostinfo.next != nil {
+		hostinfo.next.prev = hostinfo.prev
+	}
+
+	hm.Hosts[hostinfo.vpnIp] = hostinfo
+
+	if oldHostinfo == nil {
+		return
+	}
+
+	hostinfo.next = oldHostinfo
+	oldHostinfo.prev = hostinfo
+	hostinfo.prev = nil
+}
+
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	// Check if this same hostId is in the hostmap with a different instance.
-	// This could happen if we have an entry in the pending hostmap with different
-	// index values than the one in the main hostmap.
-	hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
-	if ok && hostinfo2 != hostinfo {
-		delete(hm.Hosts, hostinfo2.vpnIp)
-		delete(hm.Indexes, hostinfo2.localIndexId)
-		delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
+	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	if ok && primary == hostinfo {
+		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, hostinfo.vpnIp)
+		if len(hm.Hosts) == 0 {
+			hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+		}
+
+		if hostinfo.next != nil {
+			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
+			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// It is primary, there is no previous hostinfo now
+			hostinfo.next.prev = nil
+		}
+
+	} else {
+		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		if hostinfo.prev != nil {
+			hostinfo.prev.next = hostinfo.next
+		}
+
+		if hostinfo.next != nil {
+			hostinfo.next.prev = hostinfo.prev
+		}
 	}
 
-	delete(hm.Hosts, hostinfo.vpnIp)
-	if len(hm.Hosts) == 0 {
-		hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+	hostinfo.next = nil
+	hostinfo.prev = nil
+
+	// The remote index uses index ids outside our control so lets make sure we are only removing
+	// the remote index pointer here if it points to the hostinfo we are deleting
+	hostinfo2, ok := hm.RemoteIndexes[hostinfo.remoteIndexId]
+	if ok && hostinfo2 == hostinfo {
+		delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
+		if len(hm.RemoteIndexes) == 0 {
+			hm.RemoteIndexes = map[uint32]*HostInfo{}
+		}
 	}
+
 	delete(hm.Indexes, hostinfo.localIndexId)
 	if len(hm.Indexes) == 0 {
 		hm.Indexes = map[uint32]*HostInfo{}
 	}
-	delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
-	if len(hm.RemoteIndexes) == 0 {
-		hm.RemoteIndexes = map[uint32]*HostInfo{}
-	}
 
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@@ -520,15 +585,22 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
 	return nil, errors.New("unable to find host")
 }
 
-// We already have the hm Lock when this is called, so make sure to not call
-// any other methods that might try to grab it again
-func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
+// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
+// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
+func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 	}
 
+	existing := hm.Hosts[hostinfo.vpnIp]
 	hm.Hosts[hostinfo.vpnIp] = hostinfo
+
+	if existing != nil {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
+
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
@@ -537,6 +609,16 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 	}
+
+	i := 1
+	check := hostinfo
+	for check != nil {
+		if i > MaxHostInfosPerVpnIp {
+			hm.unlockedDeleteHostInfo(check)
+		}
+		check = check.next
+		i++
+	}
 }
 
 // punchList assembles a list of all non nil RemoteList pointer entries in this hostmap

+ 206 - 0
hostmap_test.go

@@ -1 +1,207 @@
 package nebula
+
+import (
+	"net"
+	"testing"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestHostMap_MakePrimary(t *testing.T) {
+	l := test.NewLogger()
+	hm := NewHostMap(
+		l, "test",
+		&net.IPNet{
+			IP:   net.IP{10, 0, 0, 1},
+			Mask: net.IPMask{255, 255, 255, 0},
+		},
+		[]*net.IPNet{},
+	)
+
+	f := &Interface{}
+
+	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
+	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
+	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
+	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+
+	hm.unlockedAddHostInfo(h4, f)
+	hm.unlockedAddHostInfo(h3, f)
+	hm.unlockedAddHostInfo(h2, f)
+	hm.unlockedAddHostInfo(h1, f)
+
+	// Make sure we go h1 -> h2 -> h3 -> h4
+	prim, _ := hm.QueryVpnIp(1)
+	assert.Equal(t, h1.localIndexId, prim.localIndexId)
+	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
+	assert.Nil(t, h4.next)
+
+	// Swap h3/middle to primary
+	hm.MakePrimary(h3)
+
+	// Make sure we go h3 -> h1 -> h2 -> h4
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h3.localIndexId, prim.localIndexId)
+	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
+	assert.Nil(t, h4.next)
+
+	// Swap h4/tail to primary
+	hm.MakePrimary(h4)
+
+	// Make sure we go h4 -> h3 -> h1 -> h2
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h4.localIndexId, prim.localIndexId)
+	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Nil(t, h2.next)
+
+	// Swap h4 again should be no-op
+	hm.MakePrimary(h4)
+
+	// Make sure we go h4 -> h3 -> h1 -> h2
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h4.localIndexId, prim.localIndexId)
+	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Nil(t, h2.next)
+}
+
+func TestHostMap_DeleteHostInfo(t *testing.T) {
+	l := test.NewLogger()
+	hm := NewHostMap(
+		l, "test",
+		&net.IPNet{
+			IP:   net.IP{10, 0, 0, 1},
+			Mask: net.IPMask{255, 255, 255, 0},
+		},
+		[]*net.IPNet{},
+	)
+
+	f := &Interface{}
+
+	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
+	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
+	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
+	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+	h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
+	h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+
+	hm.unlockedAddHostInfo(h6, f)
+	hm.unlockedAddHostInfo(h5, f)
+	hm.unlockedAddHostInfo(h4, f)
+	hm.unlockedAddHostInfo(h3, f)
+	hm.unlockedAddHostInfo(h2, f)
+	hm.unlockedAddHostInfo(h1, f)
+
+	// h6 should be deleted
+	assert.Nil(t, h6.next)
+	assert.Nil(t, h6.prev)
+	_, err := hm.QueryIndex(h6.localIndexId)
+	assert.Error(t, err)
+
+	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
+	prim, _ := hm.QueryVpnIp(1)
+	assert.Equal(t, h1.localIndexId, prim.localIndexId)
+	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
+	assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
+	assert.Nil(t, h5.next)
+
+	// Delete primary
+	hm.DeleteHostInfo(h1)
+	assert.Nil(t, h1.prev)
+	assert.Nil(t, h1.next)
+
+	// Make sure we go h2 -> h3 -> h4 -> h5
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h2.localIndexId, prim.localIndexId)
+	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
+	assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
+	assert.Nil(t, h5.next)
+
+	// Delete in the middle
+	hm.DeleteHostInfo(h3)
+	assert.Nil(t, h3.prev)
+	assert.Nil(t, h3.next)
+
+	// Make sure we go h2 -> h4 -> h5
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h2.localIndexId, prim.localIndexId)
+	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
+	assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
+	assert.Nil(t, h5.next)
+
+	// Delete the tail
+	hm.DeleteHostInfo(h5)
+	assert.Nil(t, h5.prev)
+	assert.Nil(t, h5.next)
+
+	// Make sure we go h2 -> h4
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h2.localIndexId, prim.localIndexId)
+	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
+	assert.Nil(t, h4.next)
+
+	// Delete the head
+	hm.DeleteHostInfo(h2)
+	assert.Nil(t, h2.prev)
+	assert.Nil(t, h2.next)
+
+	// Make sure we only have h4
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h4.localIndexId, prim.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Nil(t, prim.next)
+	assert.Nil(t, h4.next)
+
+	// Delete the only item
+	hm.DeleteHostInfo(h4)
+	assert.Nil(t, h4.prev)
+	assert.Nil(t, h4.next)
+
+	// Make sure we have nil
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Nil(t, prim)
+}

+ 5 - 3
outside.go

@@ -245,9 +245,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	f.connectionManager.ClearLocalIndex(hostInfo.localIndexId)
 	f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId)
-	f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
-
-	f.hostMap.DeleteHostInfo(hostInfo)
+	final := f.hostMap.DeleteHostInfo(hostInfo)
+	if final {
+		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
+		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
+	}
 }
 
 // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote

+ 1 - 1
overlay/tun_tester.go

@@ -51,7 +51,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
 // packets should exit the udp side, capture them with udpConn.Get
 func (t *TestTun) Send(packet []byte) {
 	if t.l.Level >= logrus.InfoLevel {
-		t.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
+		t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
 	}
 	t.rxPackets <- packet
 }

+ 5 - 0
relay_manager.go

@@ -61,6 +61,11 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput
 
 		_, inRelays := hm.Relays[index]
 		if !inRelays {
+			// Avoid standing up a relay that can't be used since only the primary hostinfo
+			// will be pointed to by the relay logic
+			//TODO: if there was an existing primary and it had relay state, should we merge?
+			hm.unlockedMakePrimary(relayHostInfo)
+
 			hm.Relays[index] = relayHostInfo
 			newRelay := Relay{
 				Type:       relayType,

+ 12 - 3
ssh.go

@@ -22,8 +22,9 @@ import (
 )
 
 type sshListHostMapFlags struct {
-	Json   bool
-	Pretty bool
+	Json    bool
+	Pretty  bool
+	ByIndex bool
 }
 
 type sshPrintCertFlags struct {
@@ -174,6 +175,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
+			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
@@ -189,6 +191,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
+			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
@@ -368,7 +371,13 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 		return nil
 	}
 
-	hm := listHostMap(hostMap)
+	var hm []ControlHostInfo
+	if fs.ByIndex {
+		hm = listHostMapIndexes(hostMap)
+	} else {
+		hm = listHostMapHosts(hostMap)
+	}
+
 	sort.Slice(hm, func(i, j int) bool {
 		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
 	})

+ 1 - 1
udp/udp_tester.go

@@ -66,7 +66,7 @@ func (u *Conn) Send(packet *Packet) {
 		u.l.WithField("header", h).
 			WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
 			WithField("dataLen", len(packet.Data)).
-			Info("UDP receiving injected packet")
+			Debug("UDP receiving injected packet")
 	}
 	u.RxPackets <- packet
 }