Jelajahi Sumber

Refactor remotes and handshaking to give every address a fair shot (#437)

Nathan Brown 4 tahun lalu
induk
melakukan
710df6a876
25 mengubah file dengan 1546 tambahan dan 1370 penghapusan
  1. 27 23
      control.go
  2. 6 4
      control_test.go
  3. 27 3
      control_tester.go
  4. 97 132
      e2e/handshakes_test.go
  5. 3 0
      e2e/helpers_test.go
  6. 3 0
      e2e/router/doc.go
  7. 119 20
      e2e/router/router.go
  8. 4 4
      examples/config.yml
  9. 36 31
      handshake_ix.go
  10. 122 106
      handshake_manager.go
  11. 28 185
      handshake_manager_test.go
  12. 52 249
      hostmap.go
  13. 0 168
      hostmap_test.go
  14. 7 50
      inside.go
  15. 117 226
      lighthouse.go
  16. 100 79
      lighthouse_test.go
  17. 3 4
      main.go
  18. 6 3
      outside.go
  19. 500 0
      remote_list.go
  20. 228 0
      remote_list_test.go
  21. 39 48
      ssh.go
  22. 1 3
      tun_tester.go
  23. 3 3
      udp_all.go
  24. 9 28
      udp_linux.go
  25. 9 1
      udp_tester.go

+ 27 - 23
control.go

@@ -67,23 +67,11 @@ func (c *Control) RebindUDPServer() {
 
 
 // ListHostmap returns details about the actual or pending (handshaking) hostmap
 // ListHostmap returns details about the actual or pending (handshaking) hostmap
 func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
 func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
-	var hm *HostMap
 	if pendingMap {
 	if pendingMap {
-		hm = c.f.handshakeManager.pendingHostMap
+		return listHostMap(c.f.handshakeManager.pendingHostMap)
 	} else {
 	} else {
-		hm = c.f.hostMap
+		return listHostMap(c.f.hostMap)
 	}
 	}
-
-	hm.RLock()
-	hosts := make([]ControlHostInfo, len(hm.Hosts))
-	i := 0
-	for _, v := range hm.Hosts {
-		hosts[i] = copyHostInfo(v)
-		i++
-	}
-	hm.RUnlock()
-
-	return hosts
 }
 }
 
 
 // GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
 // GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
@@ -100,7 +88,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
 		return nil
 		return nil
 	}
 	}
 
 
-	ch := copyHostInfo(h)
+	ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
 	return &ch
 	return &ch
 }
 }
 
 
@@ -112,7 +100,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
 	}
 	}
 
 
 	hostInfo.SetRemote(addr.Copy())
 	hostInfo.SetRemote(addr.Copy())
-	ch := copyHostInfo(hostInfo)
+	ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
 	return &ch
 	return &ch
 }
 }
 
 
@@ -163,14 +151,17 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	return
 	return
 }
 }
 
 
-func copyHostInfo(h *HostInfo) ControlHostInfo {
+func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	chi := ControlHostInfo{
 	chi := ControlHostInfo{
-		VpnIP:          int2ip(h.hostId),
-		LocalIndex:     h.localIndexId,
-		RemoteIndex:    h.remoteIndexId,
-		RemoteAddrs:    h.CopyRemotes(),
-		CachedPackets:  len(h.packetStore),
-		MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
+		VpnIP:         int2ip(h.hostId),
+		LocalIndex:    h.localIndexId,
+		RemoteIndex:   h.remoteIndexId,
+		RemoteAddrs:   h.remotes.CopyAddrs(preferredRanges),
+		CachedPackets: len(h.packetStore),
+	}
+
+	if h.ConnectionState != nil {
+		chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter)
 	}
 	}
 
 
 	if c := h.GetCert(); c != nil {
 	if c := h.GetCert(); c != nil {
@@ -183,3 +174,16 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
 
 
 	return chi
 	return chi
 }
 }
+
+func listHostMap(hm *HostMap) []ControlHostInfo {
+	hm.RLock()
+	hosts := make([]ControlHostInfo, len(hm.Hosts))
+	i := 0
+	for _, v := range hm.Hosts {
+		hosts[i] = copyHostInfo(v, hm.preferredRanges)
+		i++
+	}
+	hm.RUnlock()
+
+	return hosts
+}

+ 6 - 4
control_test.go

@@ -45,10 +45,12 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 	}
 	}
 
 
-	remotes := []*udpAddr{remote1, remote2}
+	remotes := NewRemoteList()
+	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
+	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	hm.Add(ip2int(ipNet.IP), &HostInfo{
 	hm.Add(ip2int(ipNet.IP), &HostInfo{
 		remote:  remote1,
 		remote:  remote1,
-		Remotes: remotes,
+		remotes: remotes,
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: crt,
 			peerCert: crt,
 		},
 		},
@@ -59,7 +61,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 
 
 	hm.Add(ip2int(ipNet2.IP), &HostInfo{
 	hm.Add(ip2int(ipNet2.IP), &HostInfo{
 		remote:  remote1,
 		remote:  remote1,
-		Remotes: remotes,
+		remotes: remotes,
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
 			peerCert: nil,
 			peerCert: nil,
 		},
 		},
@@ -81,7 +83,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		VpnIP:          net.IPv4(1, 2, 3, 4).To4(),
 		VpnIP:          net.IPv4(1, 2, 3, 4).To4(),
 		LocalIndex:     201,
 		LocalIndex:     201,
 		RemoteIndex:    200,
 		RemoteIndex:    200,
-		RemoteAddrs:    []*udpAddr{remote1, remote2},
+		RemoteAddrs:    []*udpAddr{remote2, remote1},
 		CachedPackets:  0,
 		CachedPackets:  0,
 		Cert:           crt.Copy(),
 		Cert:           crt.Copy(),
 		MessageCounter: 0,
 		MessageCounter: 0,

+ 27 - 3
control_tester.go

@@ -44,7 +44,18 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
 func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
-	c.f.lightHouse.AddRemote(ip2int(vpnIp), &udpAddr{IP: toAddr.IP, Port: uint16(toAddr.Port)}, false)
+	c.f.lightHouse.Lock()
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
+	remoteList.Lock()
+	defer remoteList.Unlock()
+	c.f.lightHouse.Unlock()
+
+	iVpnIp := ip2int(vpnIp)
+	if v4 := toAddr.IP.To4(); v4 != nil {
+		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+	} else {
+		remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+	}
 }
 }
 
 
 // GetFromTun will pull a packet off the tun side of nebula
 // GetFromTun will pull a packet off the tun side of nebula
@@ -84,14 +95,17 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 		SrcPort: layers.UDPPort(fromPort),
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 		DstPort: layers.UDPPort(toPort),
 	}
 	}
-	udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(&ip)
+	if err != nil {
+		panic(err)
+	}
 
 
 	buffer := gopacket.NewSerializeBuffer()
 	buffer := gopacket.NewSerializeBuffer()
 	opt := gopacket.SerializeOptions{
 	opt := gopacket.SerializeOptions{
 		ComputeChecksums: true,
 		ComputeChecksums: true,
 		FixLengths:       true,
 		FixLengths:       true,
 	}
 	}
-	err := gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
@@ -102,3 +116,13 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 func (c *Control) GetUDPAddr() string {
 func (c *Control) GetUDPAddr() string {
 	return c.f.outside.addr.String()
 	return c.f.outside.addr.String()
 }
 }
+
+func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
+	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
+	if !ok {
+		return false
+	}
+
+	c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
+	return true
+}

+ 97 - 132
e2e/handshakes_test.go

@@ -9,6 +9,7 @@ import (
 
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/e2e/router"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/stretchr/testify/assert"
 )
 )
 
 
 func TestGoodHandshake(t *testing.T) {
 func TestGoodHandshake(t *testing.T) {
@@ -23,35 +24,35 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.Start()
 	myControl.Start()
 	theirControl.Start()
 	theirControl.Start()
 
 
-	// Send a udp packet through to begin standing up the tunnel, this should come out the other side
+	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
 
 
-	// Have them consume my stage 0 packet. They have a tunnel now
+	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
 
 
-	// Get their stage 1 packet so that we can play with it
+	t.Log("Get their stage 1 packet so that we can play with it")
 	stage1Packet := theirControl.GetFromUDP(true)
 	stage1Packet := theirControl.GetFromUDP(true)
 
 
-	// I consume a garbage packet with a proper nebula header for our tunnel
+	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
 	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
 	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
 	badPacket := stage1Packet.Copy()
 	badPacket := stage1Packet.Copy()
 	badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
 	badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
 	myControl.InjectUDPPacket(badPacket)
 	myControl.InjectUDPPacket(badPacket)
 
 
-	// Have me consume their real stage 1 packet. I have a tunnel now
+	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
 	myControl.InjectUDPPacket(stage1Packet)
 	myControl.InjectUDPPacket(stage1Packet)
 
 
-	// Wait until we see my cached packet come through
+	t.Log("Wait until we see my cached packet come through")
 	myControl.WaitForType(1, 0, theirControl)
 	myControl.WaitForType(1, 0, theirControl)
 
 
-	// Make sure our host infos are correct
+	t.Log("Make sure our host infos are correct")
 	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
 	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
 
 
-	// Get that cached packet and make sure it looks right
+	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
 	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
 	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
 
 
-	// Do a bidirectional tunnel test
+	t.Log("Do a bidirectional tunnel test")
 	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl))
 	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl))
 
 
 	myControl.Stop()
 	myControl.Stop()
@@ -62,14 +63,17 @@ func TestGoodHandshake(t *testing.T) {
 func TestWrongResponderHandshake(t *testing.T) {
 func TestWrongResponderHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 
 
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
-	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 99})
+	// The IPs here are chosen on purpose:
+	// The current remote handling will sort by preference, public, and then lexically.
+	// So we need them to have a higher address than evil (we could apply a preference though)
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
+	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
 
 
-	// Add their real udp addr, which should be tried after evil. Doing this first because learned addresses are prepended
+	// Add their real udp addr, which should be tried after evil.
 	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
 	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
 
 
-	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. This will now be the first attempted ip
+	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
 	myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
 	myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
 
 
 	// Build a router so we don't have to reason who gets which packet
 	// Build a router so we don't have to reason who gets which packet
@@ -80,137 +84,98 @@ func TestWrongResponderHandshake(t *testing.T) {
 	theirControl.Start()
 	theirControl.Start()
 	evilControl.Start()
 	evilControl.Start()
 
 
-	t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)")
+	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-	r.OnceFrom(myControl)
-	r.OnceFrom(evilControl)
+	r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
+		h := &nebula.Header{}
+		err := h.Parse(p.Data)
+		if err != nil {
+			panic(err)
+		}
 
 
-	t.Log("I should have a tunnel with evil now and there should not be a cached packet waiting for us")
-	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
-	assertHostInfoPair(t, myUdpAddr, evilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
+		if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+			return router.RouteAndExit
+		}
 
 
-	//TODO: Assert pending hostmap - I should have a correct hostinfo for them now
+		return router.KeepRouting
+	})
 
 
-	t.Log("Lets let the messages fly, this time we should have a tunnel with them")
-	r.OnceFrom(myControl)
-	r.OnceFrom(theirControl)
+	//TODO: Assert pending hostmap - I should have a correct hostinfo for them now
 
 
-	t.Log("I should now have a tunnel with them now and my original packet should get there")
-	r.RouteUntilAfterMsgType(myControl, 1, 0)
+	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
 	myCachedPacket := theirControl.GetFromTun(true)
 	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
 	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
 
 
-	t.Log("I should now have a proper tunnel with them")
+	t.Log("Test the tunnel with them")
 	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
 	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
 	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
 	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
 
 
-	t.Log("Lets make sure evil is still good")
-	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
+	t.Log("Flush all packets from all controllers")
+	r.FlushAll()
+
+	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
+	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 
 	//TODO: assert hostmaps for everyone
 	//TODO: assert hostmaps for everyone
 	t.Log("Success!")
 	t.Log("Success!")
-	//TODO: myControl is attempting to shut down 2 tunnels but is blocked on the udp txChan after the first close message
-	// what we really need here is a way to exit all the go routines loops (there are many)
-	//myControl.Stop()
-	//theirControl.Stop()
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func Test_Case1_Stage1Race(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1})
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
+
+	// Put their info in our lighthouse and vice versa
+	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(myControl, theirControl)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake to start on both me and them")
+	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
+
+	t.Log("Get both stage 1 handshake packets")
+	myHsForThem := myControl.GetFromUDP(true)
+	theirHsForMe := theirControl.GetFromUDP(true)
+
+	t.Log("Now inject both stage 1 handshake packets")
+	myControl.InjectUDPPacket(theirHsForMe)
+	theirControl.InjectUDPPacket(myHsForThem)
+	//TODO: they should win, grab their index for me and make sure I use it in the end.
+
+	t.Log("They should not have a stage 2 (won the race) but I should send one")
+	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
+
+	t.Log("Route for me until I send a message packet to them")
+	myControl.WaitForType(1, 0, theirControl)
+
+	t.Log("My cached packet should be received by them")
+	myCachedPacket := theirControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+
+	t.Log("Route for them until I send a message packet to me")
+	theirControl.WaitForType(1, 0, myControl)
+
+	t.Log("Their cached packet should be received by me")
+	theirCachedPacket := myControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
+
+	t.Log("Do a bidirectional tunnel test")
+	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+	//TODO: assert hostmaps
 }
 }
 
 
-////TODO: We need to test lies both as the race winner and race loser
-//func TestManyWrongResponderHandshake(t *testing.T) {
-//	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-//
-//	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 99})
-//	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
-//	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 1})
-//
-//	t.Log("Build a router so we don't have to reason who gets which packet")
-//	r := newRouter(myControl, theirControl, evilControl)
-//
-//	t.Log("Lets add more than 10 evil addresses, this exceeds the hostinfo remotes limit")
-//	for i := 0; i < 10; i++ {
-//		addr := net.UDPAddr{IP: evilUdpAddr.IP, Port: evilUdpAddr.Port + i}
-//		myControl.InjectLightHouseAddr(theirVpnIp, &addr)
-//		// We also need to tell our router about it
-//		r.AddRoute(addr.IP, uint16(addr.Port), evilControl)
-//	}
-//
-//	// Start the servers
-//	myControl.Start()
-//	theirControl.Start()
-//	evilControl.Start()
-//
-//	t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)")
-//	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-//
-//	t.Log("We need to spin until we get to the right remote for them")
-//	getOut := false
-//	injected := false
-//	for {
-//		t.Log("Routing for me and evil while we work through the bad ips")
-//		r.RouteExitFunc(myControl, func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType {
-//			// We should stop routing right after we see a packet coming from us to them
-//			if *receiver == *theirControl {
-//				getOut = true
-//				return drainAndExit
-//			}
-//
-//			// We need to poke our real ip in at some point, this is a well protected check looking for that moment
-//			if *receiver == *evilControl {
-//				hi := myControl.GetHostInfoByVpnIP(ip2int(theirVpnIp), true)
-//				if !injected && len(hi.RemoteAddrs) == 1 {
-//					t.Log("I am on my last ip for them, time to inject the real one into my lighthouse")
-//					myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
-//					injected = true
-//				}
-//				return drainAndExit
-//			}
-//
-//			return keepRouting
-//		})
-//
-//		if getOut {
-//			break
-//		}
-//
-//		r.RouteForUntilAfterToAddr(evilControl, myUdpAddr, drainAndExit)
-//	}
-//
-//	t.Log("I should have a tunnel with evil and them, evil should not have a cached packet")
-//	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
-//	evilHostInfo := myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false)
-//	realEvilUdpAddr := &net.UDPAddr{IP: evilHostInfo.CurrentRemote.IP, Port: int(evilHostInfo.CurrentRemote.Port)}
-//
-//	t.Log("Assert mine and evil's host pairs", evilUdpAddr, realEvilUdpAddr)
-//	assertHostInfoPair(t, myUdpAddr, realEvilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
-//
-//	//t.Log("Draining everyones packets")
-//	//r.Drain(theirControl)
-//	//r.DrainAll(myControl, theirControl, evilControl)
-//	//
-//	//go func() {
-//	//	for {
-//	//		time.Sleep(10 * time.Millisecond)
-//	//		t.Log(len(theirControl.GetUDPTxChan()))
-//	//		t.Log(len(theirControl.GetTunTxChan()))
-//	//		t.Log(len(myControl.GetUDPTxChan()))
-//	//		t.Log(len(evilControl.GetUDPTxChan()))
-//	//		t.Log("=====")
-//	//	}
-//	//}()
-//
-//	t.Log("I should have a tunnel with them now and my original packet should get there")
-//	r.RouteUntilAfterMsgType(myControl, 1, 0)
-//	myCachedPacket := theirControl.GetFromTun(true)
-//
-//	t.Log("Got the cached packet, lets test the tunnel")
-//	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
-//
-//	t.Log("Testing tunnels with them")
-//	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
-//	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
-//
-//	t.Log("Testing tunnels with evil")
-//	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
-//
-//	//TODO: assert hostmaps for everyone
-//}
+//TODO: add a test with many lies

+ 3 - 0
e2e/helpers_test.go

@@ -64,6 +64,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 				"host":  "any",
 				"host":  "any",
 			}},
 			}},
 		},
 		},
+		//"handshakes": m{
+		//	"try_interval": "1s",
+		//},
 		"listen": m{
 		"listen": m{
 			"host": udpAddr.IP.String(),
 			"host": udpAddr.IP.String(),
 			"port": udpAddr.Port,
 			"port": udpAddr.Port,

+ 3 - 0
e2e/router/doc.go

@@ -0,0 +1,3 @@
+package router
+
+// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before

+ 119 - 20
e2e/router/router.go

@@ -5,6 +5,7 @@ package router
 import (
 import (
 	"fmt"
 	"fmt"
 	"net"
 	"net"
+	"reflect"
 	"strconv"
 	"strconv"
 	"sync"
 	"sync"
 
 
@@ -28,18 +29,18 @@ type R struct {
 	sync.Mutex
 	sync.Mutex
 }
 }
 
 
-type exitType int
+type ExitType int
 
 
 const (
 const (
 	// Keeps routing, the function will get called again on the next packet
 	// Keeps routing, the function will get called again on the next packet
-	keepRouting exitType = 0
+	KeepRouting ExitType = 0
 	// Does not route this packet and exits immediately
 	// Does not route this packet and exits immediately
-	exitNow exitType = 1
+	ExitNow ExitType = 1
 	// Routes this packet and exits immediately afterwards
 	// Routes this packet and exits immediately afterwards
-	routeAndExit exitType = 2
+	RouteAndExit ExitType = 2
 )
 )
 
 
-type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType
+type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
 
 
 func NewR(controls ...*nebula.Control) *R {
 func NewR(controls ...*nebula.Control) *R {
 	r := &R{
 	r := &R{
@@ -77,8 +78,8 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
 // OnceFrom will route a single packet from sender then return
 // OnceFrom will route a single packet from sender then return
 // If the router doesn't have the nebula controller for that address, we panic
 // If the router doesn't have the nebula controller for that address, we panic
 func (r *R) OnceFrom(sender *nebula.Control) {
 func (r *R) OnceFrom(sender *nebula.Control) {
-	r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) exitType {
-		return routeAndExit
+	r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
+		return RouteAndExit
 	})
 	})
 }
 }
 
 
@@ -116,7 +117,6 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 //   - exitNow: the packet will not be routed and this call will return immediately
 //   - exitNow: the packet will not be routed and this call will return immediately
 //   - routeAndExit: this call will return immediately after routing the last packet from sender
 //   - routeAndExit: this call will return immediately after routing the last packet from sender
 //   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
 //   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
-//TODO: is this RouteWhile?
 func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 	h := &nebula.Header{}
 	h := &nebula.Header{}
 	for {
 	for {
@@ -136,16 +136,16 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 
 
 		e := whatDo(p, receiver)
 		e := whatDo(p, receiver)
 		switch e {
 		switch e {
-		case exitNow:
+		case ExitNow:
 			r.Unlock()
 			r.Unlock()
 			return
 			return
 
 
-		case routeAndExit:
+		case RouteAndExit:
 			receiver.InjectUDPPacket(p)
 			receiver.InjectUDPPacket(p)
 			r.Unlock()
 			r.Unlock()
 			return
 			return
 
 
-		case keepRouting:
+		case KeepRouting:
 			receiver.InjectUDPPacket(p)
 			receiver.InjectUDPPacket(p)
 
 
 		default:
 		default:
@@ -160,35 +160,135 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 // If the router doesn't have the nebula controller for that address, we panic
 // If the router doesn't have the nebula controller for that address, we panic
 func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
 func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
 	h := &nebula.Header{}
 	h := &nebula.Header{}
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType {
+	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
 		if err := h.Parse(p.Data); err != nil {
 		if err := h.Parse(p.Data); err != nil {
 			panic(err)
 			panic(err)
 		}
 		}
 		if h.Type == msgType && h.Subtype == subType {
 		if h.Type == msgType && h.Subtype == subType {
-			return routeAndExit
+			return RouteAndExit
 		}
 		}
 
 
-		return keepRouting
+		return KeepRouting
 	})
 	})
 }
 }
 
 
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // If the router doesn't have the nebula controller for that address, we panic
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish exitType) {
-	if finish == keepRouting {
-		finish = routeAndExit
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+	if finish == KeepRouting {
+		finish = RouteAndExit
 	}
 	}
 
 
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType {
+	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
 		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
 		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
 			return finish
 			return finish
 		}
 		}
 
 
-		return keepRouting
+		return KeepRouting
 	})
 	})
 }
 }
 
 
+// RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
+// whatDo can return:
+//   - exitNow: the packet will not be routed and this call will return immediately
+//   - routeAndExit: this call will return immediately after routing the last packet from sender
+//   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
+func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
+	sc := make([]reflect.SelectCase, len(r.controls))
+	cm := make([]*nebula.Control, len(r.controls))
+
+	i := 0
+	for _, c := range r.controls {
+		sc[i] = reflect.SelectCase{
+			Dir:  reflect.SelectRecv,
+			Chan: reflect.ValueOf(c.GetUDPTxChan()),
+			Send: reflect.Value{},
+		}
+
+		cm[i] = c
+		i++
+	}
+
+	for {
+		x, rx, _ := reflect.Select(sc)
+		r.Lock()
+
+		p := rx.Interface().(*nebula.UdpPacket)
+
+		outAddr := cm[x].GetUDPAddr()
+		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
+		receiver := r.getControl(outAddr, inAddr, p)
+		if receiver == nil {
+			r.Unlock()
+			panic("Can't route for host: " + inAddr)
+		}
+
+		e := whatDo(p, receiver)
+		switch e {
+		case ExitNow:
+			r.Unlock()
+			return
+
+		case RouteAndExit:
+			receiver.InjectUDPPacket(p)
+			r.Unlock()
+			return
+
+		case KeepRouting:
+			receiver.InjectUDPPacket(p)
+
+		default:
+			panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
+		}
+		r.Unlock()
+	}
+}
+
+// FlushAll will route for every registered controller, exiting once there are no packets left to route
+func (r *R) FlushAll() {
+	sc := make([]reflect.SelectCase, len(r.controls))
+	cm := make([]*nebula.Control, len(r.controls))
+
+	i := 0
+	for _, c := range r.controls {
+		sc[i] = reflect.SelectCase{
+			Dir:  reflect.SelectRecv,
+			Chan: reflect.ValueOf(c.GetUDPTxChan()),
+			Send: reflect.Value{},
+		}
+
+		cm[i] = c
+		i++
+	}
+
+	// Add a default case to exit when nothing is left to send
+	sc = append(sc, reflect.SelectCase{
+		Dir:  reflect.SelectDefault,
+		Chan: reflect.Value{},
+		Send: reflect.Value{},
+	})
+
+	for {
+		x, rx, ok := reflect.Select(sc)
+		if !ok {
+			return
+		}
+		r.Lock()
+
+		p := rx.Interface().(*nebula.UdpPacket)
+
+		outAddr := cm[x].GetUDPAddr()
+		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
+		receiver := r.getControl(outAddr, inAddr, p)
+		if receiver == nil {
+			r.Unlock()
+			panic("Can't route for host: " + inAddr)
+		}
+		r.Unlock()
+	}
+}
+
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
 // This is an internal router function, the caller must hold the lock
 func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
 func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
@@ -216,6 +316,5 @@ func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Con
 		return c
 		return c
 	}
 	}
 
 
-	//TODO: call receive hooks!
 	return r.controls[toAddr]
 	return r.controls[toAddr]
 }
 }

+ 4 - 4
examples/config.yml

@@ -202,16 +202,16 @@ logging:
 
 
 # Handshake Manger Settings
 # Handshake Manger Settings
 #handshakes:
 #handshakes:
-  # Total time to try a handshake = sequence of `try_interval * retries`
-  # With 100ms interval and 20 retries it is 23.5 seconds
+  # Handshakes are sent to all known addresses at each interval with a linear backoff,
+  # Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
+  # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
   #try_interval: 100ms
   #try_interval: 100ms
   #retries: 20
   #retries: 20
-  # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
-  #wait_rotation: 5
   # trigger_buffer is the size of the buffer channel for quickly sending handshakes
   # trigger_buffer is the size of the buffer channel for quickly sending handshakes
   # after receiving the response for lighthouse queries
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64
   #trigger_buffer: 64
 
 
+
 # Nebula security group configuration
 # Nebula security group configuration
 firewall:
 firewall:
   conntrack:
   conntrack:

+ 36 - 31
handshake_ix.go

@@ -14,14 +14,10 @@ import (
 // Sending is done by the handshake manager
 // Sending is done by the handshake manager
 func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	// This queries the lighthouse if we don't know a remote for the host
 	// This queries the lighthouse if we don't know a remote for the host
+	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
+	// more quickly, effect is a quicker handshake.
 	if hostinfo.remote == nil {
 	if hostinfo.remote == nil {
-		ips, err := f.lightHouse.Query(vpnIp, f)
-		if err != nil {
-			//l.Debugln(err)
-		}
-		for _, ip := range ips {
-			hostinfo.AddRemote(ip)
-		}
+		f.lightHouse.QueryServer(vpnIp, f)
 	}
 	}
 
 
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
@@ -69,7 +65,6 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hostinfo.HandshakePacket[0] = msg
 	hostinfo.HandshakePacket[0] = msg
 	hostinfo.HandshakeReady = true
 	hostinfo.HandshakeReady = true
 	hostinfo.handshakeStart = time.Now()
 	hostinfo.handshakeStart = time.Now()
-
 }
 }
 
 
 func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
@@ -125,13 +120,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
 		ConnectionState: ci,
 		ConnectionState: ci,
-		Remotes:         []*udpAddr{},
 		localIndexId:    myIndex,
 		localIndexId:    myIndex,
 		remoteIndexId:   hs.Details.InitiatorIndex,
 		remoteIndexId:   hs.Details.InitiatorIndex,
 		hostId:          vpnIP,
 		hostId:          vpnIP,
 		HandshakePacket: make(map[uint8][]byte, 0),
 		HandshakePacket: make(map[uint8][]byte, 0),
 	}
 	}
 
 
+	hostinfo.Lock()
+	defer hostinfo.Unlock()
+
 	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
 		WithField("fingerprint", fingerprint).
@@ -182,16 +179,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	ci.peerCert = remoteCert
 	ci.peerCert = remoteCert
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
 	ci.eKey = NewNebulaCipherState(eKey)
-	//l.Debugln("got symmetric pairs")
 
 
-	//hostinfo.ClearRemotes()
-	hostinfo.AddRemote(addr)
-	hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
+	hostinfo.SetRemote(addr)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	// Only overwrite existing record if we should win the handshake race
 	// Only overwrite existing record if we should win the handshake race
 	overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
 	overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
@@ -214,6 +206,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		case ErrExistingHostInfo:
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and we didn't win
 			// This means there was an existing tunnel and we didn't win
 			// handshake avoidance
 			// handshake avoidance
+
+			//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
+			//TODO: if not new send a test packet like old
+
 			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
 				WithField("fingerprint", fingerprint).
@@ -234,6 +230,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 				WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
 				WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
 				Error("Failed to add HostInfo due to localIndex collision")
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
 			return
+		case ErrExistingHandshake:
+			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
+			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("fingerprint", fingerprint).
+				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+				Error("Prevented a pending handshake race")
+			return
 		default:
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
 			// And we forget to update it here
@@ -286,6 +291,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Info("Handshake is already complete")
 			Info("Handshake is already complete")
 
 
+		//TODO: evaluate addr for preference, if we handshook with a less preferred addr we can correct quickly here
+
 		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
 		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
 		return false
 		return false
 	}
 	}
@@ -334,17 +341,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	certName := remoteCert.Details.Name
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 	fingerprint, _ := remoteCert.Sha256Sum()
 
 
+	// Ensure the right host responded
 	if vpnIP != hostinfo.hostId {
 	if vpnIP != hostinfo.hostId {
 		f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
 		f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 			Info("Incorrect host responded to handshake")
 
 
-		if ho, _ := f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIP); ho != nil {
-			// We might have a pending tunnel to this host already, clear out that attempt since we have a tunnel now
-			f.handshakeManager.pendingHostMap.DeleteHostInfo(ho)
-		}
-
 		// Release our old handshake from pending, it should not continue
 		// Release our old handshake from pending, it should not continue
 		f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
 		f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
 
 
@@ -354,26 +357,28 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		newHostInfo.Lock()
 		newHostInfo.Lock()
 
 
 		// Block the current used address
 		// Block the current used address
-		newHostInfo.unlockedBlockRemote(addr)
+		newHostInfo.remotes = hostinfo.remotes
+		newHostInfo.remotes.BlockRemote(addr)
 
 
-		// If this is an ongoing issue our previous hostmap will have some bad ips too
-		for _, v := range hostinfo.badRemotes {
-			newHostInfo.unlockedBlockRemote(v)
-		}
-		//TODO: this is me enabling tests
-		newHostInfo.ForcePromoteBest(f.hostMap.preferredRanges)
+		// Get the correct remote list for the host we did handshake with
+		hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
 
 
-		f.l.WithField("blockedUdpAddrs", newHostInfo.badRemotes).WithField("vpnIp", IntIp(vpnIP)).
-			WithField("remotes", newHostInfo.Remotes).
+		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
+			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 			Info("Blocked addresses for handshakes")
 			Info("Blocked addresses for handshakes")
 
 
 		// Swap the packet store to benefit the original intended recipient
 		// Swap the packet store to benefit the original intended recipient
+		hostinfo.ConnectionState.queueLock.Lock()
 		newHostInfo.packetStore = hostinfo.packetStore
 		newHostInfo.packetStore = hostinfo.packetStore
 		hostinfo.packetStore = []*cachedPacket{}
 		hostinfo.packetStore = []*cachedPacket{}
+		hostinfo.ConnectionState.queueLock.Unlock()
 
 
-		// Set the current hostId to the new vpnIp
+		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
 		hostinfo.hostId = vpnIP
 		hostinfo.hostId = vpnIP
+		f.sendCloseTunnel(hostinfo)
 		newHostInfo.Unlock()
 		newHostInfo.Unlock()
+
+		return true
 	}
 	}
 
 
 	// Mark packet 2 as seen so it doesn't show up as missed
 	// Mark packet 2 as seen so it doesn't show up as missed

+ 122 - 106
handshake_manager.go

@@ -12,12 +12,8 @@ import (
 )
 )
 
 
 const (
 const (
-	// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
-	// With 100ms interval and 20 retries is 23.5 seconds
-	DefaultHandshakeTryInterval = time.Millisecond * 100
-	DefaultHandshakeRetries     = 20
-	// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
-	DefaultHandshakeWaitRotation  = 5
+	DefaultHandshakeTryInterval   = time.Millisecond * 100
+	DefaultHandshakeRetries       = 10
 	DefaultHandshakeTriggerBuffer = 64
 	DefaultHandshakeTriggerBuffer = 64
 )
 )
 
 
@@ -25,7 +21,6 @@ var (
 	defaultHandshakeConfig = HandshakeConfig{
 	defaultHandshakeConfig = HandshakeConfig{
 		tryInterval:   DefaultHandshakeTryInterval,
 		tryInterval:   DefaultHandshakeTryInterval,
 		retries:       DefaultHandshakeRetries,
 		retries:       DefaultHandshakeRetries,
-		waitRotation:  DefaultHandshakeWaitRotation,
 		triggerBuffer: DefaultHandshakeTriggerBuffer,
 		triggerBuffer: DefaultHandshakeTriggerBuffer,
 	}
 	}
 )
 )
@@ -33,45 +28,36 @@ var (
 type HandshakeConfig struct {
 type HandshakeConfig struct {
 	tryInterval   time.Duration
 	tryInterval   time.Duration
 	retries       int
 	retries       int
-	waitRotation  int
 	triggerBuffer int
 	triggerBuffer int
 
 
 	messageMetrics *MessageMetrics
 	messageMetrics *MessageMetrics
 }
 }
 
 
 type HandshakeManager struct {
 type HandshakeManager struct {
-	pendingHostMap *HostMap
-	mainHostMap    *HostMap
-	lightHouse     *LightHouse
-	outside        *udpConn
-	config         HandshakeConfig
+	pendingHostMap         *HostMap
+	mainHostMap            *HostMap
+	lightHouse             *LightHouse
+	outside                *udpConn
+	config                 HandshakeConfig
+	OutboundHandshakeTimer *SystemTimerWheel
+	messageMetrics         *MessageMetrics
+	l                      *logrus.Logger
 
 
 	// can be used to trigger outbound handshake for the given vpnIP
 	// can be used to trigger outbound handshake for the given vpnIP
 	trigger chan uint32
 	trigger chan uint32
-
-	OutboundHandshakeTimer *SystemTimerWheel
-	InboundHandshakeTimer  *SystemTimerWheel
-
-	messageMetrics *MessageMetrics
-	l              *logrus.Logger
 }
 }
 
 
 func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
 func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
 	return &HandshakeManager{
-		pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
-		mainHostMap:    mainHostMap,
-		lightHouse:     lightHouse,
-		outside:        outside,
-
-		config: config,
-
-		trigger: make(chan uint32, config.triggerBuffer),
-
-		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
-		InboundHandshakeTimer:  NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
-
-		messageMetrics: config.messageMetrics,
-		l:              l,
+		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
+		mainHostMap:            mainHostMap,
+		lightHouse:             lightHouse,
+		outside:                outside,
+		config:                 config,
+		trigger:                make(chan uint32, config.triggerBuffer),
+		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		messageMetrics:         config.messageMetrics,
+		l:                      l,
 	}
 	}
 }
 }
 
 
@@ -84,7 +70,6 @@ func (c *HandshakeManager) Run(f EncWriter) {
 			c.handleOutbound(vpnIP, f, true)
 			c.handleOutbound(vpnIP, f, true)
 		case now := <-clockSource:
 		case now := <-clockSource:
 			c.NextOutboundHandshakeTimerTick(now, f)
 			c.NextOutboundHandshakeTimerTick(now, f)
-			c.NextInboundHandshakeTimerTick(now)
 		}
 		}
 	}
 	}
 }
 }
@@ -109,84 +94,84 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	hostinfo.Lock()
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 	defer hostinfo.Unlock()
 
 
-	// If we haven't finished the handshake and we haven't hit max retries, query
-	// lighthouse and then send the handshake packet again.
-	if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
-		if hostinfo.remote == nil {
-			// We continue to query the lighthouse because hosts may
-			// come online during handshake retries. If the query
-			// succeeds (no error), add the lighthouse info to hostinfo
-			ips := c.lightHouse.QueryCache(vpnIP)
-			// If we have no responses yet, or only one IP (the host hadn't
-			// finished reporting its own IPs yet), then send another query to
-			// the LH.
-			if len(ips) <= 1 {
-				ips, err = c.lightHouse.Query(vpnIP, f)
-			}
-			if err == nil {
-				for _, ip := range ips {
-					hostinfo.AddRemote(ip)
-				}
-				hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
-			}
-		} else if lighthouseTriggered {
-			// We were triggered by a lighthouse HostQueryReply packet, but
-			// we have already picked a remote for this host (this can happen
-			// if we are configured with multiple lighthouses). So we can skip
-			// this trigger and let the timerwheel handle the rest of the
-			// process
-			return
-		}
+	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
+	if hostinfo.HandshakeComplete {
+		// Ensure we don't exist in the pending hostmap anymore since we have completed
+		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		return
+	}
 
 
-		hostinfo.HandshakeCounter++
+	// Check if we have a handshake packet to transmit yet
+	if !hostinfo.HandshakeReady {
+		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
+		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
+		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		return
+	}
 
 
-		// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
-		// all the others until we can stand up a connection.
-		if hostinfo.HandshakeCounter > c.config.waitRotation {
-			hostinfo.rotateRemote()
-		}
+	// If we are out of time, clean up
+	if hostinfo.HandshakeCounter >= c.config.retries {
+		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
+			WithField("initiatorIndex", hostinfo.localIndexId).
+			WithField("remoteIndex", hostinfo.remoteIndexId).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
+			Info("Handshake timed out")
+		//TODO: emit metrics
+		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		return
+	}
 
 
-		// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
-		if hostinfo.HandshakeReady && hostinfo.remote != nil {
-			c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-			err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
-			if err != nil {
-				hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
-					WithField("initiatorIndex", hostinfo.localIndexId).
-					WithField("remoteIndex", hostinfo.remoteIndexId).
-					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-					WithError(err).Error("Failed to send handshake message")
-			} else {
-				//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
-				// keep the real packet struct around for logging purposes
-				hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
-					WithField("initiatorIndex", hostinfo.localIndexId).
-					WithField("remoteIndex", hostinfo.remoteIndexId).
-					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-					Info("Handshake message sent")
-			}
-		}
+	// We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific
+	// optimization for a fast lighthouse reply
+	//TODO: it would feel better to do this once, anytime, as our delay increases over time
+	if lighthouseTriggered && hostinfo.HandshakeCounter > 0 {
+		// If we didn't return here a lighthouse could cause us to aggressively send handshakes
+		return
+	}
 
 
-		// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
-		if !lighthouseTriggered {
-			//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
-			c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
-		}
-	} else {
-		c.pendingHostMap.DeleteHostInfo(hostinfo)
+	// Get a remotes object if we don't already have one.
+	// This is mainly to protect us as this should never be the case
+	if hostinfo.remotes == nil {
+		hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
 	}
 	}
-}
 
 
-func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
-	c.InboundHandshakeTimer.advance(now)
-	for {
-		ep := c.InboundHandshakeTimer.Purge()
-		if ep == nil {
-			break
+	//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
+	if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
+		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
+		// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
+		// the learned public ip for them. Query again to short circuit the promotion counter
+		c.lightHouse.QueryServer(vpnIP, f)
+	}
+
+	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
+	var sentTo []*udpAddr
+	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
+		c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+		if err != nil {
+			hostinfo.logger(c.l).WithField("udpAddr", addr).
+				WithField("initiatorIndex", hostinfo.localIndexId).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+				WithError(err).Error("Failed to send handshake message")
+
+		} else {
+			sentTo = append(sentTo, addr)
 		}
 		}
-		index := ep.(uint32)
+	})
+
+	hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+		WithField("initiatorIndex", hostinfo.localIndexId).
+		WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+		Info("Handshake message sent")
 
 
-		c.pendingHostMap.DeleteIndex(index)
+	// Increment the counter to increase our delay, linear backoff
+	hostinfo.HandshakeCounter++
+
+	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
+	if !lighthouseTriggered {
+		//TODO: feel like we dupe handshake real fast in a tight loop, why?
+		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 	}
 }
 }
 
 
@@ -194,6 +179,7 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
 	hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
 	hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
 	// We lock here and use an array to insert items to prevent locking the
 	// We lock here and use an array to insert items to prevent locking the
 	// main receive thread for very long by waiting to add items to the pending map
 	// main receive thread for very long by waiting to add items to the pending map
+	//TODO: what lock?
 	c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
 	c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
 
 
 	return hostinfo
 	return hostinfo
@@ -203,6 +189,7 @@ var (
 	ErrExistingHostInfo    = errors.New("existing hostinfo")
 	ErrExistingHostInfo    = errors.New("existing hostinfo")
 	ErrAlreadySeen         = errors.New("already seen")
 	ErrAlreadySeen         = errors.New("already seen")
 	ErrLocalIndexCollision = errors.New("local index collision")
 	ErrLocalIndexCollision = errors.New("local index collision")
+	ErrExistingHandshake   = errors.New("existing handshake")
 )
 )
 
 
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
@@ -217,17 +204,21 @@ var (
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 // hostmap for the hostinfo.localIndexId.
 func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
 func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
-	c.pendingHostMap.RLock()
-	defer c.pendingHostMap.RUnlock()
+	c.pendingHostMap.Lock()
+	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 	defer c.mainHostMap.Unlock()
 
 
+	// Check if we already have a tunnel with this vpn ip
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
 	if found && existingHostInfo != nil {
 	if found && existingHostInfo != nil {
+		// Is it just a delayed handshake packet?
 		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
 		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
 			return existingHostInfo, ErrAlreadySeen
 			return existingHostInfo, ErrAlreadySeen
 		}
 		}
+
 		if !overwrite {
 		if !overwrite {
+			// It's a new handshake and we lost the race
 			return existingHostInfo, ErrExistingHostInfo
 			return existingHostInfo, ErrExistingHostInfo
 		}
 		}
 	}
 	}
@@ -237,6 +228,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		// We have a collision, but for a different hostinfo
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 		return existingIndex, ErrLocalIndexCollision
 	}
 	}
+
 	existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
 	existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
 	if found && existingIndex != hostinfo {
 	if found && existingIndex != hostinfo {
 		// We have a collision, but for a different hostinfo
 		// We have a collision, but for a different hostinfo
@@ -252,7 +244,24 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			Info("New host shadows existing host remoteIndex")
 			Info("New host shadows existing host remoteIndex")
 	}
 	}
 
 
+	// Check if we are also handshaking with this vpn ip
+	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
+	if found && pendingHostInfo != nil {
+		if !overwrite {
+			// We won, let our pending handshake win
+			return pendingHostInfo, ErrExistingHandshake
+		}
+
+		// We lost, take this handshake and move any cached packets over so they get sent
+		pendingHostInfo.ConnectionState.queueLock.Lock()
+		hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
+		c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
+		pendingHostInfo.ConnectionState.queueLock.Unlock()
+		pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
+	}
+
 	if existingHostInfo != nil {
 	if existingHostInfo != nil {
+		hostinfo.logger(c.l).Info("Race lost, taking new handshake")
 		// We are going to overwrite this entry, so remove the old references
 		// We are going to overwrite this entry, so remove the old references
 		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
 		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
@@ -267,6 +276,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // won't have a localIndexId collision because we already have an entry in the
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap
 // pendingHostMap
 func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
+	c.pendingHostMap.Lock()
+	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 	defer c.mainHostMap.Unlock()
 
 
@@ -288,6 +299,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	}
 	}
 
 
 	c.mainHostMap.addHostInfo(hostinfo, f)
 	c.mainHostMap.addHostInfo(hostinfo, f)
+	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
 }
 }
 
 
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo
@@ -359,3 +371,7 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
 	}
 	}
 	return index, nil
 	return index, nil
 }
 }
+
+func hsTimeout(tries int, interval time.Duration) time.Duration {
+	return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
+}

+ 28 - 185
handshake_manager_test.go

@@ -8,66 +8,12 @@ import (
 	"github.com/stretchr/testify/assert"
 	"github.com/stretchr/testify/assert"
 )
 )
 
 
-//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
-var ips []uint32
-
-func Test_NewHandshakeManagerIndex(t *testing.T) {
-	l := NewTestLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
-	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextInboundHandshakeTimerTick(now)
-
-	var indexes = make([]uint32, 4)
-	var hostinfo = make([]*HostInfo, len(indexes))
-	for i := range indexes {
-		hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
-	}
-
-	// Add four indexes
-	for i := range indexes {
-		err := blah.AddIndexHostInfo(hostinfo[i])
-		assert.NoError(t, err)
-		indexes[i] = hostinfo[i].localIndexId
-		blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
-	}
-	// Confirm they are in the pending index list
-	for _, v := range indexes {
-		assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
-	}
-	// Adding something to pending should not affect the main hostmap
-	assert.Len(t, mainHM.Indexes, 0)
-	// Jump ahead 8 seconds
-	for i := 1; i <= DefaultHandshakeRetries; i++ {
-		next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
-		blah.NextInboundHandshakeTimerTick(next_tick)
-	}
-	// Confirm they are still in the pending index list
-	for _, v := range indexes {
-		assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
-	}
-	// Jump ahead 4 more seconds
-	next_tick := now.Add(12 * time.Second)
-	blah.NextInboundHandshakeTimerTick(next_tick)
-	// Confirm they have been removed
-	for _, v := range indexes {
-		assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
-	}
-}
-
 func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 	l := NewTestLogger()
 	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
+	ip := ip2int(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
@@ -77,39 +23,30 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 	now := time.Now()
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
 
-	// Add four "IPs" - which are just uint32s
-	for _, v := range ips {
-		blah.AddVpnIP(v)
-	}
+	i := blah.AddVpnIP(ip)
+	i.remotes = NewRemoteList()
+	i.HandshakeReady = true
+
 	// Adding something to pending should not affect the main hostmap
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
 	assert.Len(t, mainHM.Hosts, 0)
+
 	// Confirm they are in the pending index list
 	// Confirm they are in the pending index list
-	for _, v := range ips {
-		assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
-	}
+	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
 
 
-	// Jump ahead `HandshakeRetries` ticks
-	cumulative := time.Duration(0)
-	for i := 0; i <= DefaultHandshakeRetries+1; i++ {
-		cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
-		next_tick := now.Add(cumulative)
-		//l.Infoln(next_tick)
-		blah.NextOutboundHandshakeTimerTick(next_tick, mw)
+	// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
+	for i := 1; i <= DefaultHandshakeRetries+1; i++ {
+		now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
+		blah.NextOutboundHandshakeTimerTick(now, mw)
 	}
 	}
 
 
 	// Confirm they are still in the pending index list
 	// Confirm they are still in the pending index list
-	for _, v := range ips {
-		assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
-	}
-	// Jump ahead 1 more second
-	cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
-	next_tick := now.Add(cumulative)
-	//l.Infoln(next_tick)
-	blah.NextOutboundHandshakeTimerTick(next_tick, mw)
+	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
+
+	// Tick 1 more time, a minute will certainly flush it out
+	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
+
 	// Confirm they have been removed
 	// Confirm they have been removed
-	for _, v := range ips {
-		assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
-	}
+	assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
 }
 }
 
 
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
@@ -121,7 +58,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-	lh := &LightHouse{}
+	lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
 
 
 	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
 	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
 
 
@@ -130,28 +67,25 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 
 
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
 
-	blah.AddVpnIP(ip)
-
+	hi := blah.AddVpnIP(ip)
+	hi.HandshakeReady = true
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
+	assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
 
 
-	// Trigger the same method the channel will
+	// Trigger the same method the channel will but, this should set our remotes pointer
 	blah.handleOutbound(ip, mw, true)
 	blah.handleOutbound(ip, mw, true)
+	assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt")
+	assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer")
 
 
-	// Make sure the trigger doesn't schedule another timer entry
+	// Make sure the trigger doesn't double schedule the timer entry
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
-	hi := blah.pendingHostMap.Hosts[ip]
-	assert.Nil(t, hi.remote)
 
 
 	uaddr := NewUDPAddrFromString("10.1.1.1:4242")
 	uaddr := NewUDPAddrFromString("10.1.1.1:4242")
-	lh.addrMap = map[uint32]*ip4And6{}
-	lh.addrMap[ip] = &ip4And6{
-		v4: []*Ip4AndPort{NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))},
-		v6: []*Ip6AndPort{},
-	}
+	hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
 
 
-	// This should trigger the hostmap to populate the hostinfo
+	// We now have remotes but only the first trigger should have pushed things forward
 	blah.handleOutbound(ip, mw, true)
 	blah.handleOutbound(ip, mw, true)
-	assert.NotNil(t, hi.remote)
+	assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt")
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 }
 }
 
 
@@ -166,100 +100,9 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
 	return c
 	return c
 }
 }
 
 
-func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
-	l := NewTestLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
-	mw := &mockEncWriter{}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextOutboundHandshakeTimerTick(now, mw)
-
-	hostinfo := blah.AddVpnIP(vpnIP)
-	// Pretned we have an index too
-	err := blah.AddIndexHostInfo(hostinfo)
-	assert.NoError(t, err)
-	blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
-	assert.NotZero(t, hostinfo.localIndexId)
-	assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
-
-	// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
-	// but not main hostmap
-	cumulative := time.Duration(0)
-	for i := 1; i <= DefaultHandshakeRetries+2; i++ {
-		cumulative += DefaultHandshakeTryInterval * time.Duration(i)
-		next_tick := now.Add(cumulative)
-		blah.NextOutboundHandshakeTimerTick(next_tick, mw)
-	}
-	/*
-		for i := 0; i <= HandshakeRetries+1; i++ {
-			next_tick := now.Add(cumulative)
-			//l.Infoln(next_tick)
-			blah.NextOutboundHandshakeTimerTick(next_tick)
-		}
-	*/
-	/*
-		for i := 0; i <= HandshakeRetries+1; i++ {
-			next_tick := now.Add(time.Duration(i) * time.Second)
-			blah.NextOutboundHandshakeTimerTick(next_tick)
-		}
-	*/
-
-	/*
-		cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
-		next_tick := now.Add(cumulative)
-		l.Infoln(cumulative, next_tick)
-		blah.NextOutboundHandshakeTimerTick(next_tick)
-	*/
-	assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
-	assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
-}
-
-func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
-	l := NewTestLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextInboundHandshakeTimerTick(now)
-
-	hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
-	err := blah.AddIndexHostInfo(hostinfo)
-	assert.NoError(t, err)
-	blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
-	// Pretned we have an index too
-	blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
-	assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
-
-	for i := 1; i <= DefaultHandshakeRetries+2; i++ {
-		next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
-		blah.NextInboundHandshakeTimerTick(next_tick)
-	}
-
-	next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
-	blah.NextInboundHandshakeTimerTick(next_tick)
-	assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
-	assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
-}
-
 type mockEncWriter struct {
 type mockEncWriter struct {
 }
 }
 
 
 func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
 func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
 	return
 	return
 }
 }
-
-func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
-	return
-}

+ 52 - 249
hostmap.go

@@ -1,7 +1,6 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
@@ -16,6 +15,7 @@ import (
 
 
 //const ProbeLen = 100
 //const ProbeLen = 100
 const PromoteEvery = 1000
 const PromoteEvery = 1000
+const ReQueryEvery = 5000
 const MaxRemotes = 10
 const MaxRemotes = 10
 
 
 // How long we should prevent roaming back to the previous IP.
 // How long we should prevent roaming back to the previous IP.
@@ -30,7 +30,6 @@ type HostMap struct {
 	Hosts           map[uint32]*HostInfo
 	Hosts           map[uint32]*HostInfo
 	preferredRanges []*net.IPNet
 	preferredRanges []*net.IPNet
 	vpnCIDR         *net.IPNet
 	vpnCIDR         *net.IPNet
-	defaultRoute    uint32
 	unsafeRoutes    *CIDRTree
 	unsafeRoutes    *CIDRTree
 	metricsEnabled  bool
 	metricsEnabled  bool
 	l               *logrus.Logger
 	l               *logrus.Logger
@@ -40,25 +39,21 @@ type HostInfo struct {
 	sync.RWMutex
 	sync.RWMutex
 
 
 	remote            *udpAddr
 	remote            *udpAddr
-	Remotes           []*udpAddr
+	remotes           *RemoteList
 	promoteCounter    uint32
 	promoteCounter    uint32
 	ConnectionState   *ConnectionState
 	ConnectionState   *ConnectionState
-	handshakeStart    time.Time
-	HandshakeReady    bool
-	HandshakeCounter  int
-	HandshakeComplete bool
-	HandshakePacket   map[uint8][]byte
-	packetStore       []*cachedPacket
+	handshakeStart    time.Time        //todo: this an entry in the handshake manager
+	HandshakeReady    bool             //todo: being in the manager means you are ready
+	HandshakeCounter  int              //todo: another handshake manager entry
+	HandshakeComplete bool             //todo: this should go away in favor of ConnectionState.ready
+	HandshakePacket   map[uint8][]byte //todo: this is other handshake manager entry
+	packetStore       []*cachedPacket  //todo: this is other handshake manager entry
 	remoteIndexId     uint32
 	remoteIndexId     uint32
 	localIndexId      uint32
 	localIndexId      uint32
 	hostId            uint32
 	hostId            uint32
 	recvError         int
 	recvError         int
 	remoteCidr        *CIDRTree
 	remoteCidr        *CIDRTree
 
 
-	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
-	// They should not be tried again during a handshake
-	badRemotes []*udpAddr
-
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
 	// with a handshake
 	// with a handshake
@@ -88,7 +83,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 		Hosts:           h,
 		Hosts:           h,
 		preferredRanges: preferredRanges,
 		preferredRanges: preferredRanges,
 		vpnCIDR:         vpnCIDR,
 		vpnCIDR:         vpnCIDR,
-		defaultRoute:    0,
 		unsafeRoutes:    NewCIDRTree(),
 		unsafeRoutes:    NewCIDRTree(),
 		l:               l,
 		l:               l,
 	}
 	}
@@ -131,7 +125,6 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
 	if _, ok := hm.Hosts[vpnIP]; !ok {
 	if _, ok := hm.Hosts[vpnIP]; !ok {
 		hm.RUnlock()
 		hm.RUnlock()
 		h = &HostInfo{
 		h = &HostInfo{
-			Remotes:         []*udpAddr{},
 			promoteCounter:  0,
 			promoteCounter:  0,
 			hostId:          vpnIP,
 			hostId:          vpnIP,
 			HandshakePacket: make(map[uint8][]byte, 0),
 			HandshakePacket: make(map[uint8][]byte, 0),
@@ -239,7 +232,11 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 
 
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	hm.Lock()
 	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
+}
 
 
+func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	// Check if this same hostId is in the hostmap with a different instance.
 	// Check if this same hostId is in the hostmap with a different instance.
 	// This could happen if we have an entry in the pending hostmap with different
 	// This could happen if we have an entry in the pending hostmap with different
 	// index values than the one in the main hostmap.
 	// index values than the one in the main hostmap.
@@ -262,7 +259,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	if len(hm.RemoteIndexes) == 0 {
 	if len(hm.RemoteIndexes) == 0 {
 		hm.RemoteIndexes = map[uint32]*HostInfo{}
 		hm.RemoteIndexes = map[uint32]*HostInfo{}
 	}
 	}
-	hm.Unlock()
 
 
 	if hm.l.Level >= logrus.DebugLevel {
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@@ -294,30 +290,6 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
-	hm.Lock()
-	i, v := hm.Hosts[vpnIp]
-	if v {
-		i.AddRemote(remote)
-	} else {
-		i = &HostInfo{
-			Remotes:         []*udpAddr{remote.Copy()},
-			promoteCounter:  0,
-			hostId:          vpnIp,
-			HandshakePacket: make(map[uint8][]byte, 0),
-		}
-		i.remote = i.Remotes[0]
-		hm.Hosts[vpnIp] = i
-		if hm.l.Level >= logrus.DebugLevel {
-			hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
-				Debug("Hostmap remote ip added")
-		}
-	}
-	i.ForcePromoteBest(hm.preferredRanges)
-	hm.Unlock()
-	return i
-}
-
 func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
 func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
 	return hm.queryVpnIP(vpnIp, nil)
 	return hm.queryVpnIP(vpnIp, nil)
 }
 }
@@ -331,12 +303,13 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
 func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
 func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
 	hm.RLock()
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 	if h, ok := hm.Hosts[vpnIp]; ok {
-		if promoteIfce != nil {
+		// Do not attempt promotion if you are a lighthouse
+		if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
 			h.TryPromoteBest(hm.preferredRanges, promoteIfce)
 			h.TryPromoteBest(hm.preferredRanges, promoteIfce)
 		}
 		}
-		//fmt.Println(h.remote)
 		hm.RUnlock()
 		hm.RUnlock()
 		return h, nil
 		return h, nil
+
 	} else {
 	} else {
 		//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
 		//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
 		hm.RUnlock()
 		hm.RUnlock()
@@ -362,11 +335,8 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
 // We already have the hm Lock when this is called, so make sure to not call
 // We already have the hm Lock when this is called, so make sure to not call
 // any other methods that might try to grab it again
 // any other methods that might try to grab it again
 func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
-	remoteCert := hostinfo.ConnectionState.peerCert
-	ip := ip2int(remoteCert.Details.Ips[0].IP)
-
-	f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
 	if f.serveDns {
 	if f.serveDns {
+		remoteCert := hostinfo.ConnectionState.peerCert
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 	}
 	}
 
 
@@ -381,38 +351,21 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 	}
 }
 }
 
 
-func (hm *HostMap) ClearRemotes(vpnIP uint32) {
-	hm.Lock()
-	i := hm.Hosts[vpnIP]
-	if i == nil {
-		hm.Unlock()
-		return
-	}
-	i.remote = nil
-	i.Remotes = nil
-	hm.Unlock()
-}
-
-func (hm *HostMap) SetDefaultRoute(ip uint32) {
-	hm.defaultRoute = ip
-}
-
-func (hm *HostMap) PunchList() []*udpAddr {
-	var list []*udpAddr
+// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
+// The caller can then do the its work outside of the read lock
+func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
 	hm.RLock()
 	hm.RLock()
+	defer hm.RUnlock()
+
 	for _, v := range hm.Hosts {
 	for _, v := range hm.Hosts {
-		for _, r := range v.Remotes {
-			list = append(list, r)
+		if v.remotes != nil {
+			rl = append(rl, v.remotes)
 		}
 		}
-		//	if h, ok := hm.Hosts[vpnIp]; ok {
-		//		hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
-		//fmt.Println(h.remote)
-		//	}
 	}
 	}
-	hm.RUnlock()
-	return list
+	return rl
 }
 }
 
 
+// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
 func (hm *HostMap) Punchy(conn *udpConn) {
 func (hm *HostMap) Punchy(conn *udpConn) {
 	var metricsTxPunchy metrics.Counter
 	var metricsTxPunchy metrics.Counter
 	if hm.metricsEnabled {
 	if hm.metricsEnabled {
@@ -421,13 +374,18 @@ func (hm *HostMap) Punchy(conn *udpConn) {
 		metricsTxPunchy = metrics.NilCounter{}
 		metricsTxPunchy = metrics.NilCounter{}
 	}
 	}
 
 
+	var remotes []*RemoteList
 	b := []byte{1}
 	b := []byte{1}
 	for {
 	for {
-		for _, addr := range hm.PunchList() {
-			metricsTxPunchy.Inc(1)
-			conn.WriteTo(b, addr)
+		remotes = hm.punchList(remotes[:0])
+		for _, rl := range remotes {
+			//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
+			for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
+				metricsTxPunchy.Inc(1)
+				conn.WriteTo(b, addr)
+			}
 		}
 		}
-		time.Sleep(time.Second * 30)
+		time.Sleep(time.Second * 10)
 	}
 	}
 }
 }
 
 
@@ -438,38 +396,15 @@ func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
 	}
 	}
 }
 }
 
 
-func (i *HostInfo) MarshalJSON() ([]byte, error) {
-	return json.Marshal(m{
-		"remote":             i.remote,
-		"remotes":            i.Remotes,
-		"promote_counter":    i.promoteCounter,
-		"connection_state":   i.ConnectionState,
-		"handshake_start":    i.handshakeStart,
-		"handshake_ready":    i.HandshakeReady,
-		"handshake_counter":  i.HandshakeCounter,
-		"handshake_complete": i.HandshakeComplete,
-		"handshake_packet":   i.HandshakePacket,
-		"packet_store":       i.packetStore,
-		"remote_index":       i.remoteIndexId,
-		"local_index":        i.localIndexId,
-		"host_id":            int2ip(i.hostId),
-		"receive_errors":     i.recvError,
-		"last_roam":          i.lastRoam,
-		"last_roam_remote":   i.lastRoamRemote,
-	})
-}
-
 func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
 func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
 	i.ConnectionState = cs
 	i.ConnectionState = cs
 }
 }
 
 
+// TryPromoteBest handles re-querying lighthouses and probing for better paths
+// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
-	if i.remote == nil {
-		i.ForcePromoteBest(preferredRanges)
-		return
-	}
-
-	if atomic.AddUint32(&i.promoteCounter, 1)%PromoteEvery == 0 {
+	c := atomic.AddUint32(&i.promoteCounter, 1)
+	if c%PromoteEvery == 0 {
 		// return early if we are already on a preferred remote
 		// return early if we are already on a preferred remote
 		rIP := i.remote.IP
 		rIP := i.remote.IP
 		for _, l := range preferredRanges {
 		for _, l := range preferredRanges {
@@ -478,87 +413,21 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 			}
 		}
 		}
 
 
-		// We re-query the lighthouse periodically while sending packets, so
-		// check for new remotes in our local lighthouse cache
-		ips := ifce.lightHouse.QueryCache(i.hostId)
-		for _, ip := range ips {
-			i.AddRemote(ip)
-		}
+		i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
+			if addr == nil || !preferred {
+				return
+			}
 
 
-		best, preferred := i.getBestRemote(preferredRanges)
-		if preferred && !best.Equals(i.remote) {
 			// Try to send a test packet to that host, this should
 			// Try to send a test packet to that host, this should
 			// cause it to detect a roaming event and switch remotes
 			// cause it to detect a roaming event and switch remotes
-			ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
-		}
-	}
-}
-
-func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
-	best, _ := i.getBestRemote(preferredRanges)
-	if best != nil {
-		i.remote = best
-	}
-}
-
-func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
-	if len(i.Remotes) > 0 {
-		for _, r := range i.Remotes {
-			for _, l := range preferredRanges {
-				if l.Contains(r.IP) {
-					return r, true
-				}
-			}
-
-			if best == nil || !PrivateIP(r.IP) {
-				best = r
-			}
-			/*
-				for _, r := range i.Remotes {
-					// Must have > 80% probe success to be considered.
-					//fmt.Println("GRADE:", r.addr.IP, r.Grade())
-					if r.Grade() > float64(.8) {
-						if localToMe.Contains(r.addr.IP) == true {
-							best = r.addr
-							break
-							//i.remote = i.Remotes[c].addr
-						} else {
-								//}
-					}
-			*/
-		}
-		return best, false
+			ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+		})
 	}
 	}
 
 
-	return nil, false
-}
-
-// rotateRemote will move remote to the next ip in the list of remote ips for this host
-// This is different than PromoteBest in that what is algorithmically best may not actually work.
-// Only known use case is when sending a stage 0 handshake.
-// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
-func (i *HostInfo) rotateRemote() {
-	// We have 0, can't rotate
-	if len(i.Remotes) < 1 {
-		return
+	// Re query our lighthouses for new remotes occasionally
+	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
+		ifce.lightHouse.QueryServer(i.hostId, ifce)
 	}
 	}
-
-	if i.remote == nil {
-		i.remote = i.Remotes[0]
-		return
-	}
-
-	// We want to look at all but the very last entry since that is handled at the end
-	for x := 0; x < len(i.Remotes)-1; x++ {
-		// Find our current position and move to the next one in the list
-		if i.Remotes[x].Equals(i.remote) {
-			i.remote = i.Remotes[x+1]
-			return
-		}
-	}
-
-	// Our current position was likely the last in the list, start over at 0
-	i.remote = i.Remotes[0]
 }
 }
 
 
 func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
 func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
@@ -607,23 +476,13 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
 		}
 		}
 	}
 	}
 
 
-	i.badRemotes = make([]*udpAddr, 0)
+	i.remotes.ResetBlockedRemotes()
 	i.packetStore = make([]*cachedPacket, 0)
 	i.packetStore = make([]*cachedPacket, 0)
 	i.ConnectionState.ready = true
 	i.ConnectionState.ready = true
 	i.ConnectionState.queueLock.Unlock()
 	i.ConnectionState.queueLock.Unlock()
 	i.ConnectionState.certState = nil
 	i.ConnectionState.certState = nil
 }
 }
 
 
-func (i *HostInfo) CopyRemotes() []*udpAddr {
-	i.RLock()
-	rc := make([]*udpAddr, len(i.Remotes), len(i.Remotes))
-	for x, addr := range i.Remotes {
-		rc[x] = addr.Copy()
-	}
-	i.RUnlock()
-	return rc
-}
-
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	if i.ConnectionState != nil {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
 		return i.ConnectionState.peerCert
@@ -631,58 +490,12 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	return nil
 	return nil
 }
 }
 
 
-func (i *HostInfo) AddRemote(remote *udpAddr) *udpAddr {
-	if i.unlockedIsBadRemote(remote) {
-		return i.remote
-	}
-
-	for _, r := range i.Remotes {
-		if r.Equals(remote) {
-			return r
-		}
-	}
-
-	// Trim this down if necessary
-	if len(i.Remotes) > MaxRemotes {
-		i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
-	}
-
-	rc := remote.Copy()
-	i.Remotes = append(i.Remotes, rc)
-	return rc
-}
-
 func (i *HostInfo) SetRemote(remote *udpAddr) {
 func (i *HostInfo) SetRemote(remote *udpAddr) {
-	i.remote = i.AddRemote(remote)
-}
-
-func (i *HostInfo) unlockedBlockRemote(remote *udpAddr) {
-	if !i.unlockedIsBadRemote(remote) {
-		// We copy here because we are taking something else's memory and we can't trust everything
-		i.badRemotes = append(i.badRemotes, remote.Copy())
+	// We copy here because we likely got this remote from a source that reuses the object
+	if !i.remote.Equals(remote) {
+		i.remote = remote.Copy()
+		i.remotes.LearnRemote(i.hostId, remote.Copy())
 	}
 	}
-
-	for k, v := range i.Remotes {
-		if v.Equals(remote) {
-			i.Remotes[k] = i.Remotes[len(i.Remotes)-1]
-			i.Remotes = i.Remotes[:len(i.Remotes)-1]
-			return
-		}
-	}
-}
-
-func (i *HostInfo) unlockedIsBadRemote(remote *udpAddr) bool {
-	for _, v := range i.badRemotes {
-		if v.Equals(remote) {
-			return true
-		}
-	}
-	return false
-}
-
-func (i *HostInfo) ClearRemotes() {
-	i.remote = nil
-	i.Remotes = []*udpAddr{}
 }
 }
 
 
 func (i *HostInfo) ClearConnectionState() {
 func (i *HostInfo) ClearConnectionState() {
@@ -805,13 +618,3 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
 	}
 	}
 	return &ips
 	return &ips
 }
 }
-
-func PrivateIP(ip net.IP) bool {
-	//TODO: Private for ipv6 or just let it ride?
-	private := false
-	_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
-	_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
-	_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
-	private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
-	return private
-}

+ 0 - 168
hostmap_test.go

@@ -1,169 +1 @@
 package nebula
 package nebula
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-/*
-func TestHostInfoDestProbe(t *testing.T) {
-	a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222")
-	d := NewHostInfoDest(a)
-
-	// 999 probes that all return should give a 100% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		d.ProbeReceived(meh)
-	}
-	assert.Equal(t, d.Grade(), float64(1))
-
-	// 999 probes of which only half return should give a 50% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		if i%2 == 0 {
-			d.ProbeReceived(meh)
-		}
-	}
-	assert.Equal(t, d.Grade(), float64(.5))
-
-	// 999 probes of which none return should give a 0% success rate
-	for i := 0; i < 999; i++ {
-		d.Probe()
-	}
-	assert.Equal(t, d.Grade(), float64(0))
-
-	// 999 probes of which only 1/4 return should give a 25% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		if i%4 == 0 {
-			d.ProbeReceived(meh)
-		}
-	}
-	assert.Equal(t, d.Grade(), float64(.25))
-
-	// 999 probes of which only half return and are duplicates should give a 50% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		if i%2 == 0 {
-			d.ProbeReceived(meh)
-			d.ProbeReceived(meh)
-		}
-	}
-	assert.Equal(t, d.Grade(), float64(.5))
-
-	// 999 probes of which only way old replies return should give a 0% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		d.ProbeReceived(meh - 101)
-	}
-	assert.Equal(t, d.Grade(), float64(0))
-
-}
-*/
-
-func TestHostmap(t *testing.T) {
-	l := NewTestLogger()
-	_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
-	_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
-	myNets := []*net.IPNet{myNet}
-	preferredRanges := []*net.IPNet{localToMe}
-
-	m := NewHostMap(l, "test", myNet, preferredRanges)
-
-	a := NewUDPAddrFromString("10.127.0.3:11111")
-	b := NewUDPAddrFromString("1.0.0.1:22222")
-	y := NewUDPAddrFromString("10.128.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-
-	info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
-
-	// There should be three remotes in the host map
-	assert.Equal(t, 3, len(info.Remotes))
-
-	// Adding an identical remote should not change the count
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-	assert.Equal(t, 3, len(info.Remotes))
-
-	// Adding a fresh remote should add one
-	y = NewUDPAddrFromString("10.18.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-	assert.Equal(t, 4, len(info.Remotes))
-
-	// Query and reference remote should get the first one (and not nil)
-	info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
-	assert.NotNil(t, info.remote)
-
-	// Promotion should ensure that the best remote is chosen (y)
-	info.ForcePromoteBest(myNets)
-	assert.True(t, myNet.Contains(info.remote.IP))
-
-}
-
-func TestHostmapdebug(t *testing.T) {
-	l := NewTestLogger()
-	_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
-	_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
-	preferredRanges := []*net.IPNet{localToMe}
-	m := NewHostMap(l, "test", myNet, preferredRanges)
-
-	a := NewUDPAddrFromString("10.127.0.3:11111")
-	b := NewUDPAddrFromString("1.0.0.1:22222")
-	y := NewUDPAddrFromString("10.128.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-
-	//t.Errorf("%s", m.DebugRemotes(1))
-}
-
-func TestHostMap_rotateRemote(t *testing.T) {
-	h := HostInfo{}
-	// 0 remotes, no panic
-	h.rotateRemote()
-	assert.Nil(t, h.remote)
-
-	// 1 remote, no panic
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 1}, 0))
-	h.rotateRemote()
-	assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 1})
-
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 2}, 0))
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 3}, 0))
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 4}, 0))
-
-	//TODO: ensure we are copying and not storing the slice!
-
-	// Rotate through those 3
-	h.rotateRemote()
-	assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 2})
-
-	h.rotateRemote()
-	assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 3})
-
-	h.rotateRemote()
-	assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 4}, Port: 0})
-
-	// Finally, we should start over
-	h.rotateRemote()
-	assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 1}, Port: 0})
-}
-
-func BenchmarkHostmappromote2(b *testing.B) {
-	l := NewTestLogger()
-	for n := 0; n < b.N; n++ {
-		_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
-		_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
-		preferredRanges := []*net.IPNet{localToMe}
-		m := NewHostMap(l, "test", myNet, preferredRanges)
-		y := NewUDPAddrFromString("10.128.0.3:11111")
-		a := NewUDPAddrFromString("10.127.0.3:11111")
-		g := NewUDPAddrFromString("1.0.0.1:22222")
-		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
-		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
-		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-	}
-}

+ 7 - 50
inside.go

@@ -54,10 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 
 
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	if dropReason == nil {
 	if dropReason == nil {
-		mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
-		if f.lightHouse != nil && mc%5000 == 0 {
-			f.lightHouse.Query(fwPacket.RemoteIP, f)
-		}
+		f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 
 
 	} else if f.l.Level >= logrus.DebugLevel {
 	} else if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).
 		hostinfo.logger(f.l).
@@ -84,15 +81,13 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
 			hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
 			hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
 		}
 		}
 	}
 	}
-
 	ci := hostinfo.ConnectionState
 	ci := hostinfo.ConnectionState
 
 
 	if ci != nil && ci.eKey != nil && ci.ready {
 	if ci != nil && ci.eKey != nil && ci.ready {
 		return hostinfo
 		return hostinfo
 	}
 	}
 
 
-	// Handshake is not ready, we need to grab the lock now before we start
-	// the handshake process
+	// Handshake is not ready, we need to grab the lock now before we start the handshake process
 	hostinfo.Lock()
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 	defer hostinfo.Unlock()
 
 
@@ -150,10 +145,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 		return
 		return
 	}
 	}
 
 
-	messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
-	if f.lightHouse != nil && messageCounter%5000 == 0 {
-		f.lightHouse.Query(fp.RemoteIP, f)
-	}
+	f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
 }
 }
 
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
@@ -187,50 +179,15 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
 	f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
 	f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
 }
 }
 
 
-// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
-func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
-	hostInfo := f.getOrHandshake(vpnIp)
-	if hostInfo == nil {
-		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", IntIp(vpnIp)).
-				Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
-		}
-		return
-	}
-
-	if hostInfo.ConnectionState.ready == false {
-		// Because we might be sending stored packets, lock here to stop new things going to
-		// the packet queue.
-		hostInfo.ConnectionState.queueLock.Lock()
-		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
-			hostInfo.ConnectionState.queueLock.Unlock()
-			return
-		}
-		hostInfo.ConnectionState.queueLock.Unlock()
-	}
-
-	f.sendMessageToAll(t, st, hostInfo, p, nb, out)
-	return
-}
-
-func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) {
-	hostInfo.RLock()
-	for _, r := range hostInfo.Remotes {
-		f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b)
-	}
-	hostInfo.RUnlock()
-}
-
 func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
 func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
 }
 
 
-func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) uint64 {
+func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
 	if ci.eKey == nil {
 		//TODO: log warning
 		//TODO: log warning
-		return 0
+		return
 	}
 	}
 
 
 	var err error
 	var err error
@@ -262,7 +219,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("attemptedCounter", c).
 			WithField("attemptedCounter", c).
 			Error("Failed to encrypt outgoing packet")
 			Error("Failed to encrypt outgoing packet")
-		return c
+		return
 	}
 	}
 
 
 	err = f.writers[q].WriteTo(out, remote)
 	err = f.writers[q].WriteTo(out, remote)
@@ -270,7 +227,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 		hostinfo.logger(f.l).WithError(err).
 		hostinfo.logger(f.l).WithError(err).
 			WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 			WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 	}
 	}
-	return c
+	return
 }
 }
 
 
 func isMulticast(ip uint32) bool {
 func isMulticast(ip uint32) bool {

+ 117 - 226
lighthouse.go

@@ -13,26 +13,11 @@ import (
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 )
 )
 
 
+//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
 //TODO: nodes are roaming lighthouses, this is bad. How are they learning?
 //TODO: nodes are roaming lighthouses, this is bad. How are they learning?
 
 
 var ErrHostNotKnown = errors.New("host not known")
 var ErrHostNotKnown = errors.New("host not known")
 
 
-// The maximum number of ip addresses to store for a given vpnIp per address family
-const maxAddrs = 10
-
-type ip4And6 struct {
-	//TODO: adding a lock here could allow us to release the lock on lh.addrMap quicker
-
-	// v4 and v6 store addresses that have been self reported by the client in a server or where all addresses are stored on a client
-	v4 []*Ip4AndPort
-	v6 []*Ip6AndPort
-
-	// Learned addresses are ones that a client does not know about but a lighthouse learned from as a result of the received packet
-	// This is only used if you are a lighthouse server
-	learnedV4 []*Ip4AndPort
-	learnedV6 []*Ip6AndPort
-}
-
 type LightHouse struct {
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
 	sync.RWMutex //Because we concurrently read and write to our maps
@@ -42,7 +27,8 @@ type LightHouse struct {
 	punchConn    *udpConn
 	punchConn    *udpConn
 
 
 	// Local cache of answers from light houses
 	// Local cache of answers from light houses
-	addrMap map[uint32]*ip4And6
+	// map of vpn Ip to answers
+	addrMap map[uint32]*RemoteList
 
 
 	// filters remote addresses allowed for each host
 	// filters remote addresses allowed for each host
 	// - When we are a lighthouse, this filters what addresses we store and
 	// - When we are a lighthouse, this filters what addresses we store and
@@ -81,7 +67,7 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i
 		amLighthouse: amLighthouse,
 		amLighthouse: amLighthouse,
 		myVpnIp:      ip2int(myVpnIpNet.IP),
 		myVpnIp:      ip2int(myVpnIpNet.IP),
 		myVpnZeros:   uint32(32 - ones),
 		myVpnZeros:   uint32(32 - ones),
-		addrMap:      make(map[uint32]*ip4And6),
+		addrMap:      make(map[uint32]*RemoteList),
 		nebulaPort:   nebulaPort,
 		nebulaPort:   nebulaPort,
 		lighthouses:  make(map[uint32]struct{}),
 		lighthouses:  make(map[uint32]struct{}),
 		staticList:   make(map[uint32]struct{}),
 		staticList:   make(map[uint32]struct{}),
@@ -130,57 +116,79 @@ func (lh *LightHouse) ValidateLHStaticEntries() error {
 	return nil
 	return nil
 }
 }
 
 
-func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]*udpAddr, error) {
-	//TODO: we need to hold the lock through the next func
+func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip, f)
 		lh.QueryServer(ip, f)
 	}
 	}
 	lh.RLock()
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
 		lh.RUnlock()
-		return TransformLHReplyToUdpAddrs(v), nil
+		return v
 	}
 	}
 	lh.RUnlock()
 	lh.RUnlock()
-	return nil, ErrHostNotKnown
+	return nil
 }
 }
 
 
 // This is asynchronous so no reply should be expected
 // This is asynchronous so no reply should be expected
 func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
 func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
-	if !lh.amLighthouse {
-		// Send a query to the lighthouses and hope for the best next time
-		query, err := proto.Marshal(NewLhQueryByInt(ip))
-		if err != nil {
-			lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
-			return
-		}
+	if lh.amLighthouse {
+		return
+	}
 
 
-		lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
-		nb := make([]byte, 12, 12)
-		out := make([]byte, mtu)
-		for n := range lh.lighthouses {
-			f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
-		}
+	if lh.IsLighthouseIP(ip) {
+		return
+	}
+
+	// Send a query to the lighthouses and hope for the best next time
+	query, err := proto.Marshal(NewLhQueryByInt(ip))
+	if err != nil {
+		lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
+		return
+	}
+
+	lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
+	nb := make([]byte, 12, 12)
+	out := make([]byte, mtu)
+	for n := range lh.lighthouses {
+		f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
 	}
 	}
 }
 }
 
 
-func (lh *LightHouse) QueryCache(ip uint32) []*udpAddr {
-	//TODO: we need to hold the lock through the next func
+func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
 	lh.RLock()
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
 		lh.RUnlock()
-		return TransformLHReplyToUdpAddrs(v)
+		return v
 	}
 	}
 	lh.RUnlock()
 	lh.RUnlock()
-	return nil
+
+	lh.Lock()
+	defer lh.Unlock()
+	// Add an entry if we don't already have one
+	return lh.unlockedGetRemoteList(ip)
 }
 }
 
 
-//
-func (lh *LightHouse) queryAndPrepMessage(ip uint32, f func(*ip4And6) (int, error)) (bool, int, error) {
+// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
+// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
+// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
+func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
 	lh.RLock()
 	lh.RLock()
-	if v, ok := lh.addrMap[ip]; ok {
-		n, err := f(v)
+	// Do we have an entry in the main cache?
+	if v, ok := lh.addrMap[vpnIp]; ok {
+		// Swap lh lock for remote list lock
+		v.RLock()
+		defer v.RUnlock()
+
 		lh.RUnlock()
 		lh.RUnlock()
-		return true, n, err
+
+		// vpnIp should also be the owner here since we are a lighthouse.
+		c := v.cache[vpnIp]
+		// Make sure we have
+		if c != nil {
+			n, err := f(c)
+			return true, n, err
+		}
+		return false, 0, nil
 	}
 	}
 	lh.RUnlock()
 	lh.RUnlock()
 	return false, 0, nil
 	return false, 0, nil
@@ -203,70 +211,47 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
 	lh.Unlock()
 	lh.Unlock()
 }
 }
 
 
-// AddRemote is correct way for non LightHouse members to add an address. toAddr will be placed in the learned map
-// static means this is a static host entry from the config file, it should only be used on start up
-func (lh *LightHouse) AddRemote(vpnIP uint32, toAddr *udpAddr, static bool) {
-	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
-		lh.addRemoteV4(vpnIP, NewIp4AndPort(ipv4, uint32(toAddr.Port)), static, true)
-	} else {
-		lh.addRemoteV6(vpnIP, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)), static, true)
-	}
-
-	//TODO: if we do not add due to a config filter we may end up not having any addresses here
-	if static {
-		lh.staticList[vpnIP] = struct{}{}
-	}
-}
-
-// unlockedGetAddrs assumes you have the lh lock
-func (lh *LightHouse) unlockedGetAddrs(vpnIP uint32) *ip4And6 {
-	am, ok := lh.addrMap[vpnIP]
-	if !ok {
-		am = &ip4And6{}
-		lh.addrMap[vpnIP] = am
-	}
-	return am
-}
-
-// addRemoteV4 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated
-func (lh *LightHouse) addRemoteV4(vpnIP uint32, to *Ip4AndPort, static bool, learned bool) {
-	// First we check if the sender thinks this is a static entry
-	// and do nothing if it is not, but should be considered static
-	if static == false {
-		if _, ok := lh.staticList[vpnIP]; ok {
-			return
-		}
-	}
-
+// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
+// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
+// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
+func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
 	lh.Lock()
 	lh.Lock()
-	defer lh.Unlock()
-	am := lh.unlockedGetAddrs(vpnIP)
+	am := lh.unlockedGetRemoteList(vpnIp)
+	am.Lock()
+	defer am.Unlock()
+	lh.Unlock()
 
 
-	if learned {
-		if !lh.unlockedShouldAddV4(am.learnedV4, to) {
+	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
+		to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
+		if !lh.unlockedShouldAddV4(to) {
 			return
 			return
 		}
 		}
-		am.learnedV4 = prependAndLimitV4(am.learnedV4, to)
+		am.unlockedPrependV4(lh.myVpnIp, to)
+
 	} else {
 	} else {
-		if !lh.unlockedShouldAddV4(am.v4, to) {
+		to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
+		if !lh.unlockedShouldAddV6(to) {
 			return
 			return
 		}
 		}
-		am.v4 = prependAndLimitV4(am.v4, to)
+		am.unlockedPrependV6(lh.myVpnIp, to)
 	}
 	}
+
+	// Mark it as static
+	lh.staticList[vpnIp] = struct{}{}
 }
 }
 
 
-func prependAndLimitV4(cache []*Ip4AndPort, to *Ip4AndPort) []*Ip4AndPort {
-	cache = append(cache, nil)
-	copy(cache[1:], cache)
-	cache[0] = to
-	if len(cache) > MaxRemotes {
-		cache = cache[:maxAddrs]
+// unlockedGetRemoteList assumes you have the lh lock
+func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
+	am, ok := lh.addrMap[vpnIP]
+	if !ok {
+		am = NewRemoteList()
+		lh.addrMap[vpnIP] = am
 	}
 	}
-	return cache
+	return am
 }
 }
 
 
-// unlockedShouldAddV4 checks if to is allowed by our allow list and is not already present in the cache
-func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool {
+// unlockedShouldAddV4 checks if to is allowed by our allow list
+func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool {
 	allow := lh.remoteAllowList.AllowIpV4(to.Ip)
 	allow := lh.remoteAllowList.AllowIpV4(to.Ip)
 	if lh.l.Level >= logrus.TraceLevel {
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 		lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@@ -276,69 +261,21 @@ func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool
 		return false
 		return false
 	}
 	}
 
 
-	for _, v := range am {
-		if v.Ip == to.Ip && v.Port == to.Port {
-			return false
-		}
-	}
-
 	return true
 	return true
 }
 }
 
 
-// addRemoteV6 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated
-func (lh *LightHouse) addRemoteV6(vpnIP uint32, to *Ip6AndPort, static bool, learned bool) {
-	// First we check if the sender thinks this is a static entry
-	// and do nothing if it is not, but should be considered static
-	if static == false {
-		if _, ok := lh.staticList[vpnIP]; ok {
-			return
-		}
-	}
-
-	lh.Lock()
-	defer lh.Unlock()
-	am := lh.unlockedGetAddrs(vpnIP)
-
-	if learned {
-		if !lh.unlockedShouldAddV6(am.learnedV6, to) {
-			return
-		}
-		am.learnedV6 = prependAndLimitV6(am.learnedV6, to)
-	} else {
-		if !lh.unlockedShouldAddV6(am.v6, to) {
-			return
-		}
-		am.v6 = prependAndLimitV6(am.v6, to)
-	}
-}
-
-func prependAndLimitV6(cache []*Ip6AndPort, to *Ip6AndPort) []*Ip6AndPort {
-	cache = append(cache, nil)
-	copy(cache[1:], cache)
-	cache[0] = to
-	if len(cache) > MaxRemotes {
-		cache = cache[:maxAddrs]
-	}
-	return cache
-}
-
-// unlockedShouldAddV6 checks if to is allowed by our allow list and is not already present in the cache
-func (lh *LightHouse) unlockedShouldAddV6(am []*Ip6AndPort, to *Ip6AndPort) bool {
+// unlockedShouldAddV6 checks if to is allowed by our allow list
+func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool {
 	allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo)
 	allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo)
 	if lh.l.Level >= logrus.TraceLevel {
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 	}
 
 
+	// We don't check our vpn network here because nebula does not support ipv6 on the inside
 	if !allow {
 	if !allow {
 		return false
 		return false
 	}
 	}
 
 
-	for _, v := range am {
-		if v.Hi == to.Hi && v.Lo == to.Lo && v.Port == to.Port {
-			return false
-		}
-	}
-
 	return true
 	return true
 }
 }
 
 
@@ -349,13 +286,6 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 	return ip
 	return ip
 }
 }
 
 
-func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
-	if lh.amLighthouse {
-		lh.DeleteVpnIP(vpnIP)
-		lh.AddRemote(vpnIP, toIp, false)
-	}
-}
-
 func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
 func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
 	if _, ok := lh.lighthouses[vpnIP]; ok {
 	if _, ok := lh.lighthouses[vpnIP]; ok {
 		return true
 		return true
@@ -496,7 +426,6 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 	return lhh.meta
 	return lhh.meta
 }
 }
 
 
-//TODO: do we need c here?
 func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
 func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
 	n := lhh.resetMeta()
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
 	err := n.Unmarshal(p)
@@ -544,13 +473,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	//TODO: we can DRY this further
 	//TODO: we can DRY this further
 	reqVpnIP := n.Details.VpnIp
 	reqVpnIP := n.Details.VpnIp
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
-	//TODO: If we use a lock on cache we can avoid holding it on lh.addrMap and keep things moving better
-	found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(cache *ip4And6) (int, error) {
+	found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		n.Type = NebulaMeta_HostQueryReply
 		n.Details.VpnIp = reqVpnIP
 		n.Details.VpnIp = reqVpnIP
 
 
-		lhh.coalesceAnswers(cache, n)
+		lhh.coalesceAnswers(c, n)
 
 
 		return n.MarshalTo(lhh.pb)
 		return n.MarshalTo(lhh.pb)
 	})
 	})
@@ -568,12 +496,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 	w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 
 
 	// This signals the other side to punch some zero byte udp packets
 	// This signals the other side to punch some zero byte udp packets
-	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(cache *ip4And6) (int, error) {
+	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
 		n.Type = NebulaMeta_HostPunchNotification
 		n.Details.VpnIp = vpnIp
 		n.Details.VpnIp = vpnIp
 
 
-		lhh.coalesceAnswers(cache, n)
+		lhh.coalesceAnswers(c, n)
 
 
 		return n.MarshalTo(lhh.pb)
 		return n.MarshalTo(lhh.pb)
 	})
 	})
@@ -591,12 +519,24 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 	w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 }
 
 
-func (lhh *LightHouseHandler) coalesceAnswers(cache *ip4And6, n *NebulaMeta) {
-	n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.v4...)
-	n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.learnedV4...)
+func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
+	if c.v4 != nil {
+		if c.v4.learned != nil {
+			n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned)
+		}
+		if c.v4.reported != nil && len(c.v4.reported) > 0 {
+			n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...)
+		}
+	}
 
 
-	n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.v6...)
-	n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.learnedV6...)
+	if c.v6 != nil {
+		if c.v6.learned != nil {
+			n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned)
+		}
+		if c.v6.reported != nil && len(c.v6.reported) > 0 {
+			n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...)
+		}
+	}
 }
 }
 
 
 func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
 func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
@@ -604,14 +544,14 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32)
 		return
 		return
 	}
 	}
 
 
-	// We can't just slam the responses in as they may come from multiple lighthouses and we should coalesce the answers
-	for _, to := range n.Details.Ip4AndPorts {
-		lhh.lh.addRemoteV4(n.Details.VpnIp, to, false, false)
-	}
+	lhh.lh.Lock()
+	am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
+	am.Lock()
+	lhh.lh.Unlock()
 
 
-	for _, to := range n.Details.Ip6AndPorts {
-		lhh.lh.addRemoteV6(n.Details.VpnIp, to, false, false)
-	}
+	am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	am.Unlock()
 
 
 	// Non-blocking attempt to trigger, skip if it would block
 	// Non-blocking attempt to trigger, skip if it would block
 	select {
 	select {
@@ -637,35 +577,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	}
 	}
 
 
 	lhh.lh.Lock()
 	lhh.lh.Lock()
-	defer lhh.lh.Unlock()
-	am := lhh.lh.unlockedGetAddrs(vpnIp)
-
-	//TODO: other note on a lock for am so we can release more quickly and lock our real unit of change which is far less contended
-
-	// We don't accumulate addresses being told to us
-	am.v4 = am.v4[:0]
-	am.v6 = am.v6[:0]
-
-	for _, v := range n.Details.Ip4AndPorts {
-		if lhh.lh.unlockedShouldAddV4(am.v4, v) {
-			am.v4 = append(am.v4, v)
-		}
-	}
-
-	for _, v := range n.Details.Ip6AndPorts {
-		if lhh.lh.unlockedShouldAddV6(am.v6, v) {
-			am.v6 = append(am.v6, v)
-		}
-	}
+	am := lhh.lh.unlockedGetRemoteList(vpnIp)
+	am.Lock()
+	lhh.lh.Unlock()
 
 
-	// We prefer the first n addresses if we got too big
-	if len(am.v4) > MaxRemotes {
-		am.v4 = am.v4[:MaxRemotes]
-	}
-
-	if len(am.v6) > MaxRemotes {
-		am.v6 = am.v6[:MaxRemotes]
-	}
+	am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	am.Unlock()
 }
 }
 
 
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
@@ -716,33 +634,6 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
 	}
 	}
 }
 }
 
 
-func TransformLHReplyToUdpAddrs(ips *ip4And6) []*udpAddr {
-	addrs := make([]*udpAddr, len(ips.v4)+len(ips.v6)+len(ips.learnedV4)+len(ips.learnedV6))
-	i := 0
-
-	for _, v := range ips.learnedV4 {
-		addrs[i] = NewUDPAddrFromLH4(v)
-		i++
-	}
-
-	for _, v := range ips.v4 {
-		addrs[i] = NewUDPAddrFromLH4(v)
-		i++
-	}
-
-	for _, v := range ips.learnedV6 {
-		addrs[i] = NewUDPAddrFromLH6(v)
-		i++
-	}
-
-	for _, v := range ips.v6 {
-		addrs[i] = NewUDPAddrFromLH6(v)
-		i++
-	}
-
-	return addrs
-}
-
 // ipMaskContains checks if testIp is contained by ip after applying a cidr
 // ipMaskContains checks if testIp is contained by ip after applying a cidr
 // zeros is 32 - bits from net.IPMask.Size()
 // zeros is 32 - bits from net.IPMask.Size()
 func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {
 func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {

+ 100 - 79
lighthouse_test.go

@@ -48,16 +48,16 @@ func Test_lhStaticMapping(t *testing.T) {
 
 
 	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
 	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
 
 
-	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
+	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
 	err := meh.ValidateLHStaticEntries()
 	err := meh.ValidateLHStaticEntries()
 	assert.Nil(t, err)
 	assert.Nil(t, err)
 
 
 	lh2 := "10.128.0.3"
 	lh2 := "10.128.0.3"
 	lh2IP := net.ParseIP(lh2)
 	lh2IP := net.ParseIP(lh2)
 
 
-	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
+	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
 	err = meh.ValidateLHStaticEntries()
 	err = meh.ValidateLHStaticEntries()
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
 }
@@ -73,17 +73,27 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 
 
 	hAddr := NewUDPAddrFromString("4.5.6.7:12345")
 	hAddr := NewUDPAddrFromString("4.5.6.7:12345")
 	hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
 	hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = &ip4And6{v4: []*Ip4AndPort{
-		NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
-		NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port))},
-	}
+	lh.addrMap[3] = NewRemoteList()
+	lh.addrMap[3].unlockedSetV4(
+		3,
+		[]*Ip4AndPort{
+			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
+			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
 
 
 	rAddr := NewUDPAddrFromString("1.2.2.3:12345")
 	rAddr := NewUDPAddrFromString("1.2.2.3:12345")
 	rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
 	rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = &ip4And6{v4: []*Ip4AndPort{
-		NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
-		NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port))},
-	}
+	lh.addrMap[2] = NewRemoteList()
+	lh.addrMap[2].unlockedSetV4(
+		3,
+		[]*Ip4AndPort{
+			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
+			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
 
 
 	mw := &mockEncWriter{}
 	mw := &mockEncWriter{}
 
 
@@ -173,7 +183,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
 
 	// Ensure proper ordering and limiting
 	// Ensure proper ordering and limiting
-	// Send 12 addrs, get 10 back, one removed on a dupe check the other by limiting
+	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
 	newLHHostUpdate(
 	newLHHostUpdate(
 		myUdpAddr0,
 		myUdpAddr0,
 		myVpnIp,
 		myVpnIp,
@@ -191,11 +201,12 @@ func TestLighthouse_Memory(t *testing.T) {
 			myUdpAddr10,
 			myUdpAddr10,
 			myUdpAddr11, // This should get cut
 			myUdpAddr11, // This should get cut
 		}, lhh)
 		}, lhh)
+
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 	assertIp4InArray(
 		t,
 		t,
 		r.msg.Details.Ip4AndPorts,
 		r.msg.Details.Ip4AndPorts,
-		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr10,
+		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 	)
 
 
 	// Make sure we won't add ips in our vpn network
 	// Make sure we won't add ips in our vpn network
@@ -247,71 +258,71 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
 	lhh.HandleRequest(fromAddr, vpnIp, b, w)
 	lhh.HandleRequest(fromAddr, vpnIp, b, w)
 }
 }
 
 
-func Test_lhRemoteAllowList(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
-	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
-		"10.20.0.0/12": false,
-	}
-	allowList, err := c.GetAllowList("remoteallowlist", false)
-	assert.Nil(t, err)
-
-	lh1 := "10.128.0.2"
-	lh1IP := net.ParseIP(lh1)
-
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	lh.SetRemoteAllowList(allowList)
-
-	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
-	remote1IP := net.ParseIP("10.20.0.3")
-	lh.AddRemote(ip2int(remote1IP), NewUDPAddr(remote1IP, uint16(4242)), true)
-	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
-	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v4)
-	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v6)
-
-	// Make sure a good ip enters the cache and addrMap
-	remote2IP := net.ParseIP("10.128.0.3")
-	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
-	lh.AddRemote(ip2int(remote2IP), remote2UDPAddr, true)
-	assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote2UDPAddr)
-
-	// Another good ip gets into the cache, ordering is inverted
-	remote3IP := net.ParseIP("10.128.0.4")
-	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
-	lh.AddRemote(ip2int(remote2IP), remote3UDPAddr, true)
-	assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote3UDPAddr, remote2UDPAddr)
-
-	// If we exceed the length limit we should only have the most recent addresses
-	addedAddrs := []*udpAddr{}
-	for i := 0; i < 11; i++ {
-		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
-		lh.AddRemote(ip2int(remote2IP), remoteUDPAddr, true)
-		// The first entry here is a duplicate, don't add it to the assert list
-		if i != 0 {
-			addedAddrs = append(addedAddrs, remoteUDPAddr)
-		}
-	}
-
-	// We should only have the last 10 of what we tried to add
-	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
-	ln := len(addedAddrs)
-	assertIp4InArray(
-		t,
-		lh.addrMap[ip2int(remote2IP)].learnedV4,
-		addedAddrs[ln-1],
-		addedAddrs[ln-2],
-		addedAddrs[ln-3],
-		addedAddrs[ln-4],
-		addedAddrs[ln-5],
-		addedAddrs[ln-6],
-		addedAddrs[ln-7],
-		addedAddrs[ln-8],
-		addedAddrs[ln-9],
-		addedAddrs[ln-10],
-	)
-}
+//TODO: this is a RemoteList test
+//func Test_lhRemoteAllowList(t *testing.T) {
+//	l := NewTestLogger()
+//	c := NewConfig(l)
+//	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
+//		"10.20.0.0/12": false,
+//	}
+//	allowList, err := c.GetAllowList("remoteallowlist", false)
+//	assert.Nil(t, err)
+//
+//	lh1 := "10.128.0.2"
+//	lh1IP := net.ParseIP(lh1)
+//
+//	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
+//
+//	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+//	lh.SetRemoteAllowList(allowList)
+//
+//	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
+//	remote1IP := net.ParseIP("10.20.0.3")
+//	remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
+//	remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
+//	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
+//	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
+//
+//	// Make sure a good ip enters the cache and addrMap
+//	remote2IP := net.ParseIP("10.128.0.3")
+//	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
+//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
+//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
+//
+//	// Another good ip gets into the cache, ordering is inverted
+//	remote3IP := net.ParseIP("10.128.0.4")
+//	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
+//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
+//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
+//
+//	// If we exceed the length limit we should only have the most recent addresses
+//	addedAddrs := []*udpAddr{}
+//	for i := 0; i < 11; i++ {
+//		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
+//		lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
+//		// The first entry here is a duplicate, don't add it to the assert list
+//		if i != 0 {
+//			addedAddrs = append(addedAddrs, remoteUDPAddr)
+//		}
+//	}
+//
+//	// We should only have the last 10 of what we tried to add
+//	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
+//	assertUdpAddrInArray(
+//		t,
+//		lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
+//		addedAddrs[0],
+//		addedAddrs[1],
+//		addedAddrs[2],
+//		addedAddrs[3],
+//		addedAddrs[4],
+//		addedAddrs[5],
+//		addedAddrs[6],
+//		addedAddrs[7],
+//		addedAddrs[8],
+//		addedAddrs[9],
+//	)
+//}
 
 
 func Test_ipMaskContains(t *testing.T) {
 func Test_ipMaskContains(t *testing.T) {
 	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
 	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
@@ -354,6 +365,16 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
 	}
 	}
 }
 }
 
 
+// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
+func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
+	assert.Len(t, have, len(want))
+	for k, w := range want {
+		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
+			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+		}
+	}
+}
+
 func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
 func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
 	addrs := make([]*udpAddr, len(ips))
 	addrs := make([]*udpAddr, len(ips))
 	for k, v := range ips {
 	for k, v := range ips {

+ 3 - 4
main.go

@@ -221,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 	}
 
 
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
-	hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
+
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
 	hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
 	hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
 
 
@@ -302,14 +302,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 				if err != nil {
 				if err != nil {
 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
 				}
-				lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true)
+				lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
 			}
 			}
 		} else {
 		} else {
 			ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
 			ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
 			if err != nil {
 			if err != nil {
 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 			}
 			}
-			lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true)
+			lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
 		}
 		}
 	}
 	}
 
 
@@ -328,7 +328,6 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	handshakeConfig := HandshakeConfig{
 	handshakeConfig := HandshakeConfig{
 		tryInterval:   config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
 		tryInterval:   config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
 		retries:       config.GetInt("handshakes.retries", DefaultHandshakeRetries),
 		retries:       config.GetInt("handshakes.retries", DefaultHandshakeRetries),
-		waitRotation:  config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
 		triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
 		triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
 
 
 		messageMetrics: messageMetrics,
 		messageMetrics: messageMetrics,

+ 6 - 3
outside.go

@@ -132,6 +132,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 	f.connectionManager.In(hostinfo.hostId)
 	f.connectionManager.In(hostinfo.hostId)
 }
 }
 
 
+// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	f.connectionManager.ClearIP(hostInfo.hostId)
 	f.connectionManager.ClearIP(hostInfo.hostId)
@@ -140,6 +141,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	f.hostMap.DeleteHostInfo(hostInfo)
 	f.hostMap.DeleteHostInfo(hostInfo)
 }
 }
 
 
+// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
+func (f *Interface) sendCloseTunnel(h *HostInfo) {
+	f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+}
+
 func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 	if hostDidRoam(hostinfo.remote, addr) {
 	if hostDidRoam(hostinfo.remote, addr) {
 		if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
 		if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
@@ -160,9 +166,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 		remoteCopy := *hostinfo.remote
 		remoteCopy := *hostinfo.remote
 		hostinfo.lastRoamRemote = &remoteCopy
 		hostinfo.lastRoamRemote = &remoteCopy
 		hostinfo.SetRemote(addr)
 		hostinfo.SetRemote(addr)
-		if f.lightHouse.amLighthouse {
-			f.lightHouse.AddRemote(hostinfo.hostId, addr, false)
-		}
 	}
 	}
 
 
 }
 }

+ 500 - 0
remote_list.go

@@ -0,0 +1,500 @@
+package nebula
+
+import (
+	"bytes"
+	"net"
+	"sort"
+	"sync"
+)
+
+// forEachFunc is used to benefit folks that want to do work inside the lock
+type forEachFunc func(addr *udpAddr, preferred bool)
+
+// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
+type checkFuncV4 func(to *Ip4AndPort) bool
+type checkFuncV6 func(to *Ip6AndPort) bool
+
+// CacheMap is a struct that better represents the lighthouse cache for humans
+// The string key is the owners vpnIp
+type CacheMap map[string]*Cache
+
+// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
+// We don't reason about ipv4 vs ipv6 here
+type Cache struct {
+	Learned  []*udpAddr `json:"learned,omitempty"`
+	Reported []*udpAddr `json:"reported,omitempty"`
+}
+
+//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
+// We will never clean learned/reported information for them as it stands today
+
+// cache is an internal struct that splits v4 and v6 addresses inside the cache map
+type cache struct {
+	v4 *cacheV4
+	v6 *cacheV6
+}
+
+// cacheV4 stores learned and reported ipv4 records under cache
+type cacheV4 struct {
+	learned  *Ip4AndPort
+	reported []*Ip4AndPort
+}
+
+// cacheV4 stores learned and reported ipv6 records under cache
+type cacheV6 struct {
+	learned  *Ip6AndPort
+	reported []*Ip6AndPort
+}
+
+// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
+// It serves as a local cache of query replies, host update notifications, and locally learned addresses
+type RemoteList struct {
+	// Every interaction with internals requires a lock!
+	sync.RWMutex
+
+	// A deduplicated set of addresses. Any accessor should lock beforehand.
+	addrs []*udpAddr
+
+	// These are maps to store v4 and v6 addresses per lighthouse
+	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
+	// For learned addresses, this is the vpnIp that sent the packet
+	cache map[uint32]*cache
+
+	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
+	// They should not be tried again during a handshake
+	badRemotes []*udpAddr
+
+	// A flag that the cache may have changed and addrs needs to be rebuilt
+	shouldRebuild bool
+}
+
+// NewRemoteList creates a new empty RemoteList
+func NewRemoteList() *RemoteList {
+	return &RemoteList{
+		addrs: make([]*udpAddr, 0),
+		cache: make(map[uint32]*cache),
+	}
+}
+
+// Len locks and reports the size of the deduplicated address list
+// The deduplication work may need to occur here, so you must pass preferredRanges
+func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
+	r.Rebuild(preferredRanges)
+	r.RLock()
+	defer r.RUnlock()
+	return len(r.addrs)
+}
+
+// ForEach locks and will call the forEachFunc for every deduplicated address in the list
+// The deduplication work may need to occur here, so you must pass preferredRanges
+func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
+	r.Rebuild(preferredRanges)
+	r.RLock()
+	for _, v := range r.addrs {
+		forEach(v, isPreferred(v.IP, preferredRanges))
+	}
+	r.RUnlock()
+}
+
+// CopyAddrs locks and makes a deep copy of the deduplicated address list
+// The deduplication work may need to occur here, so you must pass preferredRanges
+func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
+	r.Rebuild(preferredRanges)
+
+	r.RLock()
+	defer r.RUnlock()
+	c := make([]*udpAddr, len(r.addrs))
+	for i, v := range r.addrs {
+		c[i] = v.Copy()
+	}
+	return c
+}
+
+// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
+// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
+// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
+//TODO: this needs to support the allow list list
+func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
+	r.Lock()
+	defer r.Unlock()
+	if v4 := addr.IP.To4(); v4 != nil {
+		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
+	} else {
+		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
+	}
+}
+
+// CopyCache locks and creates a more human friendly form of the internal address cache.
+// This may contain duplicates and blocked addresses
+func (r *RemoteList) CopyCache() *CacheMap {
+	r.RLock()
+	defer r.RUnlock()
+
+	cm := make(CacheMap)
+	getOrMake := func(vpnIp string) *Cache {
+		c := cm[vpnIp]
+		if c == nil {
+			c = &Cache{
+				Learned:  make([]*udpAddr, 0),
+				Reported: make([]*udpAddr, 0),
+			}
+			cm[vpnIp] = c
+		}
+		return c
+	}
+
+	for owner, mc := range r.cache {
+		c := getOrMake(IntIp(owner).String())
+
+		if mc.v4 != nil {
+			if mc.v4.learned != nil {
+				c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
+			}
+
+			for _, a := range mc.v4.reported {
+				c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
+			}
+		}
+
+		if mc.v6 != nil {
+			if mc.v6.learned != nil {
+				c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
+			}
+
+			for _, a := range mc.v6.reported {
+				c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
+			}
+		}
+	}
+
+	return &cm
+}
+
+// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
+func (r *RemoteList) BlockRemote(bad *udpAddr) {
+	r.Lock()
+	defer r.Unlock()
+
+	// Check if we already blocked this addr
+	if r.unlockedIsBad(bad) {
+		return
+	}
+
+	// We copy here because we are taking something else's memory and we can't trust everything
+	r.badRemotes = append(r.badRemotes, bad.Copy())
+
+	// Mark the next interaction must recollect/dedupe
+	r.shouldRebuild = true
+}
+
+// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
+func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
+	r.RLock()
+	defer r.RUnlock()
+
+	c := make([]*udpAddr, len(r.badRemotes))
+	for i, v := range r.badRemotes {
+		c[i] = v.Copy()
+	}
+	return c
+}
+
+// ResetBlockedRemotes locks and clears the blocked remotes list
+func (r *RemoteList) ResetBlockedRemotes() {
+	r.Lock()
+	r.badRemotes = nil
+	r.Unlock()
+}
+
+// Rebuild locks and generates the deduplicated address list only if there is work to be done
+// There is generally no reason to call this directly but it is safe to do so
+func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
+	r.Lock()
+	defer r.Unlock()
+
+	// Only rebuild if the cache changed
+	//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
+	if r.shouldRebuild {
+		r.unlockedCollect()
+		r.shouldRebuild = false
+	}
+
+	// Always re-sort, preferredRanges can change via HUP
+	r.unlockedSort(preferredRanges)
+}
+
+// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
+func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
+	for _, v := range r.badRemotes {
+		if v.Equals(remote) {
+			return true
+		}
+	}
+	return false
+}
+
+// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
+// deduplicated address list as dirty
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
+	r.shouldRebuild = true
+	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
+}
+
+// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
+// and marks the deduplicated address list as dirty
+func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV4(ownerVpnIp)
+
+	// Reset the slice
+	c.reported = c.reported[:0]
+
+	// We can't take their array but we can take their pointers
+	for _, v := range to[:minInt(len(to), MaxRemotes)] {
+		if check(v) {
+			c.reported = append(c.reported, v)
+		}
+	}
+}
+
+// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
+// This is only useful for establishing static hosts
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV4(ownerVpnIp)
+
+	// We are doing the easy append because this is rarely called
+	c.reported = append([]*Ip4AndPort{to}, c.reported...)
+	if len(c.reported) > MaxRemotes {
+		c.reported = c.reported[:MaxRemotes]
+	}
+}
+
+// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
+// deduplicated address list as dirty
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
+	r.shouldRebuild = true
+	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
+}
+
+// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
+// and marks the deduplicated address list as dirty
+func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV6(ownerVpnIp)
+
+	// Reset the slice
+	c.reported = c.reported[:0]
+
+	// We can't take their array but we can take their pointers
+	for _, v := range to[:minInt(len(to), MaxRemotes)] {
+		if check(v) {
+			c.reported = append(c.reported, v)
+		}
+	}
+}
+
+// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
+// This is only useful for establishing static hosts
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV6(ownerVpnIp)
+
+	// We are doing the easy append because this is rarely called
+	c.reported = append([]*Ip6AndPort{to}, c.reported...)
+	if len(c.reported) > MaxRemotes {
+		c.reported = c.reported[:MaxRemotes]
+	}
+}
+
+// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
+// The caller must dirty the learned address cache if required
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
+	am := r.cache[ownerVpnIp]
+	if am == nil {
+		am = &cache{}
+		r.cache[ownerVpnIp] = am
+	}
+	// Avoid occupying memory for v6 addresses if we never have any
+	if am.v4 == nil {
+		am.v4 = &cacheV4{}
+	}
+	return am.v4
+}
+
+// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
+// The caller must dirty the learned address cache if required
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
+	am := r.cache[ownerVpnIp]
+	if am == nil {
+		am = &cache{}
+		r.cache[ownerVpnIp] = am
+	}
+	// Avoid occupying memory for v4 addresses if we never have any
+	if am.v6 == nil {
+		am.v6 = &cacheV6{}
+	}
+	return am.v6
+}
+
+// unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list.
+// The result of this function can contain duplicates. unlockedSort handles cleaning it.
+func (r *RemoteList) unlockedCollect() {
+	addrs := r.addrs[:0]
+
+	for _, c := range r.cache {
+		if c.v4 != nil {
+			if c.v4.learned != nil {
+				u := NewUDPAddrFromLH4(c.v4.learned)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+
+			for _, v := range c.v4.reported {
+				u := NewUDPAddrFromLH4(v)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+		}
+
+		if c.v6 != nil {
+			if c.v6.learned != nil {
+				u := NewUDPAddrFromLH6(c.v6.learned)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+
+			for _, v := range c.v6.reported {
+				u := NewUDPAddrFromLH6(v)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+		}
+	}
+
+	r.addrs = addrs
+}
+
+// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
+func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
+	n := len(r.addrs)
+	if n < 2 {
+		return
+	}
+
+	lessFunc := func(i, j int) bool {
+		a := r.addrs[i]
+		b := r.addrs[j]
+		// Preferred addresses first
+
+		aPref := isPreferred(a.IP, preferredRanges)
+		bPref := isPreferred(b.IP, preferredRanges)
+		switch {
+		case aPref && !bPref:
+			// If i is preferred and j is not, i is less than j
+			return true
+
+		case !aPref && bPref:
+			// If j is preferred then i is not due to the else, i is not less than j
+			return false
+
+		default:
+			// Both i an j are either preferred or not, sort within that
+		}
+
+		// ipv6 addresses 2nd
+		a4 := a.IP.To4()
+		b4 := b.IP.To4()
+		switch {
+		case a4 == nil && b4 != nil:
+			// If i is v6 and j is v4, i is less than j
+			return true
+
+		case a4 != nil && b4 == nil:
+			// If j is v6 and i is v4, i is not less than j
+			return false
+
+		case a4 != nil && b4 != nil:
+			// Special case for ipv4, a4 and b4 are not nil
+			aPrivate := isPrivateIP(a4)
+			bPrivate := isPrivateIP(b4)
+			switch {
+			case !aPrivate && bPrivate:
+				// If i is a public ip (not private) and j is a private ip, i is less then j
+				return true
+
+			case aPrivate && !bPrivate:
+				// If j is public (not private) then i is private due to the else, i is not less than j
+				return false
+
+			default:
+				// Both i an j are either public or private, sort within that
+			}
+
+		default:
+			// Both i an j are either ipv4 or ipv6, sort within that
+		}
+
+		// lexical order of ips 3rd
+		c := bytes.Compare(a.IP, b.IP)
+		if c == 0 {
+			// Ips are the same, Lexical order of ports 4th
+			return a.Port < b.Port
+		}
+
+		// Ip wasn't the same
+		return c < 0
+	}
+
+	// Sort it
+	sort.Slice(r.addrs, lessFunc)
+
+	// Deduplicate
+	a, b := 0, 1
+	for b < n {
+		if !r.addrs[a].Equals(r.addrs[b]) {
+			a++
+			if a != b {
+				r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
+			}
+		}
+		b++
+	}
+
+	r.addrs = r.addrs[:a+1]
+	return
+}
+
+// minInt returns the minimum integer of a or b
+func minInt(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+// isPreferred returns true of the ip is contained in the preferredRanges list
+func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
+	//TODO: this would be better in a CIDR6Tree
+	for _, p := range preferredRanges {
+		if p.Contains(ip) {
+			return true
+		}
+	}
+	return false
+}
+
+var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
+var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
+var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
+
+// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
+func isPrivateIP(ip net.IP) bool {
+	//TODO: another great cidrtree option
+	//TODO: Private for ipv6 or just let it ride?
+	return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
+}

+ 228 - 0
remote_list_test.go

@@ -0,0 +1,228 @@
+package nebula
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRemoteList_Rebuild(t *testing.T) {
+	rl := NewRemoteList()
+	rl.unlockedSetV4(
+		0,
+		[]*Ip4AndPort{
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
+			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
+			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
+
+	rl.unlockedSetV6(
+		1,
+		[]*Ip6AndPort{
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
+			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
+		},
+		func(*Ip6AndPort) bool { return true },
+	)
+
+	rl.Rebuild([]*net.IPNet{})
+	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
+
+	// ipv6 first, sorted lexically within
+	assert.Equal(t, "[1::1]:1", rl.addrs[0].String())
+	assert.Equal(t, "[1::1]:2", rl.addrs[1].String())
+	assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String())
+
+	// ipv4 last, sorted by public first, then private, lexically within them
+	assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String())
+	assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String())
+	assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String())
+	assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String())
+	assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
+	assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
+	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
+
+	// Now ensure we can hoist ipv4 up
+	_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
+	assert.NoError(t, err)
+	rl.Rebuild([]*net.IPNet{ipNet})
+	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
+
+	// ipv4 first, public then private, lexically within them
+	assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String())
+	assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String())
+	assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String())
+	assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String())
+	assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String())
+	assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String())
+	assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String())
+
+	// ipv6 last, sorted by public first, then private, lexically within them
+	assert.Equal(t, "[1::1]:1", rl.addrs[7].String())
+	assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
+	assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
+
+	// Ensure we can hoist a specific ipv4 range over anything else
+	_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
+	assert.NoError(t, err)
+	rl.Rebuild([]*net.IPNet{ipNet})
+	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
+
+	// Preferred ipv4 first
+	assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String())
+	assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String())
+
+	// ipv6 next
+	assert.Equal(t, "[1::1]:1", rl.addrs[2].String())
+	assert.Equal(t, "[1::1]:2", rl.addrs[3].String())
+	assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String())
+
+	// the remaining ipv4 last
+	assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String())
+	assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String())
+	assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
+	assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
+	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
+}
+
+func BenchmarkFullRebuild(b *testing.B) {
+	rl := NewRemoteList()
+	rl.unlockedSetV4(
+		0,
+		[]*Ip4AndPort{
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
+			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
+
+	rl.unlockedSetV6(
+		0,
+		[]*Ip6AndPort{
+			NewIp6AndPort(net.ParseIP("1::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
+			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+		},
+		func(*Ip6AndPort) bool { return true },
+	)
+
+	b.Run("no preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{})
+		}
+	})
+
+	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
+	assert.NoError(b, err)
+	b.Run("1 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{ipNet})
+		}
+	})
+
+	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
+	assert.NoError(b, err)
+	b.Run("2 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+		}
+	})
+
+	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
+	assert.NoError(b, err)
+	b.Run("3 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+		}
+	})
+}
+
+func BenchmarkSortRebuild(b *testing.B) {
+	rl := NewRemoteList()
+	rl.unlockedSetV4(
+		0,
+		[]*Ip4AndPort{
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
+			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
+
+	rl.unlockedSetV6(
+		0,
+		[]*Ip6AndPort{
+			NewIp6AndPort(net.ParseIP("1::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
+			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+		},
+		func(*Ip6AndPort) bool { return true },
+	)
+
+	b.Run("no preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{})
+		}
+	})
+
+	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
+	rl.Rebuild([]*net.IPNet{ipNet})
+
+	assert.NoError(b, err)
+	b.Run("1 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.Rebuild([]*net.IPNet{ipNet})
+		}
+	})
+
+	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
+	rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+
+	assert.NoError(b, err)
+	b.Run("2 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+		}
+	})
+
+	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
+	rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+
+	assert.NoError(b, err)
+	b.Run("3 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+		}
+	})
+}

+ 39 - 48
ssh.go

@@ -10,8 +10,8 @@ import (
 	"os"
 	"os"
 	"reflect"
 	"reflect"
 	"runtime/pprof"
 	"runtime/pprof"
+	"sort"
 	"strings"
 	"strings"
-	"sync/atomic"
 	"syscall"
 	"syscall"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
@@ -335,8 +335,10 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 		return nil
 		return nil
 	}
 	}
 
 
-	hostMap.RLock()
-	defer hostMap.RUnlock()
+	hm := listHostMap(hostMap)
+	sort.Slice(hm, func(i, j int) bool {
+		return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
+	})
 
 
 	if fs.Json || fs.Pretty {
 	if fs.Json || fs.Pretty {
 		js := json.NewEncoder(w.GetWriter())
 		js := json.NewEncoder(w.GetWriter())
@@ -344,35 +346,15 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 			js.SetIndent("", "    ")
 			js.SetIndent("", "    ")
 		}
 		}
 
 
-		d := make([]m, len(hostMap.Hosts))
-		x := 0
-		var h m
-		for _, v := range hostMap.Hosts {
-			h = m{
-				"vpnIp":         int2ip(v.hostId),
-				"localIndex":    v.localIndexId,
-				"remoteIndex":   v.remoteIndexId,
-				"remoteAddrs":   v.CopyRemotes(),
-				"cachedPackets": len(v.packetStore),
-				"cert":          v.GetCert(),
-			}
-
-			if v.ConnectionState != nil {
-				h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
-			}
-
-			d[x] = h
-			x++
-		}
-
-		err := js.Encode(d)
+		err := js.Encode(hm)
 		if err != nil {
 		if err != nil {
 			//TODO
 			//TODO
 			return nil
 			return nil
 		}
 		}
+
 	} else {
 	} else {
-		for i, v := range hostMap.Hosts {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.CopyRemotes()))
+		for _, v := range hm {
+			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -389,8 +371,26 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 		return nil
 		return nil
 	}
 	}
 
 
+	type lighthouseInfo struct {
+		VpnIP net.IP    `json:"vpnIp"`
+		Addrs *CacheMap `json:"addrs"`
+	}
+
 	lightHouse.RLock()
 	lightHouse.RLock()
-	defer lightHouse.RUnlock()
+	addrMap := make([]lighthouseInfo, len(lightHouse.addrMap))
+	x := 0
+	for k, v := range lightHouse.addrMap {
+		addrMap[x] = lighthouseInfo{
+			VpnIP: int2ip(k),
+			Addrs: v.CopyCache(),
+		}
+		x++
+	}
+	lightHouse.RUnlock()
+
+	sort.Slice(addrMap, func(i, j int) bool {
+		return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
+	})
 
 
 	if fs.Json || fs.Pretty {
 	if fs.Json || fs.Pretty {
 		js := json.NewEncoder(w.GetWriter())
 		js := json.NewEncoder(w.GetWriter())
@@ -398,27 +398,19 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 			js.SetIndent("", "    ")
 			js.SetIndent("", "    ")
 		}
 		}
 
 
-		d := make([]m, len(lightHouse.addrMap))
-		x := 0
-		var h m
-		for vpnIp, v := range lightHouse.addrMap {
-			h = m{
-				"vpnIp": int2ip(vpnIp),
-				"addrs": TransformLHReplyToUdpAddrs(v),
-			}
-
-			d[x] = h
-			x++
-		}
-
-		err := js.Encode(d)
+		err := js.Encode(addrMap)
 		if err != nil {
 		if err != nil {
 			//TODO
 			//TODO
 			return nil
 			return nil
 		}
 		}
+
 	} else {
 	} else {
-		for vpnIp, v := range lightHouse.addrMap {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), TransformLHReplyToUdpAddrs(v)))
+		for _, v := range addrMap {
+			b, err := json.Marshal(v.Addrs)
+			if err != nil {
+				return err
+			}
+			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
@@ -469,8 +461,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	ips, _ := ifce.lightHouse.Query(vpnIp, ifce)
-	return json.NewEncoder(w.GetWriter()).Encode(ips)
+	return json.NewEncoder(w.GetWriter()).Encode(ifce.lightHouse.Query(vpnIp, ifce).CopyCache())
 }
 }
 
 
 func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
 func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
@@ -727,7 +718,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 	}
 
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
 	if err != nil {
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 	}
@@ -737,7 +728,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		enc.SetIndent("", "    ")
 		enc.SetIndent("", "    ")
 	}
 	}
 
 
-	return enc.Encode(hostInfo)
+	return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
 }
 }
 
 
 func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {
 func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {

+ 1 - 3
tun_tester.go

@@ -41,9 +41,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []r
 // These are unencrypted ip layer frames destined for another nebula node.
 // These are unencrypted ip layer frames destined for another nebula node.
 // packets should exit the udp side, capture them with udpConn.Get
 // packets should exit the udp side, capture them with udpConn.Get
 func (c *Tun) Send(packet []byte) {
 func (c *Tun) Send(packet []byte) {
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.Debug("Tun injecting packet")
-	}
+	c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
 	c.rxPackets <- packet
 	c.rxPackets <- packet
 }
 }
 
 

+ 3 - 3
udp_all.go

@@ -13,8 +13,8 @@ type udpAddr struct {
 }
 }
 
 
 func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
 func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
-	addr := udpAddr{IP: make([]byte, len(ip)), Port: port}
-	copy(addr.IP, ip)
+	addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
+	copy(addr.IP, ip.To16())
 	return &addr
 	return &addr
 }
 }
 
 
@@ -22,7 +22,7 @@ func NewUDPAddrFromString(s string) *udpAddr {
 	ip, port, err := parseIPAndPort(s)
 	ip, port, err := parseIPAndPort(s)
 	//TODO: handle err
 	//TODO: handle err
 	_ = err
 	_ = err
-	return &udpAddr{IP: ip, Port: port}
+	return &udpAddr{IP: ip.To16(), Port: port}
 }
 }
 
 
 func (ua *udpAddr) Equals(t *udpAddr) bool {
 func (ua *udpAddr) Equals(t *udpAddr) bool {

+ 9 - 28
udp_linux.go

@@ -97,40 +97,21 @@ func (u *udpConn) GetSendBuffer() (int, error) {
 }
 }
 
 
 func (u *udpConn) LocalAddr() (*udpAddr, error) {
 func (u *udpConn) LocalAddr() (*udpAddr, error) {
-	var rsa unix.RawSockaddrAny
-	var rLen = unix.SizeofSockaddrAny
-
-	_, _, err := unix.Syscall(
-		unix.SYS_GETSOCKNAME,
-		uintptr(u.sysFd),
-		uintptr(unsafe.Pointer(&rsa)),
-		uintptr(unsafe.Pointer(&rLen)),
-	)
-
-	if err != 0 {
+	sa, err := unix.Getsockname(u.sysFd)
+	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
 	addr := &udpAddr{}
 	addr := &udpAddr{}
-	if rsa.Addr.Family == unix.AF_INET {
-		pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(&rsa))
-		addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
-		copy(addr.IP, pp.Addr[:])
-
-	} else if rsa.Addr.Family == unix.AF_INET6 {
-		//TODO: this cast sucks and we can do better
-		pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(&rsa))
-		addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
-		copy(addr.IP, pp.Addr[:])
-
-	} else {
-		addr.Port = 0
-		addr.IP = []byte{}
+	switch sa := sa.(type) {
+	case *unix.SockaddrInet4:
+		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
+		addr.Port = uint16(sa.Port)
+	case *unix.SockaddrInet6:
+		addr.IP = sa.Addr[0:]
+		addr.Port = uint16(sa.Port)
 	}
 	}
 
 
-	//TODO: Just use this instead?
-	//a, b := unix.Getsockname(u.sysFd)
-
 	return addr, nil
 	return addr, nil
 }
 }
 
 

+ 9 - 1
udp_tester.go

@@ -3,6 +3,7 @@
 package nebula
 package nebula
 
 
 import (
 import (
+	"fmt"
 	"net"
 	"net"
 
 
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
@@ -53,7 +54,14 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
 // this is an encrypted packet or a handshake message in most cases
 // this is an encrypted packet or a handshake message in most cases
 // packets were transmitted from another nebula node, you can send them with Tun.Send
 // packets were transmitted from another nebula node, you can send them with Tun.Send
 func (u *udpConn) Send(packet *UdpPacket) {
 func (u *udpConn) Send(packet *UdpPacket) {
-	u.l.Infof("UDP injecting packet %+v", packet)
+	h := &Header{}
+	if err := h.Parse(packet.Data); err != nil {
+		panic(err)
+	}
+	u.l.WithField("header", h).
+		WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
+		WithField("dataLen", len(packet.Data)).
+		Info("UDP receiving injected packet")
 	u.rxPackets <- packet
 	u.rxPackets <- packet
 }
 }