浏览代码

Remove handshake race avoidance (#820)

Co-authored-by: Wade Simmons <[email protected]>
Nate Brown 2 年之前
父节点
当前提交
92cc32f844
共有 18 个文件被更改,包括 741 次插入157 次删除
  1. 23 1
      connection_manager.go
  2. 2 2
      connection_manager_test.go
  3. 27 5
      control.go
  4. 247 51
      e2e/handshakes_test.go
  5. 4 4
      e2e/helpers_test.go
  6. 32 2
      e2e/router/hostmap.go
  7. 33 1
      e2e/router/router.go
  8. 1 0
      go.mod
  9. 2 0
      go.sum
  10. 13 15
      handshake_ix.go
  11. 26 49
      handshake_manager.go
  12. 101 19
      hostmap.go
  13. 206 0
      hostmap_test.go
  14. 5 3
      outside.go
  15. 1 1
      overlay/tun_tester.go
  16. 5 0
      relay_manager.go
  17. 12 3
      ssh.go
  18. 1 1
      udp/udp_tester.go

+ 23 - 1
connection_manager.go

@@ -181,6 +181,14 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			continue
 			continue
 		}
 		}
 
 
+		// Does the vpnIp point to this hostinfo or is it ancillary? If we have ancillary hostinfos then we need to
+		// decide if this should be the main hostinfo if we are seeing traffic on it
+		primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
+		mainHostInfo := true
+		if primary != nil && primary != hostinfo {
+			mainHostInfo = false
+		}
+
 		// If we saw an incoming packets from this ip and peer's certificate is not
 		// If we saw an incoming packets from this ip and peer's certificate is not
 		// expired, just ignore.
 		// expired, just ignore.
 		if traf {
 		if traf {
@@ -191,6 +199,20 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			}
 			}
 			n.ClearLocalIndex(localIndex)
 			n.ClearLocalIndex(localIndex)
 			n.ClearPendingDeletion(localIndex)
 			n.ClearPendingDeletion(localIndex)
+
+			if !mainHostInfo {
+				if hostinfo.vpnIp > n.intf.myVpnIp {
+					// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
+					// This the primary and prime the old primary hostinfo for testing
+					n.hostMap.MakePrimary(hostinfo)
+					n.Out(primary.localIndexId)
+				} else {
+					// This hostinfo is still being used despite not being the primary hostinfo for this vpn ip
+					// Keep tracking so that we can tear it down when it goes away
+					n.Out(hostinfo.localIndexId)
+				}
+			}
+
 			continue
 			continue
 		}
 		}
 
 
@@ -198,7 +220,7 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 			Debug("Tunnel status")
 			Debug("Tunnel status")
 
 
-		if hostinfo != nil && hostinfo.ConnectionState != nil {
+		if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
 			n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
 
 

+ 2 - 2
connection_manager_test.go

@@ -80,7 +80,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		certState: cs,
 		certState: cs,
 		H:         &noise.HandshakeState{},
 		H:         &noise.HandshakeState{},
 	}
 	}
-	nc.hostMap.addHostInfo(hostinfo, ifce)
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)
@@ -156,7 +156,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		certState: cs,
 		certState: cs,
 		H:         &noise.HandshakeState{},
 		H:         &noise.HandshakeState{},
 	}
 	}
-	nc.hostMap.addHostInfo(hostinfo, ifce)
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 
 	// We saw traffic out to vpnIp
 	// We saw traffic out to vpnIp
 	nc.Out(hostinfo.localIndexId)
 	nc.Out(hostinfo.localIndexId)

+ 27 - 5
control.go

@@ -95,12 +95,21 @@ func (c *Control) RebindUDPServer() {
 	c.f.rebindCount++
 	c.f.rebindCount++
 }
 }
 
 
-// ListHostmap returns details about the actual or pending (handshaking) hostmap
-func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
+// ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
+func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
 	if pendingMap {
 	if pendingMap {
-		return listHostMap(c.f.handshakeManager.pendingHostMap)
+		return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
 	} else {
 	} else {
-		return listHostMap(c.f.hostMap)
+		return listHostMapHosts(c.f.hostMap)
+	}
+}
+
+// ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
+func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
+	if pendingMap {
+		return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
+	} else {
+		return listHostMapIndexes(c.f.hostMap)
 	}
 	}
 }
 }
 
 
@@ -232,7 +241,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	return chi
 	return chi
 }
 }
 
 
-func listHostMap(hm *HostMap) []ControlHostInfo {
+func listHostMapHosts(hm *HostMap) []ControlHostInfo {
 	hm.RLock()
 	hm.RLock()
 	hosts := make([]ControlHostInfo, len(hm.Hosts))
 	hosts := make([]ControlHostInfo, len(hm.Hosts))
 	i := 0
 	i := 0
@@ -244,3 +253,16 @@ func listHostMap(hm *HostMap) []ControlHostInfo {
 
 
 	return hosts
 	return hosts
 }
 }
+
+func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
+	hm.RLock()
+	hosts := make([]ControlHostInfo, len(hm.Indexes))
+	i := 0
+	for _, v := range hm.Indexes {
+		hosts[i] = copyHostInfo(v, hm.preferredRanges)
+		i++
+	}
+	hm.RUnlock()
+
+	return hosts
+}

+ 247 - 51
e2e/handshakes_test.go

@@ -19,10 +19,10 @@ import (
 func BenchmarkHotPath(b *testing.B) {
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
 	myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
@@ -32,7 +32,7 @@ func BenchmarkHotPath(b *testing.B) {
 	r.CancelFlowLogs()
 	r.CancelFlowLogs()
 
 
 	for n := 0; n < b.N; n++ {
 	for n := 0; n < b.N; n++ {
-		myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+		myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 		_ = r.RouteForAllUntilTxTun(theirControl)
 		_ = r.RouteForAllUntilTxTun(theirControl)
 	}
 	}
 
 
@@ -42,18 +42,18 @@ func BenchmarkHotPath(b *testing.B) {
 
 
 func TestGoodHandshake(t *testing.T) {
 func TestGoodHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 
 	// Put their info in our lighthouse
 	// Put their info in our lighthouse
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 
 	// Start the servers
 	// Start the servers
 	myControl.Start()
 	myControl.Start()
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
 	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 
 
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
@@ -74,16 +74,16 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.WaitForType(1, 0, theirControl)
 	myControl.WaitForType(1, 0, theirControl)
 
 
 	t.Log("Make sure our host infos are correct")
 	t.Log("Make sure our host infos are correct")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
 
 
 	t.Log("Get that cached packet and make sure it looks right")
 	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 
 
 	t.Log("Do a bidirectional tunnel test")
 	t.Log("Do a bidirectional tunnel test")
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
 	defer r.RenderFlow()
 	defer r.RenderFlow()
-	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
 
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
 	myControl.Stop()
@@ -97,15 +97,15 @@ func TestWrongResponderHandshake(t *testing.T) {
 	// The IPs here are chosen on purpose:
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
 	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
 	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
 
 
 	// Add their real udp addr, which should be tried after evil.
 	// Add their real udp addr, which should be tried after evil.
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
 	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
-	myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, evilUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl, evilControl)
 	r := router.NewR(t, myControl, theirControl, evilControl)
@@ -117,7 +117,7 @@ func TestWrongResponderHandshake(t *testing.T) {
 	evilControl.Start()
 	evilControl.Start()
 
 
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 	r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
 		h := &header.H{}
 		h := &header.H{}
 		err := h.Parse(p.Data)
 		err := h.Parse(p.Data)
@@ -136,18 +136,18 @@ func TestWrongResponderHandshake(t *testing.T) {
 
 
 	t.Log("My cached packet should be received by them")
 	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 
 
 	t.Log("Test the tunnel with them")
 	t.Log("Test the tunnel with them")
-	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
-	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl)
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
 
 
 	t.Log("Flush all packets from all controllers")
 	t.Log("Flush all packets from all controllers")
 	r.FlushAll()
 	r.FlushAll()
 
 
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
 	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), true), "My pending hostmap should not contain evil")
-	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp), false), "My main hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(evilVpnIp.IP), false), "My main hostmap should not contain evil")
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 
 	//TODO: assert hostmaps for everyone
 	//TODO: assert hostmaps for everyone
@@ -157,14 +157,17 @@ func TestWrongResponderHandshake(t *testing.T) {
 	theirControl.Stop()
 	theirControl.Stop()
 }
 }
 
 
-func Test_Case1_Stage1Race(t *testing.T) {
+func TestStage1Race(t *testing.T) {
+	// This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow
+	// But will eventually collapse down to a single tunnel
+
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 
 	// Put their info in our lighthouse and vice versa
 	// Put their info in our lighthouse and vice versa
-	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
-	theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, theirControl)
 	r := router.NewR(t, myControl, theirControl)
@@ -175,8 +178,8 @@ func Test_Case1_Stage1Race(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake to start on both me and them")
 	t.Log("Trigger a handshake to start on both me and them")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-	theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
 
 
 	t.Log("Get both stage 1 handshake packets")
 	t.Log("Get both stage 1 handshake packets")
 	myHsForThem := myControl.GetFromUDP(true)
 	myHsForThem := myControl.GetFromUDP(true)
@@ -185,44 +188,165 @@ func Test_Case1_Stage1Race(t *testing.T) {
 	r.Log("Now inject both stage 1 handshake packets")
 	r.Log("Now inject both stage 1 handshake packets")
 	r.InjectUDPPacket(theirControl, myControl, theirHsForMe)
 	r.InjectUDPPacket(theirControl, myControl, theirHsForMe)
 	r.InjectUDPPacket(myControl, theirControl, myHsForThem)
 	r.InjectUDPPacket(myControl, theirControl, myHsForThem)
-	//TODO: they should win, grab their index for me and make sure I use it in the end.
 
 
-	r.Log("They should not have a stage 2 (won the race) but I should send one")
-	r.InjectUDPPacket(myControl, theirControl, myControl.GetFromUDP(true))
+	r.Log("Route until they receive a message packet")
+	myCachedPacket := r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 
 
-	r.Log("Route for me until I send a message packet to them")
-	r.RouteForAllUntilAfterMsgTypeTo(theirControl, header.Message, header.MessageNone)
+	r.Log("Their cached packet should be received by me")
+	theirCachedPacket := r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
 
 
-	t.Log("My cached packet should be received by them")
-	myCachedPacket := theirControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+	r.Log("Do a bidirectional tunnel test")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
 
 
-	t.Log("Route for them until I send a message packet to me")
-	theirControl.WaitForType(1, 0, myControl)
+	myHostmapHosts := myControl.ListHostmapHosts(false)
+	myHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirHostmapIndexes := theirControl.ListHostmapIndexes(false)
 
 
-	t.Log("Their cached packet should be received by me")
-	theirCachedPacket := myControl.GetFromTun(true)
-	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
+	// We should have two tunnels on both sides
+	assert.Len(t, myHostmapHosts, 1)
+	assert.Len(t, theirHostmapHosts, 1)
+	assert.Len(t, myHostmapIndexes, 2)
+	assert.Len(t, theirHostmapIndexes, 2)
 
 
-	t.Log("Do a bidirectional tunnel test")
-	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	r.Log("Spin until connection manager tears down a tunnel")
+
+	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		t.Log("Connection manager hasn't ticked yet")
+		time.Sleep(time.Second)
+	}
+
+	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
+	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
+
+	// We should only have a single tunnel now on both sides
+	assert.Len(t, myFinalHostmapHosts, 1)
+	assert.Len(t, theirFinalHostmapHosts, 1)
+	assert.Len(t, myFinalHostmapIndexes, 1)
+	assert.Len(t, theirFinalHostmapIndexes, 1)
 
 
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 	myControl.Stop()
 	myControl.Stop()
 	theirControl.Stop()
 	theirControl.Stop()
-	//TODO: assert hostmaps
+}
+
+func TestUncleanShutdownRaceLoser(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r.Log("Trigger a handshake from me to them")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+
+	r.Log("Nuke my hostmap")
+	myHostmap := myControl.GetHostmap()
+	myHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	myHostmap.Indexes = map[uint32]*nebula.HostInfo{}
+	myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
+
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me again"))
+	p = r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+
+	r.Log("Assert the tunnel works")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	r.Log("Wait for the dead index to go away")
+	start := len(theirControl.GetHostmap().Indexes)
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		if len(theirControl.GetHostmap().Indexes) < start {
+			break
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+}
+
+func TestUncleanShutdownRaceWinner(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r.Log("Trigger a handshake from me to them")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	r.Log("Nuke my hostmap")
+	theirHostmap := theirControl.GetHostmap()
+	theirHostmap.Hosts = map[iputil.VpnIp]*nebula.HostInfo{}
+	theirHostmap.Indexes = map[uint32]*nebula.HostInfo{}
+	theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{}
+
+	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them again"))
+	p = r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("Derp hostmaps", myControl, theirControl)
+
+	r.Log("Assert the tunnel works")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	r.Log("Wait for the dead index to go away")
+	start := len(myControl.GetHostmap().Indexes)
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		if len(myControl.GetHostmap().Indexes) < start {
+			break
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
 }
 }
 
 
 func TestRelays(t *testing.T) {
 func TestRelays(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIp, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIp, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	// Teach my how to get to the relay and that their can be reached via the relay
-	myControl.InjectLightHouseAddr(relayVpnIp, relayUdpAddr)
-	myControl.InjectRelays(theirVpnIp, []net.IP{relayVpnIp})
-	relayControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
 	r := router.NewR(t, myControl, relayControl, theirControl)
 	r := router.NewR(t, myControl, relayControl, theirControl)
@@ -234,12 +358,84 @@ func TestRelays(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 
 
 	t.Log("Trigger a handshake from me to them via the relay")
 	t.Log("Trigger a handshake from me to them via the relay")
-	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
 
 
 	p := r.RouteForAllUntilTxTun(theirControl)
 	p := r.RouteForAllUntilTxTun(theirControl)
-	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIp, theirVpnIp, 80, 80)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl)
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 	//TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it
 }
 }
 
 
+func TestStage1RaceRelays(t *testing.T) {
+	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	theirControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	theirControl.InjectRelays(myVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	relayControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	r.Log("Trigger a handshake to start on both me and relay")
+	myControl.InjectTunUDPPacket(relayVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	relayControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from relay"))
+
+	r.Log("Get both stage 1 handshake packets")
+	//TODO: this is where it breaks, we need to get the hs packets for the relay not for the destination
+	myHsForThem := myControl.GetFromUDP(true)
+	relayHsForMe := relayControl.GetFromUDP(true)
+
+	r.Log("Now inject both stage 1 handshake packets")
+	r.InjectUDPPacket(relayControl, myControl, relayHsForMe)
+	r.InjectUDPPacket(myControl, relayControl, myHsForThem)
+
+	r.Log("Route for me until I send a message packet to relay")
+	r.RouteForAllUntilAfterMsgTypeTo(relayControl, header.Message, header.MessageNone)
+
+	r.Log("My cached packet should be received by relay")
+	myCachedPacket := relayControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.IP, relayVpnIpNet.IP, 80, 80)
+
+	r.Log("Relays cached packet should be received by me")
+	relayCachedPacket := r.RouteForAllUntilTxTun(myControl)
+	assertUdpPacket(t, []byte("Hi from relay"), relayCachedPacket, relayVpnIpNet.IP, myVpnIpNet.IP, 80, 80)
+
+	r.Log("Do a bidirectional tunnel test; me and relay")
+	assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+
+	r.Log("Create a tunnel between relay and them")
+	assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+
+	r.RenderHostmaps("Starting hostmaps", myControl, relayControl, theirControl)
+
+	r.Log("Trigger a handshake to start from me to them via the relay")
+	//TODO: if we initiate a handshake from me and then assert the tunnel it will cause a relay control race that can blow up
+	//	this is a problem that exists on master today
+	//myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+	relayControl.Stop()
+	//
+	////TODO: assert hostmaps
+}
+
 //TODO: add a test with many lies
 //TODO: add a test with many lies

+ 4 - 4
e2e/helpers_test.go

@@ -30,7 +30,7 @@ import (
 type m map[string]interface{}
 type m map[string]interface{}
 
 
 // newSimpleServer creates a nebula instance with many assumptions
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, net.IP, *net.UDPAddr) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) {
 	l := NewTestLogger()
 	l := NewTestLogger()
 
 
 	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
 	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
@@ -101,7 +101,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	return control, vpnIpNet.IP, &udpAddr
+	return control, vpnIpNet, &udpAddr
 }
 }
 
 
 // newTestCaCert will generate a CA cert
 // newTestCaCert will generate a CA cert
@@ -231,12 +231,12 @@ func deadline(t *testing.T, seconds time.Duration) doneCb {
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
 func assertTunnel(t *testing.T, vpnIpA, vpnIpB net.IP, controlA, controlB *nebula.Control, r *router.R) {
 	// Send a packet from them to me
 	// Send a packet from them to me
 	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
 	controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B"))
-	bPacket := r.RouteUntilTxTun(controlB, controlA)
+	bPacket := r.RouteForAllUntilTxTun(controlA)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 	assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80)
 
 
 	// And once more from me to them
 	// And once more from me to them
 	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
 	controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A"))
-	aPacket := r.RouteUntilTxTun(controlA, controlB)
+	aPacket := r.RouteForAllUntilTxTun(controlB)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 	assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80)
 }
 }
 
 

+ 32 - 2
e2e/router/hostmap.go

@@ -5,9 +5,11 @@ package router
 
 
 import (
 import (
 	"fmt"
 	"fmt"
+	"sort"
 	"strings"
 	"strings"
 
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
+	"github.com/slackhq/nebula/iputil"
 )
 )
 
 
 type edge struct {
 type edge struct {
@@ -64,7 +66,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 
 
 	// Draw the vpn to index nodes
 	// Draw the vpn to index nodes
 	r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName)
 	r += fmt.Sprintf("\t\tsubgraph %s.hosts[\"Hosts (vpn ip to index)\"]\n", clusterName)
-	for vpnIp, hi := range hm.Hosts {
+	for _, vpnIp := range sortedHosts(hm.Hosts) {
+		hi := hm.Hosts[vpnIp]
 		r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp)
 		r += fmt.Sprintf("\t\t\t%v.%v[\"%v\"]\n", clusterName, vpnIp, vpnIp)
 		lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex()))
 		lines = append(lines, fmt.Sprintf("%v.%v --> %v.%v", clusterName, vpnIp, clusterName, hi.GetLocalIndex()))
 
 
@@ -91,7 +94,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 
 
 	// Draw the local index to relay or remote index nodes
 	// Draw the local index to relay or remote index nodes
 	r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName)
 	r += fmt.Sprintf("\t\tsubgraph indexes.%s[\"Indexes (index to hostinfo)\"]\n", clusterName)
-	for idx, hi := range hm.Indexes {
+	for _, idx := range sortedIndexes(hm.Indexes) {
+		hi := hm.Indexes[idx]
 		r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
 		r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp())
 		remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
 		remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ")
 		globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
 		globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())})
@@ -107,3 +111,29 @@ func renderHostmap(c *nebula.Control) (string, []*edge) {
 	r += "\tend\n"
 	r += "\tend\n"
 	return r, globalLines
 	return r, globalLines
 }
 }
+
+func sortedHosts(hosts map[iputil.VpnIp]*nebula.HostInfo) []iputil.VpnIp {
+	keys := make([]iputil.VpnIp, 0, len(hosts))
+	for key := range hosts {
+		keys = append(keys, key)
+	}
+
+	sort.SliceStable(keys, func(i, j int) bool {
+		return keys[i] > keys[j]
+	})
+
+	return keys
+}
+
+func sortedIndexes(indexes map[uint32]*nebula.HostInfo) []uint32 {
+	keys := make([]uint32, 0, len(indexes))
+	for key := range indexes {
+		keys = append(keys, key)
+	}
+
+	sort.SliceStable(keys, func(i, j int) bool {
+		return keys[i] > keys[j]
+	})
+
+	return keys
+}

+ 33 - 1
e2e/router/router.go

@@ -10,6 +10,7 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"reflect"
 	"reflect"
+	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -22,6 +23,7 @@ import (
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/slackhq/nebula/udp"
+	"golang.org/x/exp/maps"
 )
 )
 
 
 type R struct {
 type R struct {
@@ -150,6 +152,7 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R {
 			case <-ctx.Done():
 			case <-ctx.Done():
 				return
 				return
 			case <-clockSource.C:
 			case <-clockSource.C:
+				r.renderHostmaps("clock tick")
 				r.renderFlow()
 				r.renderFlow()
 			}
 			}
 		}
 		}
@@ -220,11 +223,16 @@ func (r *R) renderFlow() {
 		)
 		)
 	}
 	}
 
 
+	if len(participantsVals) > 2 {
+		// Get the first and last participantVals for notes
+		participantsVals = []string{participantsVals[0], participantsVals[len(participantsVals)-1]}
+	}
+
 	// Print packets
 	// Print packets
 	h := &header.H{}
 	h := &header.H{}
 	for _, e := range r.flow {
 	for _, e := range r.flow {
 		if e.packet == nil {
 		if e.packet == nil {
-			fmt.Fprintf(f, "    note over %s: %s\n", strings.Join(participantsVals, ", "), e.note)
+			//fmt.Fprintf(f, "    note over %s: %s\n", strings.Join(participantsVals, ", "), e.note)
 			continue
 			continue
 		}
 		}
 
 
@@ -294,6 +302,28 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) {
 	})
 	})
 }
 }
 
 
+func (r *R) renderHostmaps(title string) {
+	c := maps.Values(r.controls)
+	sort.SliceStable(c, func(i, j int) bool {
+		return c[i].GetVpnIp() > c[j].GetVpnIp()
+	})
+
+	s := renderHostmaps(c...)
+	if len(r.additionalGraphs) > 0 {
+		lastGraph := r.additionalGraphs[len(r.additionalGraphs)-1]
+		if lastGraph.content == s {
+			// Ignore this rendering if it matches the last rendering added
+			// This is useful if you want to track rendering changes
+			return
+		}
+	}
+
+	r.additionalGraphs = append(r.additionalGraphs, mermaidGraph{
+		title:   title,
+		content: s,
+	})
+}
+
 // InjectFlow can be used to record packet flow if the test is handling the routing on its own.
 // InjectFlow can be used to record packet flow if the test is handling the routing on its own.
 // The packet is assumed to have been received
 // The packet is assumed to have been received
 func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) {
 func (r *R) InjectFlow(from, to *nebula.Control, p *udp.Packet) {
@@ -332,6 +362,8 @@ func (r *R) unlockedInjectFlow(from, to *nebula.Control, p *udp.Packet, tun bool
 		return nil
 		return nil
 	}
 	}
 
 
+	r.renderHostmaps(fmt.Sprintf("Packet %v", len(r.flow)))
+
 	if len(r.ignoreFlows) > 0 {
 	if len(r.ignoreFlows) > 0 {
 		var h header.H
 		var h header.H
 		err := h.Parse(p.Data)
 		err := h.Parse(p.Data)

+ 1 - 0
go.mod

@@ -21,6 +21,7 @@ require (
 	github.com/stretchr/testify v1.8.1
 	github.com/stretchr/testify v1.8.1
 	github.com/vishvananda/netlink v1.1.0
 	github.com/vishvananda/netlink v1.1.0
 	golang.org/x/crypto v0.3.0
 	golang.org/x/crypto v0.3.0
+	golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2
 	golang.org/x/net v0.2.0
 	golang.org/x/net v0.2.0
 	golang.org/x/sys v0.2.0
 	golang.org/x/sys v0.2.0
 	golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224
 	golang.zx2c4.com/wintun v0.0.0-20211104114900-415007cec224

+ 2 - 0
go.sum

@@ -266,6 +266,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0
 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4=
 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
 golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM=
 golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
 golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU=
+golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 h1:Jvc7gsqn21cJHCmAWx0LiimpP18LZmUxkT5Mp7EZ1mI=
+golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc=
 golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
 golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
 golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=

+ 13 - 15
handshake_ix.go

@@ -207,9 +207,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 	hostinfo.SetRemote(addr)
 	hostinfo.SetRemote(addr)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 
-	// Only overwrite existing record if we should win the handshake race
-	overwrite := vpnIp > f.myVpnIp
-	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
+	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f)
 	if err != nil {
 	if err != nil {
 		switch err {
 		switch err {
 		case ErrAlreadySeen:
 		case ErrAlreadySeen:
@@ -280,16 +278,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
 				WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp).
 				Error("Failed to add HostInfo due to localIndex collision")
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 			return
-		case ErrExistingHandshake:
-			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
-			f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
-				WithField("certName", certName).
-				WithField("fingerprint", fingerprint).
-				WithField("issuer", issuer).
-				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
-				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				Error("Prevented a pending handshake race")
-			return
 		default:
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
 			// And we forget to update it here
@@ -344,6 +332,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via interface{}, packet []b
 			Info("Handshake message sent")
 			Info("Handshake message sent")
 	}
 	}
 
 
+	if existing != nil {
+		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
+		f.connectionManager.Out(existing.localIndexId)
+	}
+
+	f.connectionManager.Out(hostinfo.localIndexId)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 
 
 	return
 	return
@@ -501,8 +495,12 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via interface{}, hostinfo *
 	hostinfo.CreateRemoteCIDR(remoteCert)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
-	//TODO: Complete here does not do a race avoidance, it will just take the new tunnel. Is this ok?
-	f.handshakeManager.Complete(hostinfo, f)
+	existing := f.handshakeManager.Complete(hostinfo, f)
+	if existing != nil {
+		// Make sure we are tracking the old primary if there was one, it needs to go away eventually
+		f.connectionManager.Out(existing.localIndexId)
+	}
+
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
 	f.metricHandshakes.Update(duration)
 	f.metricHandshakes.Update(duration)
 
 

+ 26 - 49
handshake_manager.go

@@ -53,6 +53,10 @@ type HandshakeManager struct {
 	metricTimedOut         metrics.Counter
 	metricTimedOut         metrics.Counter
 	l                      *logrus.Logger
 	l                      *logrus.Logger
 
 
+	// vpnIps is another map similar to the pending hostmap but tracks entries in the wheel instead
+	// this is to avoid situations where the same vpn ip enters the wheel and causes rapid fire handshaking
+	vpnIps map[iputil.VpnIp]struct{}
+
 	// can be used to trigger outbound handshake for the given vpnIp
 	// can be used to trigger outbound handshake for the given vpnIp
 	trigger chan iputil.VpnIp
 	trigger chan iputil.VpnIp
 }
 }
@@ -66,6 +70,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 		config:                 config,
 		config:                 config,
 		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
 		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
 		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		vpnIps:                 map[iputil.VpnIp]struct{}{},
 		messageMetrics:         config.messageMetrics,
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -103,6 +108,7 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.E
 func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
 func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, lighthouseTriggered bool) {
 	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
 	if err != nil {
 	if err != nil {
+		delete(c.vpnIps, vpnIp)
 		return
 		return
 	}
 	}
 	hostinfo.Lock()
 	hostinfo.Lock()
@@ -160,7 +166,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
 		c.lightHouse.QueryServer(vpnIp, f)
 		c.lightHouse.QueryServer(vpnIp, f)
 	}
 	}
 
 
-	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
+	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
 	var sentTo []*udp.Addr
 	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
 	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
 		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
 		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
@@ -260,7 +266,6 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
 
 
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
 	if !lighthouseTriggered {
-		//TODO: feel like we dupe handshake real fast in a tight loop, why?
 		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 	}
 }
 }
@@ -269,7 +274,10 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
 	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
 	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
 
 
 	if created {
 	if created {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+		if _, ok := c.vpnIps[vpnIp]; !ok {
+			c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+		}
+		c.vpnIps[vpnIp] = struct{}{}
 		c.metricInitiated.Inc(1)
 		c.metricInitiated.Inc(1)
 	}
 	}
 
 
@@ -280,7 +288,6 @@ var (
 	ErrExistingHostInfo    = errors.New("existing hostinfo")
 	ErrExistingHostInfo    = errors.New("existing hostinfo")
 	ErrAlreadySeen         = errors.New("already seen")
 	ErrAlreadySeen         = errors.New("already seen")
 	ErrLocalIndexCollision = errors.New("local index collision")
 	ErrLocalIndexCollision = errors.New("local index collision")
-	ErrExistingHandshake   = errors.New("existing handshake")
 )
 )
 
 
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
@@ -294,7 +301,7 @@ var (
 //
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 // hostmap for the hostinfo.localIndexId.
-func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
+func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
 	c.pendingHostMap.Lock()
 	c.pendingHostMap.Lock()
 	defer c.pendingHostMap.Unlock()
 	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	c.mainHostMap.Lock()
@@ -303,9 +310,14 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	// Check if we already have a tunnel with this vpn ip
 	// Check if we already have a tunnel with this vpn ip
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
 	if found && existingHostInfo != nil {
 	if found && existingHostInfo != nil {
-		// Is it just a delayed handshake packet?
-		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
-			return existingHostInfo, ErrAlreadySeen
+		testHostInfo := existingHostInfo
+		for testHostInfo != nil {
+			// Is it just a delayed handshake packet?
+			if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
+				return existingHostInfo, ErrAlreadySeen
+			}
+
+			testHostInfo = testHostInfo.next
 		}
 		}
 
 
 		// Is this a newer handshake?
 		// Is this a newer handshake?
@@ -337,56 +349,19 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
-	// Check if we are also handshaking with this vpn ip
-	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.vpnIp]
-	if found && pendingHostInfo != nil {
-		if !overwrite {
-			// We won, let our pending handshake win
-			return pendingHostInfo, ErrExistingHandshake
-		}
-
-		// We lost, take this handshake and move any cached packets over so they get sent
-		pendingHostInfo.ConnectionState.queueLock.Lock()
-		hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
-		c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
-		pendingHostInfo.ConnectionState.queueLock.Unlock()
-		pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
-	}
-
-	if existingHostInfo != nil {
-		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
-		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
-		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
-		for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() {
-			delete(c.mainHostMap.Relays, relayIdx)
-		}
-	}
-
-	c.mainHostMap.addHostInfo(hostinfo, f)
+	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	return existingHostInfo, nil
 	return existingHostInfo, nil
 }
 }
 
 
 // Complete is a simpler version of CheckAndComplete when we already know we
 // Complete is a simpler version of CheckAndComplete when we already know we
 // won't have a localIndexId collision because we already have an entry in the
 // won't have a localIndexId collision because we already have an entry in the
-// pendingHostMap
-func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
+// pendingHostMap. An existing hostinfo is returned if there was one.
+func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) *HostInfo {
 	c.pendingHostMap.Lock()
 	c.pendingHostMap.Lock()
 	defer c.pendingHostMap.Unlock()
 	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 	defer c.mainHostMap.Unlock()
 
 
-	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
-	if found && existingHostInfo != nil {
-		// We are going to overwrite this entry, so remove the old references
-		delete(c.mainHostMap.Hosts, existingHostInfo.vpnIp)
-		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
-		delete(c.mainHostMap.RemoteIndexes, existingHostInfo.remoteIndexId)
-		for _, relayIdx := range existingHostInfo.relayState.CopyRelayForIdxs() {
-			delete(c.mainHostMap.Relays, relayIdx)
-		}
-	}
-
 	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
 	existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
 	if found && existingRemoteIndex != nil {
 	if found && existingRemoteIndex != nil {
 		// We have a collision, but this can happen since we can't control
 		// We have a collision, but this can happen since we can't control
@@ -396,8 +371,10 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
-	c.mainHostMap.addHostInfo(hostinfo, f)
+	existingHostInfo := c.mainHostMap.Hosts[hostinfo.vpnIp]
+	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
 	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
+	return existingHostInfo
 }
 }
 
 
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo

+ 101 - 19
hostmap.go

@@ -23,6 +23,10 @@ const PromoteEvery = 1000
 const ReQueryEvery = 5000
 const ReQueryEvery = 5000
 const MaxRemotes = 10
 const MaxRemotes = 10
 
 
+// MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
+// 5 allows for an initial handshake and each host pair re-handshaking twice
+const MaxHostInfosPerVpnIp = 5
+
 // How long we should prevent roaming back to the previous IP.
 // How long we should prevent roaming back to the previous IP.
 // This helps prevent flapping due to packets already in flight
 // This helps prevent flapping due to packets already in flight
 const RoamingSuppressSeconds = 2
 const RoamingSuppressSeconds = 2
@@ -180,6 +184,10 @@ type HostInfo struct {
 
 
 	lastRoam       time.Time
 	lastRoam       time.Time
 	lastRoamRemote *udp.Addr
 	lastRoamRemote *udp.Addr
+
+	// Used to track other hostinfos for this vpn ip since only 1 can be primary
+	// Synchronised via hostmap lock and not the hostinfo lock.
+	next, prev *HostInfo
 }
 }
 
 
 type ViaSender struct {
 type ViaSender struct {
@@ -395,9 +403,12 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
+// DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
+func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
 	// Delete the host itself, ensuring it's not modified anymore
 	hm.Lock()
 	hm.Lock()
+	// If we have a previous or next hostinfo then we are not the last one for this vpn ip
+	final := (hostinfo.next == nil && hostinfo.prev == nil)
 	hm.unlockedDeleteHostInfo(hostinfo)
 	hm.unlockedDeleteHostInfo(hostinfo)
 	hm.Unlock()
 	hm.Unlock()
 
 
@@ -421,6 +432,8 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	for _, localIdx := range teardownRelayIdx {
 	for _, localIdx := range teardownRelayIdx {
 		hm.RemoveRelay(localIdx)
 		hm.RemoveRelay(localIdx)
 	}
 	}
+
+	return final
 }
 }
 
 
 func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
 func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
@@ -429,29 +442,81 @@ func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
 	delete(hm.RemoteIndexes, localIdx)
 	delete(hm.RemoteIndexes, localIdx)
 }
 }
 
 
+func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
+	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedMakePrimary(hostinfo)
+}
+
+func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) {
+	oldHostinfo := hm.Hosts[hostinfo.vpnIp]
+	if oldHostinfo == hostinfo {
+		return
+	}
+
+	if hostinfo.prev != nil {
+		hostinfo.prev.next = hostinfo.next
+	}
+
+	if hostinfo.next != nil {
+		hostinfo.next.prev = hostinfo.prev
+	}
+
+	hm.Hosts[hostinfo.vpnIp] = hostinfo
+
+	if oldHostinfo == nil {
+		return
+	}
+
+	hostinfo.next = oldHostinfo
+	oldHostinfo.prev = hostinfo
+	hostinfo.prev = nil
+}
+
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
-	// Check if this same hostId is in the hostmap with a different instance.
-	// This could happen if we have an entry in the pending hostmap with different
-	// index values than the one in the main hostmap.
-	hostinfo2, ok := hm.Hosts[hostinfo.vpnIp]
-	if ok && hostinfo2 != hostinfo {
-		delete(hm.Hosts, hostinfo2.vpnIp)
-		delete(hm.Indexes, hostinfo2.localIndexId)
-		delete(hm.RemoteIndexes, hostinfo2.remoteIndexId)
+	primary, ok := hm.Hosts[hostinfo.vpnIp]
+	if ok && primary == hostinfo {
+		// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
+		delete(hm.Hosts, hostinfo.vpnIp)
+		if len(hm.Hosts) == 0 {
+			hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+		}
+
+		if hostinfo.next != nil {
+			// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
+			hm.Hosts[hostinfo.vpnIp] = hostinfo.next
+			// It is primary, there is no previous hostinfo now
+			hostinfo.next.prev = nil
+		}
+
+	} else {
+		// Relink if we were in the middle of multiple hostinfos for this vpn ip
+		if hostinfo.prev != nil {
+			hostinfo.prev.next = hostinfo.next
+		}
+
+		if hostinfo.next != nil {
+			hostinfo.next.prev = hostinfo.prev
+		}
 	}
 	}
 
 
-	delete(hm.Hosts, hostinfo.vpnIp)
-	if len(hm.Hosts) == 0 {
-		hm.Hosts = map[iputil.VpnIp]*HostInfo{}
+	hostinfo.next = nil
+	hostinfo.prev = nil
+
+	// The remote index uses index ids outside our control so lets make sure we are only removing
+	// the remote index pointer here if it points to the hostinfo we are deleting
+	hostinfo2, ok := hm.RemoteIndexes[hostinfo.remoteIndexId]
+	if ok && hostinfo2 == hostinfo {
+		delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
+		if len(hm.RemoteIndexes) == 0 {
+			hm.RemoteIndexes = map[uint32]*HostInfo{}
+		}
 	}
 	}
+
 	delete(hm.Indexes, hostinfo.localIndexId)
 	delete(hm.Indexes, hostinfo.localIndexId)
 	if len(hm.Indexes) == 0 {
 	if len(hm.Indexes) == 0 {
 		hm.Indexes = map[uint32]*HostInfo{}
 		hm.Indexes = map[uint32]*HostInfo{}
 	}
 	}
-	delete(hm.RemoteIndexes, hostinfo.remoteIndexId)
-	if len(hm.RemoteIndexes) == 0 {
-		hm.RemoteIndexes = map[uint32]*HostInfo{}
-	}
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@@ -520,15 +585,22 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
 	return nil, errors.New("unable to find host")
 	return nil, errors.New("unable to find host")
 }
 }
 
 
-// We already have the hm Lock when this is called, so make sure to not call
-// any other methods that might try to grab it again
-func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
+// unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
+// If an entry exists for the Hosts table (vpnIp -> hostinfo) then the provided hostinfo will be made primary
+func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	if f.serveDns {
 	if f.serveDns {
 		remoteCert := hostinfo.ConnectionState.peerCert
 		remoteCert := hostinfo.ConnectionState.peerCert
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 	}
 	}
 
 
+	existing := hm.Hosts[hostinfo.vpnIp]
 	hm.Hosts[hostinfo.vpnIp] = hostinfo
 	hm.Hosts[hostinfo.vpnIp] = hostinfo
+
+	if existing != nil {
+		hostinfo.next = existing
+		existing.prev = hostinfo
+	}
+
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.Indexes[hostinfo.localIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 
@@ -537,6 +609,16 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 			Debug("Hostmap vpnIp added")
 	}
 	}
+
+	i := 1
+	check := hostinfo
+	for check != nil {
+		if i > MaxHostInfosPerVpnIp {
+			hm.unlockedDeleteHostInfo(check)
+		}
+		check = check.next
+		i++
+	}
 }
 }
 
 
 // punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
 // punchList assembles a list of all non nil RemoteList pointer entries in this hostmap

+ 206 - 0
hostmap_test.go

@@ -1 +1,207 @@
 package nebula
 package nebula
+
+import (
+	"net"
+	"testing"
+
+	"github.com/slackhq/nebula/test"
+	"github.com/stretchr/testify/assert"
+)
+
+func TestHostMap_MakePrimary(t *testing.T) {
+	l := test.NewLogger()
+	hm := NewHostMap(
+		l, "test",
+		&net.IPNet{
+			IP:   net.IP{10, 0, 0, 1},
+			Mask: net.IPMask{255, 255, 255, 0},
+		},
+		[]*net.IPNet{},
+	)
+
+	f := &Interface{}
+
+	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
+	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
+	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
+	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+
+	hm.unlockedAddHostInfo(h4, f)
+	hm.unlockedAddHostInfo(h3, f)
+	hm.unlockedAddHostInfo(h2, f)
+	hm.unlockedAddHostInfo(h1, f)
+
+	// Make sure we go h1 -> h2 -> h3 -> h4
+	prim, _ := hm.QueryVpnIp(1)
+	assert.Equal(t, h1.localIndexId, prim.localIndexId)
+	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
+	assert.Nil(t, h4.next)
+
+	// Swap h3/middle to primary
+	hm.MakePrimary(h3)
+
+	// Make sure we go h3 -> h1 -> h2 -> h4
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h3.localIndexId, prim.localIndexId)
+	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
+	assert.Nil(t, h4.next)
+
+	// Swap h4/tail to primary
+	hm.MakePrimary(h4)
+
+	// Make sure we go h4 -> h3 -> h1 -> h2
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h4.localIndexId, prim.localIndexId)
+	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Nil(t, h2.next)
+
+	// Swap h4 again should be no-op
+	hm.MakePrimary(h4)
+
+	// Make sure we go h4 -> h3 -> h1 -> h2
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h4.localIndexId, prim.localIndexId)
+	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h2.localIndexId, h1.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h1.prev.localIndexId)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Nil(t, h2.next)
+}
+
+func TestHostMap_DeleteHostInfo(t *testing.T) {
+	l := test.NewLogger()
+	hm := NewHostMap(
+		l, "test",
+		&net.IPNet{
+			IP:   net.IP{10, 0, 0, 1},
+			Mask: net.IPMask{255, 255, 255, 0},
+		},
+		[]*net.IPNet{},
+	)
+
+	f := &Interface{}
+
+	h1 := &HostInfo{vpnIp: 1, localIndexId: 1}
+	h2 := &HostInfo{vpnIp: 1, localIndexId: 2}
+	h3 := &HostInfo{vpnIp: 1, localIndexId: 3}
+	h4 := &HostInfo{vpnIp: 1, localIndexId: 4}
+	h5 := &HostInfo{vpnIp: 1, localIndexId: 5}
+	h6 := &HostInfo{vpnIp: 1, localIndexId: 6}
+
+	hm.unlockedAddHostInfo(h6, f)
+	hm.unlockedAddHostInfo(h5, f)
+	hm.unlockedAddHostInfo(h4, f)
+	hm.unlockedAddHostInfo(h3, f)
+	hm.unlockedAddHostInfo(h2, f)
+	hm.unlockedAddHostInfo(h1, f)
+
+	// h6 should be deleted
+	assert.Nil(t, h6.next)
+	assert.Nil(t, h6.prev)
+	_, err := hm.QueryIndex(h6.localIndexId)
+	assert.Error(t, err)
+
+	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
+	prim, _ := hm.QueryVpnIp(1)
+	assert.Equal(t, h1.localIndexId, prim.localIndexId)
+	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h1.localIndexId, h2.prev.localIndexId)
+	assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
+	assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
+	assert.Nil(t, h5.next)
+
+	// Delete primary
+	hm.DeleteHostInfo(h1)
+	assert.Nil(t, h1.prev)
+	assert.Nil(t, h1.next)
+
+	// Make sure we go h2 -> h3 -> h4 -> h5
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h2.localIndexId, prim.localIndexId)
+	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h3.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h3.prev.localIndexId)
+	assert.Equal(t, h4.localIndexId, h3.next.localIndexId)
+	assert.Equal(t, h3.localIndexId, h4.prev.localIndexId)
+	assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
+	assert.Nil(t, h5.next)
+
+	// Delete in the middle
+	hm.DeleteHostInfo(h3)
+	assert.Nil(t, h3.prev)
+	assert.Nil(t, h3.next)
+
+	// Make sure we go h2 -> h4 -> h5
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h2.localIndexId, prim.localIndexId)
+	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
+	assert.Equal(t, h5.localIndexId, h4.next.localIndexId)
+	assert.Equal(t, h4.localIndexId, h5.prev.localIndexId)
+	assert.Nil(t, h5.next)
+
+	// Delete the tail
+	hm.DeleteHostInfo(h5)
+	assert.Nil(t, h5.prev)
+	assert.Nil(t, h5.next)
+
+	// Make sure we go h2 -> h4
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h2.localIndexId, prim.localIndexId)
+	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Equal(t, h4.localIndexId, h2.next.localIndexId)
+	assert.Equal(t, h2.localIndexId, h4.prev.localIndexId)
+	assert.Nil(t, h4.next)
+
+	// Delete the head
+	hm.DeleteHostInfo(h2)
+	assert.Nil(t, h2.prev)
+	assert.Nil(t, h2.next)
+
+	// Make sure we only have h4
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Equal(t, h4.localIndexId, prim.localIndexId)
+	assert.Nil(t, prim.prev)
+	assert.Nil(t, prim.next)
+	assert.Nil(t, h4.next)
+
+	// Delete the only item
+	hm.DeleteHostInfo(h4)
+	assert.Nil(t, h4.prev)
+	assert.Nil(t, h4.next)
+
+	// Make sure we have nil
+	prim, _ = hm.QueryVpnIp(1)
+	assert.Nil(t, prim)
+}

+ 5 - 3
outside.go

@@ -245,9 +245,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	f.connectionManager.ClearLocalIndex(hostInfo.localIndexId)
 	f.connectionManager.ClearLocalIndex(hostInfo.localIndexId)
 	f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId)
 	f.connectionManager.ClearPendingDeletion(hostInfo.localIndexId)
-	f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
-
-	f.hostMap.DeleteHostInfo(hostInfo)
+	final := f.hostMap.DeleteHostInfo(hostInfo)
+	if final {
+		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
+		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
+	}
 }
 }
 
 
 // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
 // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote

+ 1 - 1
overlay/tun_tester.go

@@ -51,7 +51,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []Route, _ int
 // packets should exit the udp side, capture them with udpConn.Get
 // packets should exit the udp side, capture them with udpConn.Get
 func (t *TestTun) Send(packet []byte) {
 func (t *TestTun) Send(packet []byte) {
 	if t.l.Level >= logrus.InfoLevel {
 	if t.l.Level >= logrus.InfoLevel {
-		t.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
+		t.l.WithField("dataLen", len(packet)).Debug("Tun receiving injected packet")
 	}
 	}
 	t.rxPackets <- packet
 	t.rxPackets <- packet
 }
 }

+ 5 - 0
relay_manager.go

@@ -61,6 +61,11 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp iput
 
 
 		_, inRelays := hm.Relays[index]
 		_, inRelays := hm.Relays[index]
 		if !inRelays {
 		if !inRelays {
+			// Avoid standing up a relay that can't be used since only the primary hostinfo
+			// will be pointed to by the relay logic
+			//TODO: if there was an existing primary and it had relay state, should we merge?
+			hm.unlockedMakePrimary(relayHostInfo)
+
 			hm.Relays[index] = relayHostInfo
 			hm.Relays[index] = relayHostInfo
 			newRelay := Relay{
 			newRelay := Relay{
 				Type:       relayType,
 				Type:       relayType,

+ 12 - 3
ssh.go

@@ -22,8 +22,9 @@ import (
 )
 )
 
 
 type sshListHostMapFlags struct {
 type sshListHostMapFlags struct {
-	Json   bool
-	Pretty bool
+	Json    bool
+	Pretty  bool
+	ByIndex bool
 }
 }
 
 
 type sshPrintCertFlags struct {
 type sshPrintCertFlags struct {
@@ -174,6 +175,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			s := sshListHostMapFlags{}
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
+			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
 			return fl, &s
 			return fl, &s
 		},
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
@@ -189,6 +191,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			s := sshListHostMapFlags{}
 			s := sshListHostMapFlags{}
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Json, "json", false, "outputs as json with more information")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
 			fl.BoolVar(&s.Pretty, "pretty", false, "pretty prints json, assumes -json")
+			fl.BoolVar(&s.ByIndex, "by-index", false, "gets all hosts in the hostmap from the index table")
 			return fl, &s
 			return fl, &s
 		},
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
@@ -368,7 +371,13 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 		return nil
 		return nil
 	}
 	}
 
 
-	hm := listHostMap(hostMap)
+	var hm []ControlHostInfo
+	if fs.ByIndex {
+		hm = listHostMapIndexes(hostMap)
+	} else {
+		hm = listHostMapHosts(hostMap)
+	}
+
 	sort.Slice(hm, func(i, j int) bool {
 	sort.Slice(hm, func(i, j int) bool {
 		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
 		return bytes.Compare(hm[i].VpnIp, hm[j].VpnIp) < 0
 	})
 	})

+ 1 - 1
udp/udp_tester.go

@@ -66,7 +66,7 @@ func (u *Conn) Send(packet *Packet) {
 		u.l.WithField("header", h).
 		u.l.WithField("header", h).
 			WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
 			WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
 			WithField("dataLen", len(packet.Data)).
 			WithField("dataLen", len(packet.Data)).
-			Info("UDP receiving injected packet")
+			Debug("UDP receiving injected packet")
 	}
 	}
 	u.RxPackets <- packet
 	u.RxPackets <- packet
 }
 }