Browse Source

Refactor remotes and handshaking to give every address a fair shot (#437)

Nathan Brown 4 years ago
parent
commit
710df6a876
25 changed files with 1546 additions and 1370 deletions
  1. 27 23
      control.go
  2. 6 4
      control_test.go
  3. 27 3
      control_tester.go
  4. 97 132
      e2e/handshakes_test.go
  5. 3 0
      e2e/helpers_test.go
  6. 3 0
      e2e/router/doc.go
  7. 119 20
      e2e/router/router.go
  8. 4 4
      examples/config.yml
  9. 36 31
      handshake_ix.go
  10. 122 106
      handshake_manager.go
  11. 28 185
      handshake_manager_test.go
  12. 52 249
      hostmap.go
  13. 0 168
      hostmap_test.go
  14. 7 50
      inside.go
  15. 117 226
      lighthouse.go
  16. 100 79
      lighthouse_test.go
  17. 3 4
      main.go
  18. 6 3
      outside.go
  19. 500 0
      remote_list.go
  20. 228 0
      remote_list_test.go
  21. 39 48
      ssh.go
  22. 1 3
      tun_tester.go
  23. 3 3
      udp_all.go
  24. 9 28
      udp_linux.go
  25. 9 1
      udp_tester.go

+ 27 - 23
control.go

@@ -67,23 +67,11 @@ func (c *Control) RebindUDPServer() {
 
 // ListHostmap returns details about the actual or pending (handshaking) hostmap
 func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
-	var hm *HostMap
 	if pendingMap {
-		hm = c.f.handshakeManager.pendingHostMap
+		return listHostMap(c.f.handshakeManager.pendingHostMap)
 	} else {
-		hm = c.f.hostMap
+		return listHostMap(c.f.hostMap)
 	}
-
-	hm.RLock()
-	hosts := make([]ControlHostInfo, len(hm.Hosts))
-	i := 0
-	for _, v := range hm.Hosts {
-		hosts[i] = copyHostInfo(v)
-		i++
-	}
-	hm.RUnlock()
-
-	return hosts
 }
 
 // GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
@@ -100,7 +88,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
 		return nil
 	}
 
-	ch := copyHostInfo(h)
+	ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
 	return &ch
 }
 
@@ -112,7 +100,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
 	}
 
 	hostInfo.SetRemote(addr.Copy())
-	ch := copyHostInfo(hostInfo)
+	ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
 	return &ch
 }
 
@@ -163,14 +151,17 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
 	return
 }
 
-func copyHostInfo(h *HostInfo) ControlHostInfo {
+func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	chi := ControlHostInfo{
-		VpnIP:          int2ip(h.hostId),
-		LocalIndex:     h.localIndexId,
-		RemoteIndex:    h.remoteIndexId,
-		RemoteAddrs:    h.CopyRemotes(),
-		CachedPackets:  len(h.packetStore),
-		MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
+		VpnIP:         int2ip(h.hostId),
+		LocalIndex:    h.localIndexId,
+		RemoteIndex:   h.remoteIndexId,
+		RemoteAddrs:   h.remotes.CopyAddrs(preferredRanges),
+		CachedPackets: len(h.packetStore),
+	}
+
+	if h.ConnectionState != nil {
+		chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter)
 	}
 
 	if c := h.GetCert(); c != nil {
@@ -183,3 +174,16 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {
 
 	return chi
 }
+
+func listHostMap(hm *HostMap) []ControlHostInfo {
+	hm.RLock()
+	hosts := make([]ControlHostInfo, len(hm.Hosts))
+	i := 0
+	for _, v := range hm.Hosts {
+		hosts[i] = copyHostInfo(v, hm.preferredRanges)
+		i++
+	}
+	hm.RUnlock()
+
+	return hosts
+}

+ 6 - 4
control_test.go

@@ -45,10 +45,12 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		Signature: []byte{1, 2, 1, 2, 1, 3},
 	}
 
-	remotes := []*udpAddr{remote1, remote2}
+	remotes := NewRemoteList()
+	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
+	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
 	hm.Add(ip2int(ipNet.IP), &HostInfo{
 		remote:  remote1,
-		Remotes: remotes,
+		remotes: remotes,
 		ConnectionState: &ConnectionState{
 			peerCert: crt,
 		},
@@ -59,7 +61,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 
 	hm.Add(ip2int(ipNet2.IP), &HostInfo{
 		remote:  remote1,
-		Remotes: remotes,
+		remotes: remotes,
 		ConnectionState: &ConnectionState{
 			peerCert: nil,
 		},
@@ -81,7 +83,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
 		VpnIP:          net.IPv4(1, 2, 3, 4).To4(),
 		LocalIndex:     201,
 		RemoteIndex:    200,
-		RemoteAddrs:    []*udpAddr{remote1, remote2},
+		RemoteAddrs:    []*udpAddr{remote2, remote1},
 		CachedPackets:  0,
 		Cert:           crt.Copy(),
 		MessageCounter: 0,

+ 27 - 3
control_tester.go

@@ -44,7 +44,18 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
 // InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
 // This is necessary if you did not configure static hosts or are not running a lighthouse
 func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
-	c.f.lightHouse.AddRemote(ip2int(vpnIp), &udpAddr{IP: toAddr.IP, Port: uint16(toAddr.Port)}, false)
+	c.f.lightHouse.Lock()
+	remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
+	remoteList.Lock()
+	defer remoteList.Unlock()
+	c.f.lightHouse.Unlock()
+
+	iVpnIp := ip2int(vpnIp)
+	if v4 := toAddr.IP.To4(); v4 != nil {
+		remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
+	} else {
+		remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
+	}
 }
 
 // GetFromTun will pull a packet off the tun side of nebula
@@ -84,14 +95,17 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 		SrcPort: layers.UDPPort(fromPort),
 		DstPort: layers.UDPPort(toPort),
 	}
-	udp.SetNetworkLayerForChecksum(&ip)
+	err := udp.SetNetworkLayerForChecksum(&ip)
+	if err != nil {
+		panic(err)
+	}
 
 	buffer := gopacket.NewSerializeBuffer()
 	opt := gopacket.SerializeOptions{
 		ComputeChecksums: true,
 		FixLengths:       true,
 	}
-	err := gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
+	err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
 	if err != nil {
 		panic(err)
 	}
@@ -102,3 +116,13 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
 func (c *Control) GetUDPAddr() string {
 	return c.f.outside.addr.String()
 }
+
+func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
+	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
+	if !ok {
+		return false
+	}
+
+	c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
+	return true
+}

+ 97 - 132
e2e/handshakes_test.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/slackhq/nebula"
 	"github.com/slackhq/nebula/e2e/router"
+	"github.com/stretchr/testify/assert"
 )
 
 func TestGoodHandshake(t *testing.T) {
@@ -23,35 +24,35 @@ func TestGoodHandshake(t *testing.T) {
 	myControl.Start()
 	theirControl.Start()
 
-	// Send a udp packet through to begin standing up the tunnel, this should come out the other side
+	t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side")
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
 
-	// Have them consume my stage 0 packet. They have a tunnel now
+	t.Log("Have them consume my stage 0 packet. They have a tunnel now")
 	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
 
-	// Get their stage 1 packet so that we can play with it
+	t.Log("Get their stage 1 packet so that we can play with it")
 	stage1Packet := theirControl.GetFromUDP(true)
 
-	// I consume a garbage packet with a proper nebula header for our tunnel
+	t.Log("I consume a garbage packet with a proper nebula header for our tunnel")
 	// this should log a statement and get ignored, allowing the real handshake packet to complete the tunnel
 	badPacket := stage1Packet.Copy()
 	badPacket.Data = badPacket.Data[:len(badPacket.Data)-nebula.HeaderLen]
 	myControl.InjectUDPPacket(badPacket)
 
-	// Have me consume their real stage 1 packet. I have a tunnel now
+	t.Log("Have me consume their real stage 1 packet. I have a tunnel now")
 	myControl.InjectUDPPacket(stage1Packet)
 
-	// Wait until we see my cached packet come through
+	t.Log("Wait until we see my cached packet come through")
 	myControl.WaitForType(1, 0, theirControl)
 
-	// Make sure our host infos are correct
+	t.Log("Make sure our host infos are correct")
 	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
 
-	// Get that cached packet and make sure it looks right
+	t.Log("Get that cached packet and make sure it looks right")
 	myCachedPacket := theirControl.GetFromTun(true)
 	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
 
-	// Do a bidirectional tunnel test
+	t.Log("Do a bidirectional tunnel test")
 	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, router.NewR(myControl, theirControl))
 
 	myControl.Stop()
@@ -62,14 +63,17 @@ func TestGoodHandshake(t *testing.T) {
 func TestWrongResponderHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
 
-	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1})
-	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
-	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 99})
+	// The IPs here are chosen on purpose:
+	// The current remote handling will sort by preference, public, and then lexically.
+	// So we need them to have a higher address than evil (we could apply a preference though)
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
+	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
 
-	// Add their real udp addr, which should be tried after evil. Doing this first because learned addresses are prepended
+	// Add their real udp addr, which should be tried after evil.
 	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
 
-	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. This will now be the first attempted ip
+	// Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse.
 	myControl.InjectLightHouseAddr(theirVpnIp, evilUdpAddr)
 
 	// Build a router so we don't have to reason who gets which packet
@@ -80,137 +84,98 @@ func TestWrongResponderHandshake(t *testing.T) {
 	theirControl.Start()
 	evilControl.Start()
 
-	t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)")
+	t.Log("Start the handshake process, we will route until we see our cached packet get sent to them")
 	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-	r.OnceFrom(myControl)
-	r.OnceFrom(evilControl)
+	r.RouteForAllExitFunc(func(p *nebula.UdpPacket, c *nebula.Control) router.ExitType {
+		h := &nebula.Header{}
+		err := h.Parse(p.Data)
+		if err != nil {
+			panic(err)
+		}
 
-	t.Log("I should have a tunnel with evil now and there should not be a cached packet waiting for us")
-	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
-	assertHostInfoPair(t, myUdpAddr, evilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
+		if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
+			return router.RouteAndExit
+		}
 
-	//TODO: Assert pending hostmap - I should have a correct hostinfo for them now
+		return router.KeepRouting
+	})
 
-	t.Log("Lets let the messages fly, this time we should have a tunnel with them")
-	r.OnceFrom(myControl)
-	r.OnceFrom(theirControl)
+	//TODO: Assert pending hostmap - I should have a correct hostinfo for them now
 
-	t.Log("I should now have a tunnel with them now and my original packet should get there")
-	r.RouteUntilAfterMsgType(myControl, 1, 0)
+	t.Log("My cached packet should be received by them")
 	myCachedPacket := theirControl.GetFromTun(true)
 	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
 
-	t.Log("I should now have a proper tunnel with them")
+	t.Log("Test the tunnel with them")
 	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
 	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
 
-	t.Log("Lets make sure evil is still good")
-	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
+	t.Log("Flush all packets from all controllers")
+	r.FlushAll()
+
+	t.Log("Ensure ensure I don't have any hostinfo artifacts from evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), true), "My pending hostmap should not contain evil")
+	assert.Nil(t, myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false), "My main hostmap should not contain evil")
+	//NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete
 
 	//TODO: assert hostmaps for everyone
 	t.Log("Success!")
-	//TODO: myControl is attempting to shut down 2 tunnels but is blocked on the udp txChan after the first close message
-	// what we really need here is a way to exit all the go routines loops (there are many)
-	//myControl.Stop()
-	//theirControl.Stop()
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func Test_Case1_Stage1Race(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1})
+	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
+
+	// Put their info in our lighthouse and vice versa
+	myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIp, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(myControl, theirControl)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake to start on both me and them")
+	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIp, 80, 80, []byte("Hi from them"))
+
+	t.Log("Get both stage 1 handshake packets")
+	myHsForThem := myControl.GetFromUDP(true)
+	theirHsForMe := theirControl.GetFromUDP(true)
+
+	t.Log("Now inject both stage 1 handshake packets")
+	myControl.InjectUDPPacket(theirHsForMe)
+	theirControl.InjectUDPPacket(myHsForThem)
+	//TODO: they should win, grab their index for me and make sure I use it in the end.
+
+	t.Log("They should not have a stage 2 (won the race) but I should send one")
+	theirControl.InjectUDPPacket(myControl.GetFromUDP(true))
+
+	t.Log("Route for me until I send a message packet to them")
+	myControl.WaitForType(1, 0, theirControl)
+
+	t.Log("My cached packet should be received by them")
+	myCachedPacket := theirControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
+
+	t.Log("Route for them until I send a message packet to me")
+	theirControl.WaitForType(1, 0, myControl)
+
+	t.Log("Their cached packet should be received by me")
+	theirCachedPacket := myControl.GetFromTun(true)
+	assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIp, myVpnIp, 80, 80)
+
+	t.Log("Do a bidirectional tunnel test")
+	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+	//TODO: assert hostmaps
 }
 
-////TODO: We need to test lies both as the race winner and race loser
-//func TestManyWrongResponderHandshake(t *testing.T) {
-//	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-//
-//	myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 99})
-//	theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
-//	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 1})
-//
-//	t.Log("Build a router so we don't have to reason who gets which packet")
-//	r := newRouter(myControl, theirControl, evilControl)
-//
-//	t.Log("Lets add more than 10 evil addresses, this exceeds the hostinfo remotes limit")
-//	for i := 0; i < 10; i++ {
-//		addr := net.UDPAddr{IP: evilUdpAddr.IP, Port: evilUdpAddr.Port + i}
-//		myControl.InjectLightHouseAddr(theirVpnIp, &addr)
-//		// We also need to tell our router about it
-//		r.AddRoute(addr.IP, uint16(addr.Port), evilControl)
-//	}
-//
-//	// Start the servers
-//	myControl.Start()
-//	theirControl.Start()
-//	evilControl.Start()
-//
-//	t.Log("Stand up the tunnel with evil (because the lighthouse cache is lying to us about who it is)")
-//	myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
-//
-//	t.Log("We need to spin until we get to the right remote for them")
-//	getOut := false
-//	injected := false
-//	for {
-//		t.Log("Routing for me and evil while we work through the bad ips")
-//		r.RouteExitFunc(myControl, func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType {
-//			// We should stop routing right after we see a packet coming from us to them
-//			if *receiver == *theirControl {
-//				getOut = true
-//				return drainAndExit
-//			}
-//
-//			// We need to poke our real ip in at some point, this is a well protected check looking for that moment
-//			if *receiver == *evilControl {
-//				hi := myControl.GetHostInfoByVpnIP(ip2int(theirVpnIp), true)
-//				if !injected && len(hi.RemoteAddrs) == 1 {
-//					t.Log("I am on my last ip for them, time to inject the real one into my lighthouse")
-//					myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
-//					injected = true
-//				}
-//				return drainAndExit
-//			}
-//
-//			return keepRouting
-//		})
-//
-//		if getOut {
-//			break
-//		}
-//
-//		r.RouteForUntilAfterToAddr(evilControl, myUdpAddr, drainAndExit)
-//	}
-//
-//	t.Log("I should have a tunnel with evil and them, evil should not have a cached packet")
-//	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
-//	evilHostInfo := myControl.GetHostInfoByVpnIP(ip2int(evilVpnIp), false)
-//	realEvilUdpAddr := &net.UDPAddr{IP: evilHostInfo.CurrentRemote.IP, Port: int(evilHostInfo.CurrentRemote.Port)}
-//
-//	t.Log("Assert mine and evil's host pairs", evilUdpAddr, realEvilUdpAddr)
-//	assertHostInfoPair(t, myUdpAddr, realEvilUdpAddr, myVpnIp, evilVpnIp, myControl, evilControl)
-//
-//	//t.Log("Draining everyones packets")
-//	//r.Drain(theirControl)
-//	//r.DrainAll(myControl, theirControl, evilControl)
-//	//
-//	//go func() {
-//	//	for {
-//	//		time.Sleep(10 * time.Millisecond)
-//	//		t.Log(len(theirControl.GetUDPTxChan()))
-//	//		t.Log(len(theirControl.GetTunTxChan()))
-//	//		t.Log(len(myControl.GetUDPTxChan()))
-//	//		t.Log(len(evilControl.GetUDPTxChan()))
-//	//		t.Log("=====")
-//	//	}
-//	//}()
-//
-//	t.Log("I should have a tunnel with them now and my original packet should get there")
-//	r.RouteUntilAfterMsgType(myControl, 1, 0)
-//	myCachedPacket := theirControl.GetFromTun(true)
-//
-//	t.Log("Got the cached packet, lets test the tunnel")
-//	assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
-//
-//	t.Log("Testing tunnels with them")
-//	assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
-//	assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
-//
-//	t.Log("Testing tunnels with evil")
-//	assertTunnel(t, myVpnIp, evilVpnIp, myControl, evilControl, r)
-//
-//	//TODO: assert hostmaps for everyone
-//}
+//TODO: add a test with many lies

+ 3 - 0
e2e/helpers_test.go

@@ -64,6 +64,9 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 				"host":  "any",
 			}},
 		},
+		//"handshakes": m{
+		//	"try_interval": "1s",
+		//},
 		"listen": m{
 			"host": udpAddr.IP.String(),
 			"port": udpAddr.Port,

+ 3 - 0
e2e/router/doc.go

@@ -0,0 +1,3 @@
+package router
+
+// This file exists to allow `go fmt` to traverse here on its own. The build tags were keeping it out before

+ 119 - 20
e2e/router/router.go

@@ -5,6 +5,7 @@ package router
 import (
 	"fmt"
 	"net"
+	"reflect"
 	"strconv"
 	"sync"
 
@@ -28,18 +29,18 @@ type R struct {
 	sync.Mutex
 }
 
-type exitType int
+type ExitType int
 
 const (
 	// Keeps routing, the function will get called again on the next packet
-	keepRouting exitType = 0
+	KeepRouting ExitType = 0
 	// Does not route this packet and exits immediately
-	exitNow exitType = 1
+	ExitNow ExitType = 1
 	// Routes this packet and exits immediately afterwards
-	routeAndExit exitType = 2
+	RouteAndExit ExitType = 2
 )
 
-type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) exitType
+type ExitFunc func(packet *nebula.UdpPacket, receiver *nebula.Control) ExitType
 
 func NewR(controls ...*nebula.Control) *R {
 	r := &R{
@@ -77,8 +78,8 @@ func (r *R) AddRoute(ip net.IP, port uint16, c *nebula.Control) {
 // OnceFrom will route a single packet from sender then return
 // If the router doesn't have the nebula controller for that address, we panic
 func (r *R) OnceFrom(sender *nebula.Control) {
-	r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) exitType {
-		return routeAndExit
+	r.RouteExitFunc(sender, func(*nebula.UdpPacket, *nebula.Control) ExitType {
+		return RouteAndExit
 	})
 }
 
@@ -116,7 +117,6 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) []
 //   - exitNow: the packet will not be routed and this call will return immediately
 //   - routeAndExit: this call will return immediately after routing the last packet from sender
 //   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
-//TODO: is this RouteWhile?
 func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 	h := &nebula.Header{}
 	for {
@@ -136,16 +136,16 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 
 		e := whatDo(p, receiver)
 		switch e {
-		case exitNow:
+		case ExitNow:
 			r.Unlock()
 			return
 
-		case routeAndExit:
+		case RouteAndExit:
 			receiver.InjectUDPPacket(p)
 			r.Unlock()
 			return
 
-		case keepRouting:
+		case KeepRouting:
 			receiver.InjectUDPPacket(p)
 
 		default:
@@ -160,35 +160,135 @@ func (r *R) RouteExitFunc(sender *nebula.Control, whatDo ExitFunc) {
 // If the router doesn't have the nebula controller for that address, we panic
 func (r *R) RouteUntilAfterMsgType(sender *nebula.Control, msgType nebula.NebulaMessageType, subType nebula.NebulaMessageSubType) {
 	h := &nebula.Header{}
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType {
+	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
 		if err := h.Parse(p.Data); err != nil {
 			panic(err)
 		}
 		if h.Type == msgType && h.Subtype == subType {
-			return routeAndExit
+			return RouteAndExit
 		}
 
-		return keepRouting
+		return KeepRouting
 	})
 }
 
 // RouteForUntilAfterToAddr will route for sender and return only after it sees and sends a packet destined for toAddr
 // finish can be any of the exitType values except `keepRouting`, the default value is `routeAndExit`
 // If the router doesn't have the nebula controller for that address, we panic
-func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish exitType) {
-	if finish == keepRouting {
-		finish = routeAndExit
+func (r *R) RouteForUntilAfterToAddr(sender *nebula.Control, toAddr *net.UDPAddr, finish ExitType) {
+	if finish == KeepRouting {
+		finish = RouteAndExit
 	}
 
-	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) exitType {
+	r.RouteExitFunc(sender, func(p *nebula.UdpPacket, r *nebula.Control) ExitType {
 		if p.ToIp.Equal(toAddr.IP) && p.ToPort == uint16(toAddr.Port) {
 			return finish
 		}
 
-		return keepRouting
+		return KeepRouting
 	})
 }
 
+// RouteForAllExitFunc will route for every registered controller and calls the whatDo func with each udp packet from
+// whatDo can return:
+//   - exitNow: the packet will not be routed and this call will return immediately
+//   - routeAndExit: this call will return immediately after routing the last packet from sender
+//   - keepRouting: the packet will be routed and whatDo will be called again on the next packet from sender
+func (r *R) RouteForAllExitFunc(whatDo ExitFunc) {
+	sc := make([]reflect.SelectCase, len(r.controls))
+	cm := make([]*nebula.Control, len(r.controls))
+
+	i := 0
+	for _, c := range r.controls {
+		sc[i] = reflect.SelectCase{
+			Dir:  reflect.SelectRecv,
+			Chan: reflect.ValueOf(c.GetUDPTxChan()),
+			Send: reflect.Value{},
+		}
+
+		cm[i] = c
+		i++
+	}
+
+	for {
+		x, rx, _ := reflect.Select(sc)
+		r.Lock()
+
+		p := rx.Interface().(*nebula.UdpPacket)
+
+		outAddr := cm[x].GetUDPAddr()
+		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
+		receiver := r.getControl(outAddr, inAddr, p)
+		if receiver == nil {
+			r.Unlock()
+			panic("Can't route for host: " + inAddr)
+		}
+
+		e := whatDo(p, receiver)
+		switch e {
+		case ExitNow:
+			r.Unlock()
+			return
+
+		case RouteAndExit:
+			receiver.InjectUDPPacket(p)
+			r.Unlock()
+			return
+
+		case KeepRouting:
+			receiver.InjectUDPPacket(p)
+
+		default:
+			panic(fmt.Sprintf("Unknown exitFunc return: %v", e))
+		}
+		r.Unlock()
+	}
+}
+
+// FlushAll will route for every registered controller, exiting once there are no packets left to route
+func (r *R) FlushAll() {
+	sc := make([]reflect.SelectCase, len(r.controls))
+	cm := make([]*nebula.Control, len(r.controls))
+
+	i := 0
+	for _, c := range r.controls {
+		sc[i] = reflect.SelectCase{
+			Dir:  reflect.SelectRecv,
+			Chan: reflect.ValueOf(c.GetUDPTxChan()),
+			Send: reflect.Value{},
+		}
+
+		cm[i] = c
+		i++
+	}
+
+	// Add a default case to exit when nothing is left to send
+	sc = append(sc, reflect.SelectCase{
+		Dir:  reflect.SelectDefault,
+		Chan: reflect.Value{},
+		Send: reflect.Value{},
+	})
+
+	for {
+		x, rx, ok := reflect.Select(sc)
+		if !ok {
+			return
+		}
+		r.Lock()
+
+		p := rx.Interface().(*nebula.UdpPacket)
+
+		outAddr := cm[x].GetUDPAddr()
+		inAddr := net.JoinHostPort(p.ToIp.String(), fmt.Sprintf("%v", p.ToPort))
+		receiver := r.getControl(outAddr, inAddr, p)
+		if receiver == nil {
+			r.Unlock()
+			panic("Can't route for host: " + inAddr)
+		}
+		r.Unlock()
+	}
+}
+
 // getControl performs or seeds NAT translation and returns the control for toAddr, p from fields may change
 // This is an internal router function, the caller must hold the lock
 func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Control {
@@ -216,6 +316,5 @@ func (r *R) getControl(fromAddr, toAddr string, p *nebula.UdpPacket) *nebula.Con
 		return c
 	}
 
-	//TODO: call receive hooks!
 	return r.controls[toAddr]
 }

+ 4 - 4
examples/config.yml

@@ -202,16 +202,16 @@ logging:
 
 # Handshake Manger Settings
 #handshakes:
-  # Total time to try a handshake = sequence of `try_interval * retries`
-  # With 100ms interval and 20 retries it is 23.5 seconds
+  # Handshakes are sent to all known addresses at each interval with a linear backoff,
+  # Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
+  # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
   #try_interval: 100ms
   #retries: 20
-  # wait_rotation is the number of handshake attempts to do before starting to try non-local IP addresses
-  #wait_rotation: 5
   # trigger_buffer is the size of the buffer channel for quickly sending handshakes
   # after receiving the response for lighthouse queries
   #trigger_buffer: 64
 
+
 # Nebula security group configuration
 firewall:
   conntrack:

+ 36 - 31
handshake_ix.go

@@ -14,14 +14,10 @@ import (
 // Sending is done by the handshake manager
 func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	// This queries the lighthouse if we don't know a remote for the host
+	// We do it here to provoke the lighthouse to preempt our timer wheel and trigger the stage 1 packet to send
+	// more quickly, effect is a quicker handshake.
 	if hostinfo.remote == nil {
-		ips, err := f.lightHouse.Query(vpnIp, f)
-		if err != nil {
-			//l.Debugln(err)
-		}
-		for _, ip := range ips {
-			hostinfo.AddRemote(ip)
-		}
+		f.lightHouse.QueryServer(vpnIp, f)
 	}
 
 	err := f.handshakeManager.AddIndexHostInfo(hostinfo)
@@ -69,7 +65,6 @@ func ixHandshakeStage0(f *Interface, vpnIp uint32, hostinfo *HostInfo) {
 	hostinfo.HandshakePacket[0] = msg
 	hostinfo.HandshakeReady = true
 	hostinfo.handshakeStart = time.Now()
-
 }
 
 func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
@@ -125,13 +120,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 	hostinfo := &HostInfo{
 		ConnectionState: ci,
-		Remotes:         []*udpAddr{},
 		localIndexId:    myIndex,
 		remoteIndexId:   hs.Details.InitiatorIndex,
 		hostId:          vpnIP,
 		HandshakePacket: make(map[uint8][]byte, 0),
 	}
 
+	hostinfo.Lock()
+	defer hostinfo.Unlock()
+
 	f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -182,16 +179,11 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	ci.peerCert = remoteCert
 	ci.dKey = NewNebulaCipherState(dKey)
 	ci.eKey = NewNebulaCipherState(eKey)
-	//l.Debugln("got symmetric pairs")
 
-	//hostinfo.ClearRemotes()
-	hostinfo.AddRemote(addr)
-	hostinfo.ForcePromoteBest(f.hostMap.preferredRanges)
+	hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
+	hostinfo.SetRemote(addr)
 	hostinfo.CreateRemoteCIDR(remoteCert)
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	// Only overwrite existing record if we should win the handshake race
 	overwrite := vpnIP > ip2int(f.certState.certificate.Details.Ips[0].IP)
 	existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, overwrite, f)
@@ -214,6 +206,10 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		case ErrExistingHostInfo:
 			// This means there was an existing tunnel and we didn't win
 			// handshake avoidance
+
+			//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
+			//TODO: if not new send a test packet like old
+
 			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 				WithField("certName", certName).
 				WithField("fingerprint", fingerprint).
@@ -234,6 +230,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 				WithField("localIndex", hostinfo.localIndexId).WithField("collision", IntIp(existing.hostId)).
 				Error("Failed to add HostInfo due to localIndex collision")
 			return
+		case ErrExistingHandshake:
+			// We have a race where both parties think they are an initiator and this tunnel lost, let the other one finish
+			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
+				WithField("certName", certName).
+				WithField("fingerprint", fingerprint).
+				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
+				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+				Error("Prevented a pending handshake race")
+			return
 		default:
 			// Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete
 			// And we forget to update it here
@@ -286,6 +291,8 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
 			Info("Handshake is already complete")
 
+		//TODO: evaluate addr for preference, if we handshook with a less preferred addr we can correct quickly here
+
 		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
 		return false
 	}
@@ -334,17 +341,13 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 	certName := remoteCert.Details.Name
 	fingerprint, _ := remoteCert.Sha256Sum()
 
+	// Ensure the right host responded
 	if vpnIP != hostinfo.hostId {
 		f.l.WithField("intendedVpnIp", IntIp(hostinfo.hostId)).WithField("haveVpnIp", IntIp(vpnIP)).
 			WithField("udpAddr", addr).WithField("certName", certName).
 			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			Info("Incorrect host responded to handshake")
 
-		if ho, _ := f.handshakeManager.pendingHostMap.QueryVpnIP(vpnIP); ho != nil {
-			// We might have a pending tunnel to this host already, clear out that attempt since we have a tunnel now
-			f.handshakeManager.pendingHostMap.DeleteHostInfo(ho)
-		}
-
 		// Release our old handshake from pending, it should not continue
 		f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
 
@@ -354,26 +357,28 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		newHostInfo.Lock()
 
 		// Block the current used address
-		newHostInfo.unlockedBlockRemote(addr)
+		newHostInfo.remotes = hostinfo.remotes
+		newHostInfo.remotes.BlockRemote(addr)
 
-		// If this is an ongoing issue our previous hostmap will have some bad ips too
-		for _, v := range hostinfo.badRemotes {
-			newHostInfo.unlockedBlockRemote(v)
-		}
-		//TODO: this is me enabling tests
-		newHostInfo.ForcePromoteBest(f.hostMap.preferredRanges)
+		// Get the correct remote list for the host we did handshake with
+		hostinfo.remotes = f.lightHouse.QueryCache(vpnIP)
 
-		f.l.WithField("blockedUdpAddrs", newHostInfo.badRemotes).WithField("vpnIp", IntIp(vpnIP)).
-			WithField("remotes", newHostInfo.Remotes).
+		f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", IntIp(vpnIP)).
+			WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 			Info("Blocked addresses for handshakes")
 
 		// Swap the packet store to benefit the original intended recipient
+		hostinfo.ConnectionState.queueLock.Lock()
 		newHostInfo.packetStore = hostinfo.packetStore
 		hostinfo.packetStore = []*cachedPacket{}
+		hostinfo.ConnectionState.queueLock.Unlock()
 
-		// Set the current hostId to the new vpnIp
+		// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
 		hostinfo.hostId = vpnIP
+		f.sendCloseTunnel(hostinfo)
 		newHostInfo.Unlock()
+
+		return true
 	}
 
 	// Mark packet 2 as seen so it doesn't show up as missed

+ 122 - 106
handshake_manager.go

@@ -12,12 +12,8 @@ import (
 )
 
 const (
-	// Total time to try a handshake = sequence of HandshakeTryInterval * HandshakeRetries
-	// With 100ms interval and 20 retries is 23.5 seconds
-	DefaultHandshakeTryInterval = time.Millisecond * 100
-	DefaultHandshakeRetries     = 20
-	// DefaultHandshakeWaitRotation is the number of handshake attempts to do before starting to use other ips addresses
-	DefaultHandshakeWaitRotation  = 5
+	DefaultHandshakeTryInterval   = time.Millisecond * 100
+	DefaultHandshakeRetries       = 10
 	DefaultHandshakeTriggerBuffer = 64
 )
 
@@ -25,7 +21,6 @@ var (
 	defaultHandshakeConfig = HandshakeConfig{
 		tryInterval:   DefaultHandshakeTryInterval,
 		retries:       DefaultHandshakeRetries,
-		waitRotation:  DefaultHandshakeWaitRotation,
 		triggerBuffer: DefaultHandshakeTriggerBuffer,
 	}
 )
@@ -33,45 +28,36 @@ var (
 type HandshakeConfig struct {
 	tryInterval   time.Duration
 	retries       int
-	waitRotation  int
 	triggerBuffer int
 
 	messageMetrics *MessageMetrics
 }
 
 type HandshakeManager struct {
-	pendingHostMap *HostMap
-	mainHostMap    *HostMap
-	lightHouse     *LightHouse
-	outside        *udpConn
-	config         HandshakeConfig
+	pendingHostMap         *HostMap
+	mainHostMap            *HostMap
+	lightHouse             *LightHouse
+	outside                *udpConn
+	config                 HandshakeConfig
+	OutboundHandshakeTimer *SystemTimerWheel
+	messageMetrics         *MessageMetrics
+	l                      *logrus.Logger
 
 	// can be used to trigger outbound handshake for the given vpnIP
 	trigger chan uint32
-
-	OutboundHandshakeTimer *SystemTimerWheel
-	InboundHandshakeTimer  *SystemTimerWheel
-
-	messageMetrics *MessageMetrics
-	l              *logrus.Logger
 }
 
 func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside *udpConn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		pendingHostMap: NewHostMap(l, "pending", tunCidr, preferredRanges),
-		mainHostMap:    mainHostMap,
-		lightHouse:     lightHouse,
-		outside:        outside,
-
-		config: config,
-
-		trigger: make(chan uint32, config.triggerBuffer),
-
-		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
-		InboundHandshakeTimer:  NewSystemTimerWheel(config.tryInterval, config.tryInterval*time.Duration(config.retries)),
-
-		messageMetrics: config.messageMetrics,
-		l:              l,
+		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
+		mainHostMap:            mainHostMap,
+		lightHouse:             lightHouse,
+		outside:                outside,
+		config:                 config,
+		trigger:                make(chan uint32, config.triggerBuffer),
+		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		messageMetrics:         config.messageMetrics,
+		l:                      l,
 	}
 }
 
@@ -84,7 +70,6 @@ func (c *HandshakeManager) Run(f EncWriter) {
 			c.handleOutbound(vpnIP, f, true)
 		case now := <-clockSource:
 			c.NextOutboundHandshakeTimerTick(now, f)
-			c.NextInboundHandshakeTimerTick(now)
 		}
 	}
 }
@@ -109,84 +94,84 @@ func (c *HandshakeManager) handleOutbound(vpnIP uint32, f EncWriter, lighthouseT
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 
-	// If we haven't finished the handshake and we haven't hit max retries, query
-	// lighthouse and then send the handshake packet again.
-	if hostinfo.HandshakeCounter < c.config.retries && !hostinfo.HandshakeComplete {
-		if hostinfo.remote == nil {
-			// We continue to query the lighthouse because hosts may
-			// come online during handshake retries. If the query
-			// succeeds (no error), add the lighthouse info to hostinfo
-			ips := c.lightHouse.QueryCache(vpnIP)
-			// If we have no responses yet, or only one IP (the host hadn't
-			// finished reporting its own IPs yet), then send another query to
-			// the LH.
-			if len(ips) <= 1 {
-				ips, err = c.lightHouse.Query(vpnIP, f)
-			}
-			if err == nil {
-				for _, ip := range ips {
-					hostinfo.AddRemote(ip)
-				}
-				hostinfo.ForcePromoteBest(c.mainHostMap.preferredRanges)
-			}
-		} else if lighthouseTriggered {
-			// We were triggered by a lighthouse HostQueryReply packet, but
-			// we have already picked a remote for this host (this can happen
-			// if we are configured with multiple lighthouses). So we can skip
-			// this trigger and let the timerwheel handle the rest of the
-			// process
-			return
-		}
+	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
+	if hostinfo.HandshakeComplete {
+		// Ensure we don't exist in the pending hostmap anymore since we have completed
+		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		return
+	}
 
-		hostinfo.HandshakeCounter++
+	// Check if we have a handshake packet to transmit yet
+	if !hostinfo.HandshakeReady {
+		// There is currently a slight race in getOrHandshake due to ConnectionState not being part of the HostInfo directly
+		// Our hostinfo here was added to the pending map and the wheel may have ticked to us before we created ConnectionState
+		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		return
+	}
 
-		// We want to use the "best" calculated ip for the first 5 attempts, after that we just blindly rotate through
-		// all the others until we can stand up a connection.
-		if hostinfo.HandshakeCounter > c.config.waitRotation {
-			hostinfo.rotateRemote()
-		}
+	// If we are out of time, clean up
+	if hostinfo.HandshakeCounter >= c.config.retries {
+		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
+			WithField("initiatorIndex", hostinfo.localIndexId).
+			WithField("remoteIndex", hostinfo.remoteIndexId).
+			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
+			Info("Handshake timed out")
+		//TODO: emit metrics
+		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		return
+	}
 
-		// Ensure the handshake is ready to avoid a race in timer tick and stage 0 handshake generation
-		if hostinfo.HandshakeReady && hostinfo.remote != nil {
-			c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-			err := c.outside.WriteTo(hostinfo.HandshakePacket[0], hostinfo.remote)
-			if err != nil {
-				hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
-					WithField("initiatorIndex", hostinfo.localIndexId).
-					WithField("remoteIndex", hostinfo.remoteIndexId).
-					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-					WithError(err).Error("Failed to send handshake message")
-			} else {
-				//TODO: this log line is assuming a lot of stuff around the cached stage 0 handshake packet, we should
-				// keep the real packet struct around for logging purposes
-				hostinfo.logger(c.l).WithField("udpAddr", hostinfo.remote).
-					WithField("initiatorIndex", hostinfo.localIndexId).
-					WithField("remoteIndex", hostinfo.remoteIndexId).
-					WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-					Info("Handshake message sent")
-			}
-		}
+	// We only care about a lighthouse trigger before the first handshake transmit attempt. This is a very specific
+	// optimization for a fast lighthouse reply
+	//TODO: it would feel better to do this once, anytime, as our delay increases over time
+	if lighthouseTriggered && hostinfo.HandshakeCounter > 0 {
+		// If we didn't return here a lighthouse could cause us to aggressively send handshakes
+		return
+	}
 
-		// Readd to the timer wheel so we continue trying wait HandshakeTryInterval * counter longer for next try
-		if !lighthouseTriggered {
-			//l.Infoln("Interval: ", HandshakeTryInterval*time.Duration(hostinfo.HandshakeCounter))
-			c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
-		}
-	} else {
-		c.pendingHostMap.DeleteHostInfo(hostinfo)
+	// Get a remotes object if we don't already have one.
+	// This is mainly to protect us as this should never be the case
+	if hostinfo.remotes == nil {
+		hostinfo.remotes = c.lightHouse.QueryCache(vpnIP)
 	}
-}
 
-func (c *HandshakeManager) NextInboundHandshakeTimerTick(now time.Time) {
-	c.InboundHandshakeTimer.advance(now)
-	for {
-		ep := c.InboundHandshakeTimer.Purge()
-		if ep == nil {
-			break
+	//TODO: this will generate a load of queries for hosts with only 1 ip (i'm not using a lighthouse, static mapped)
+	if hostinfo.remotes.Len(c.pendingHostMap.preferredRanges) <= 1 {
+		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
+		// Our vpnIP here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
+		// the learned public ip for them. Query again to short circuit the promotion counter
+		c.lightHouse.QueryServer(vpnIP, f)
+	}
+
+	// Send a the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
+	var sentTo []*udpAddr
+	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udpAddr, _ bool) {
+		c.messageMetrics.Tx(handshake, NebulaMessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+		if err != nil {
+			hostinfo.logger(c.l).WithField("udpAddr", addr).
+				WithField("initiatorIndex", hostinfo.localIndexId).
+				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+				WithError(err).Error("Failed to send handshake message")
+
+		} else {
+			sentTo = append(sentTo, addr)
 		}
-		index := ep.(uint32)
+	})
+
+	hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+		WithField("initiatorIndex", hostinfo.localIndexId).
+		WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
+		Info("Handshake message sent")
 
-		c.pendingHostMap.DeleteIndex(index)
+	// Increment the counter to increase our delay, linear backoff
+	hostinfo.HandshakeCounter++
+
+	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
+	if !lighthouseTriggered {
+		//TODO: feel like we dupe handshake real fast in a tight loop, why?
+		c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
 	}
 }
 
@@ -194,6 +179,7 @@ func (c *HandshakeManager) AddVpnIP(vpnIP uint32) *HostInfo {
 	hostinfo := c.pendingHostMap.AddVpnIP(vpnIP)
 	// We lock here and use an array to insert items to prevent locking the
 	// main receive thread for very long by waiting to add items to the pending map
+	//TODO: what lock?
 	c.OutboundHandshakeTimer.Add(vpnIP, c.config.tryInterval)
 
 	return hostinfo
@@ -203,6 +189,7 @@ var (
 	ErrExistingHostInfo    = errors.New("existing hostinfo")
 	ErrAlreadySeen         = errors.New("already seen")
 	ErrLocalIndexCollision = errors.New("local index collision")
+	ErrExistingHandshake   = errors.New("existing handshake")
 )
 
 // CheckAndComplete checks for any conflicts in the main and pending hostmap
@@ -217,17 +204,21 @@ var (
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, overwrite bool, f *Interface) (*HostInfo, error) {
-	c.pendingHostMap.RLock()
-	defer c.pendingHostMap.RUnlock()
+	c.pendingHostMap.Lock()
+	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 
+	// Check if we already have a tunnel with this vpn ip
 	existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.hostId]
 	if found && existingHostInfo != nil {
+		// Is it just a delayed handshake packet?
 		if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
 			return existingHostInfo, ErrAlreadySeen
 		}
+
 		if !overwrite {
+			// It's a new handshake and we lost the race
 			return existingHostInfo, ErrExistingHostInfo
 		}
 	}
@@ -237,6 +228,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
+
 	existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
 	if found && existingIndex != hostinfo {
 		// We have a collision, but for a different hostinfo
@@ -252,7 +244,24 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			Info("New host shadows existing host remoteIndex")
 	}
 
+	// Check if we are also handshaking with this vpn ip
+	pendingHostInfo, found := c.pendingHostMap.Hosts[hostinfo.hostId]
+	if found && pendingHostInfo != nil {
+		if !overwrite {
+			// We won, let our pending handshake win
+			return pendingHostInfo, ErrExistingHandshake
+		}
+
+		// We lost, take this handshake and move any cached packets over so they get sent
+		pendingHostInfo.ConnectionState.queueLock.Lock()
+		hostinfo.packetStore = append(hostinfo.packetStore, pendingHostInfo.packetStore...)
+		c.pendingHostMap.unlockedDeleteHostInfo(pendingHostInfo)
+		pendingHostInfo.ConnectionState.queueLock.Unlock()
+		pendingHostInfo.logger(c.l).Info("Handshake race lost, replacing pending handshake with completed tunnel")
+	}
+
 	if existingHostInfo != nil {
+		hostinfo.logger(c.l).Info("Race lost, taking new handshake")
 		// We are going to overwrite this entry, so remove the old references
 		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)
@@ -267,6 +276,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap
 func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
+	c.pendingHostMap.Lock()
+	defer c.pendingHostMap.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 
@@ -288,6 +299,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	}
 
 	c.mainHostMap.addHostInfo(hostinfo, f)
+	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
 }
 
 // AddIndexHostInfo generates a unique localIndexId for this HostInfo
@@ -359,3 +371,7 @@ func generateIndex(l *logrus.Logger) (uint32, error) {
 	}
 	return index, nil
 }
+
+func hsTimeout(tries int, interval time.Duration) time.Duration {
+	return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
+}

+ 28 - 185
handshake_manager_test.go

@@ -8,66 +8,12 @@ import (
 	"github.com/stretchr/testify/assert"
 )
 
-//var ips []uint32 = []uint32{9000, 9999999, 3, 292394923}
-var ips []uint32
-
-func Test_NewHandshakeManagerIndex(t *testing.T) {
-	l := NewTestLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
-	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextInboundHandshakeTimerTick(now)
-
-	var indexes = make([]uint32, 4)
-	var hostinfo = make([]*HostInfo, len(indexes))
-	for i := range indexes {
-		hostinfo[i] = &HostInfo{ConnectionState: &ConnectionState{}}
-	}
-
-	// Add four indexes
-	for i := range indexes {
-		err := blah.AddIndexHostInfo(hostinfo[i])
-		assert.NoError(t, err)
-		indexes[i] = hostinfo[i].localIndexId
-		blah.InboundHandshakeTimer.Add(indexes[i], time.Second*10)
-	}
-	// Confirm they are in the pending index list
-	for _, v := range indexes {
-		assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
-	}
-	// Adding something to pending should not affect the main hostmap
-	assert.Len(t, mainHM.Indexes, 0)
-	// Jump ahead 8 seconds
-	for i := 1; i <= DefaultHandshakeRetries; i++ {
-		next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
-		blah.NextInboundHandshakeTimerTick(next_tick)
-	}
-	// Confirm they are still in the pending index list
-	for _, v := range indexes {
-		assert.Contains(t, blah.pendingHostMap.Indexes, uint32(v))
-	}
-	// Jump ahead 4 more seconds
-	next_tick := now.Add(12 * time.Second)
-	blah.NextInboundHandshakeTimerTick(next_tick)
-	// Confirm they have been removed
-	for _, v := range indexes {
-		assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(v))
-	}
-}
-
 func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 	l := NewTestLogger()
 	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	ips = []uint32{ip2int(net.ParseIP("172.1.1.2"))}
+	ip := ip2int(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
@@ -77,39 +23,30 @@ func Test_NewHandshakeManagerVpnIP(t *testing.T) {
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
-	// Add four "IPs" - which are just uint32s
-	for _, v := range ips {
-		blah.AddVpnIP(v)
-	}
+	i := blah.AddVpnIP(ip)
+	i.remotes = NewRemoteList()
+	i.HandshakeReady = true
+
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)
+
 	// Confirm they are in the pending index list
-	for _, v := range ips {
-		assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
-	}
+	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
 
-	// Jump ahead `HandshakeRetries` ticks
-	cumulative := time.Duration(0)
-	for i := 0; i <= DefaultHandshakeRetries+1; i++ {
-		cumulative += time.Duration(i)*DefaultHandshakeTryInterval + 1
-		next_tick := now.Add(cumulative)
-		//l.Infoln(next_tick)
-		blah.NextOutboundHandshakeTimerTick(next_tick, mw)
+	// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
+	for i := 1; i <= DefaultHandshakeRetries+1; i++ {
+		now = now.Add(time.Duration(i) * DefaultHandshakeTryInterval)
+		blah.NextOutboundHandshakeTimerTick(now, mw)
 	}
 
 	// Confirm they are still in the pending index list
-	for _, v := range ips {
-		assert.Contains(t, blah.pendingHostMap.Hosts, uint32(v))
-	}
-	// Jump ahead 1 more second
-	cumulative += time.Duration(DefaultHandshakeRetries+1) * DefaultHandshakeTryInterval
-	next_tick := now.Add(cumulative)
-	//l.Infoln(next_tick)
-	blah.NextOutboundHandshakeTimerTick(next_tick, mw)
+	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
+
+	// Tick 1 more time, a minute will certainly flush it out
+	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
+
 	// Confirm they have been removed
-	for _, v := range ips {
-		assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(v))
-	}
+	assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
 }
 
 func Test_NewHandshakeManagerTrigger(t *testing.T) {
@@ -121,7 +58,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
 	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-	lh := &LightHouse{}
+	lh := &LightHouse{addrMap: make(map[uint32]*RemoteList), l: l}
 
 	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udpConn{}, defaultHandshakeConfig)
 
@@ -130,28 +67,25 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
-	blah.AddVpnIP(ip)
-
+	hi := blah.AddVpnIP(ip)
+	hi.HandshakeReady = true
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
+	assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")
 
-	// Trigger the same method the channel will
+	// Trigger the same method the channel will but, this should set our remotes pointer
 	blah.handleOutbound(ip, mw, true)
+	assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have done a handshake attempt")
+	assert.NotNil(t, hi.remotes, "Manager should have set my remotes pointer")
 
-	// Make sure the trigger doesn't schedule another timer entry
+	// Make sure the trigger doesn't double schedule the timer entry
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
-	hi := blah.pendingHostMap.Hosts[ip]
-	assert.Nil(t, hi.remote)
 
 	uaddr := NewUDPAddrFromString("10.1.1.1:4242")
-	lh.addrMap = map[uint32]*ip4And6{}
-	lh.addrMap[ip] = &ip4And6{
-		v4: []*Ip4AndPort{NewIp4AndPort(uaddr.IP, uint32(uaddr.Port))},
-		v6: []*Ip6AndPort{},
-	}
+	hi.remotes.unlockedPrependV4(ip, NewIp4AndPort(uaddr.IP, uint32(uaddr.Port)))
 
-	// This should trigger the hostmap to populate the hostinfo
+	// We now have remotes but only the first trigger should have pushed things forward
 	blah.handleOutbound(ip, mw, true)
-	assert.NotNil(t, hi.remote)
+	assert.Equal(t, 1, hi.HandshakeCounter, "Trigger should have not done a handshake attempt")
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 }
 
@@ -166,100 +100,9 @@ func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
 	return c
 }
 
-func Test_NewHandshakeManagerVpnIPcleanup(t *testing.T) {
-	l := NewTestLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	vpnIP = ip2int(net.ParseIP("172.1.1.2"))
-	preferredRanges := []*net.IPNet{localrange}
-	mw := &mockEncWriter{}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextOutboundHandshakeTimerTick(now, mw)
-
-	hostinfo := blah.AddVpnIP(vpnIP)
-	// Pretned we have an index too
-	err := blah.AddIndexHostInfo(hostinfo)
-	assert.NoError(t, err)
-	blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
-	assert.NotZero(t, hostinfo.localIndexId)
-	assert.Contains(t, blah.pendingHostMap.Indexes, hostinfo.localIndexId)
-
-	// Jump ahead `HandshakeRetries` ticks. Eviction should happen in pending
-	// but not main hostmap
-	cumulative := time.Duration(0)
-	for i := 1; i <= DefaultHandshakeRetries+2; i++ {
-		cumulative += DefaultHandshakeTryInterval * time.Duration(i)
-		next_tick := now.Add(cumulative)
-		blah.NextOutboundHandshakeTimerTick(next_tick, mw)
-	}
-	/*
-		for i := 0; i <= HandshakeRetries+1; i++ {
-			next_tick := now.Add(cumulative)
-			//l.Infoln(next_tick)
-			blah.NextOutboundHandshakeTimerTick(next_tick)
-		}
-	*/
-	/*
-		for i := 0; i <= HandshakeRetries+1; i++ {
-			next_tick := now.Add(time.Duration(i) * time.Second)
-			blah.NextOutboundHandshakeTimerTick(next_tick)
-		}
-	*/
-
-	/*
-		cumulative += HandshakeTryInterval*time.Duration(HandshakeRetries) + 3
-		next_tick := now.Add(cumulative)
-		l.Infoln(cumulative, next_tick)
-		blah.NextOutboundHandshakeTimerTick(next_tick)
-	*/
-	assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(vpnIP))
-	assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(12341234))
-}
-
-func Test_NewHandshakeManagerIndexcleanup(t *testing.T) {
-	l := NewTestLogger()
-	_, tuncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
-	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
-	preferredRanges := []*net.IPNet{localrange}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
-
-	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, &LightHouse{}, &udpConn{}, defaultHandshakeConfig)
-
-	now := time.Now()
-	blah.NextInboundHandshakeTimerTick(now)
-
-	hostinfo := &HostInfo{ConnectionState: &ConnectionState{}}
-	err := blah.AddIndexHostInfo(hostinfo)
-	assert.NoError(t, err)
-	blah.InboundHandshakeTimer.Add(hostinfo.localIndexId, time.Second*10)
-	// Pretned we have an index too
-	blah.pendingHostMap.AddVpnIPHostInfo(101010, hostinfo)
-	assert.Contains(t, blah.pendingHostMap.Hosts, uint32(101010))
-
-	for i := 1; i <= DefaultHandshakeRetries+2; i++ {
-		next_tick := now.Add(DefaultHandshakeTryInterval * time.Duration(i))
-		blah.NextInboundHandshakeTimerTick(next_tick)
-	}
-
-	next_tick := now.Add(DefaultHandshakeTryInterval*DefaultHandshakeRetries + 3)
-	blah.NextInboundHandshakeTimerTick(next_tick)
-	assert.NotContains(t, blah.pendingHostMap.Hosts, uint32(101010))
-	assert.NotContains(t, blah.pendingHostMap.Indexes, uint32(hostinfo.localIndexId))
-}
-
 type mockEncWriter struct {
 }
 
 func (mw *mockEncWriter) SendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
 	return
 }
-
-func (mw *mockEncWriter) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
-	return
-}

+ 52 - 249
hostmap.go

@@ -1,7 +1,6 @@
 package nebula
 
 import (
-	"encoding/json"
 	"errors"
 	"fmt"
 	"net"
@@ -16,6 +15,7 @@ import (
 
 //const ProbeLen = 100
 const PromoteEvery = 1000
+const ReQueryEvery = 5000
 const MaxRemotes = 10
 
 // How long we should prevent roaming back to the previous IP.
@@ -30,7 +30,6 @@ type HostMap struct {
 	Hosts           map[uint32]*HostInfo
 	preferredRanges []*net.IPNet
 	vpnCIDR         *net.IPNet
-	defaultRoute    uint32
 	unsafeRoutes    *CIDRTree
 	metricsEnabled  bool
 	l               *logrus.Logger
@@ -40,25 +39,21 @@ type HostInfo struct {
 	sync.RWMutex
 
 	remote            *udpAddr
-	Remotes           []*udpAddr
+	remotes           *RemoteList
 	promoteCounter    uint32
 	ConnectionState   *ConnectionState
-	handshakeStart    time.Time
-	HandshakeReady    bool
-	HandshakeCounter  int
-	HandshakeComplete bool
-	HandshakePacket   map[uint8][]byte
-	packetStore       []*cachedPacket
+	handshakeStart    time.Time        //todo: this an entry in the handshake manager
+	HandshakeReady    bool             //todo: being in the manager means you are ready
+	HandshakeCounter  int              //todo: another handshake manager entry
+	HandshakeComplete bool             //todo: this should go away in favor of ConnectionState.ready
+	HandshakePacket   map[uint8][]byte //todo: this is other handshake manager entry
+	packetStore       []*cachedPacket  //todo: this is other handshake manager entry
 	remoteIndexId     uint32
 	localIndexId      uint32
 	hostId            uint32
 	recvError         int
 	remoteCidr        *CIDRTree
 
-	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
-	// They should not be tried again during a handshake
-	badRemotes []*udpAddr
-
 	// lastRebindCount is the other side of Interface.rebindCount, if these values don't match then we need to ask LH
 	// for a punch from the remote end of this tunnel. The goal being to prime their conntrack for our traffic just like
 	// with a handshake
@@ -88,7 +83,6 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 		Hosts:           h,
 		preferredRanges: preferredRanges,
 		vpnCIDR:         vpnCIDR,
-		defaultRoute:    0,
 		unsafeRoutes:    NewCIDRTree(),
 		l:               l,
 	}
@@ -131,7 +125,6 @@ func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
 	if _, ok := hm.Hosts[vpnIP]; !ok {
 		hm.RUnlock()
 		h = &HostInfo{
-			Remotes:         []*udpAddr{},
 			promoteCounter:  0,
 			hostId:          vpnIP,
 			HandshakePacket: make(map[uint8][]byte, 0),
@@ -239,7 +232,11 @@ func (hm *HostMap) DeleteReverseIndex(index uint32) {
 
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	hm.Lock()
+	defer hm.Unlock()
+	hm.unlockedDeleteHostInfo(hostinfo)
+}
 
+func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	// Check if this same hostId is in the hostmap with a different instance.
 	// This could happen if we have an entry in the pending hostmap with different
 	// index values than the one in the main hostmap.
@@ -262,7 +259,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) {
 	if len(hm.RemoteIndexes) == 0 {
 		hm.RemoteIndexes = map[uint32]*HostInfo{}
 	}
-	hm.Unlock()
 
 	if hm.l.Level >= logrus.DebugLevel {
 		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
@@ -294,30 +290,6 @@ func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
 	}
 }
 
-func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
-	hm.Lock()
-	i, v := hm.Hosts[vpnIp]
-	if v {
-		i.AddRemote(remote)
-	} else {
-		i = &HostInfo{
-			Remotes:         []*udpAddr{remote.Copy()},
-			promoteCounter:  0,
-			hostId:          vpnIp,
-			HandshakePacket: make(map[uint8][]byte, 0),
-		}
-		i.remote = i.Remotes[0]
-		hm.Hosts[vpnIp] = i
-		if hm.l.Level >= logrus.DebugLevel {
-			hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
-				Debug("Hostmap remote ip added")
-		}
-	}
-	i.ForcePromoteBest(hm.preferredRanges)
-	hm.Unlock()
-	return i
-}
-
 func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
 	return hm.queryVpnIP(vpnIp, nil)
 }
@@ -331,12 +303,13 @@ func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostIn
 func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
-		if promoteIfce != nil {
+		// Do not attempt promotion if you are a lighthouse
+		if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
 			h.TryPromoteBest(hm.preferredRanges, promoteIfce)
 		}
-		//fmt.Println(h.remote)
 		hm.RUnlock()
 		return h, nil
+
 	} else {
 		//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
 		hm.RUnlock()
@@ -362,11 +335,8 @@ func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
 // We already have the hm Lock when this is called, so make sure to not call
 // any other methods that might try to grab it again
 func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
-	remoteCert := hostinfo.ConnectionState.peerCert
-	ip := ip2int(remoteCert.Details.Ips[0].IP)
-
-	f.lightHouse.AddRemoteAndReset(ip, hostinfo.remote)
 	if f.serveDns {
+		remoteCert := hostinfo.ConnectionState.peerCert
 		dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String())
 	}
 
@@ -381,38 +351,21 @@ func (hm *HostMap) addHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 }
 
-func (hm *HostMap) ClearRemotes(vpnIP uint32) {
-	hm.Lock()
-	i := hm.Hosts[vpnIP]
-	if i == nil {
-		hm.Unlock()
-		return
-	}
-	i.remote = nil
-	i.Remotes = nil
-	hm.Unlock()
-}
-
-func (hm *HostMap) SetDefaultRoute(ip uint32) {
-	hm.defaultRoute = ip
-}
-
-func (hm *HostMap) PunchList() []*udpAddr {
-	var list []*udpAddr
+// punchList assembles a list of all non nil RemoteList pointer entries in this hostmap
+// The caller can then do the its work outside of the read lock
+func (hm *HostMap) punchList(rl []*RemoteList) []*RemoteList {
 	hm.RLock()
+	defer hm.RUnlock()
+
 	for _, v := range hm.Hosts {
-		for _, r := range v.Remotes {
-			list = append(list, r)
+		if v.remotes != nil {
+			rl = append(rl, v.remotes)
 		}
-		//	if h, ok := hm.Hosts[vpnIp]; ok {
-		//		hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
-		//fmt.Println(h.remote)
-		//	}
 	}
-	hm.RUnlock()
-	return list
+	return rl
 }
 
+// Punchy iterates through the result of punchList() to assemble all known addresses and sends a hole punch packet to them
 func (hm *HostMap) Punchy(conn *udpConn) {
 	var metricsTxPunchy metrics.Counter
 	if hm.metricsEnabled {
@@ -421,13 +374,18 @@ func (hm *HostMap) Punchy(conn *udpConn) {
 		metricsTxPunchy = metrics.NilCounter{}
 	}
 
+	var remotes []*RemoteList
 	b := []byte{1}
 	for {
-		for _, addr := range hm.PunchList() {
-			metricsTxPunchy.Inc(1)
-			conn.WriteTo(b, addr)
+		remotes = hm.punchList(remotes[:0])
+		for _, rl := range remotes {
+			//TODO: CopyAddrs generates garbage but ForEach locks for the work here, figure out which way is better
+			for _, addr := range rl.CopyAddrs(hm.preferredRanges) {
+				metricsTxPunchy.Inc(1)
+				conn.WriteTo(b, addr)
+			}
 		}
-		time.Sleep(time.Second * 30)
+		time.Sleep(time.Second * 10)
 	}
 }
 
@@ -438,38 +396,15 @@ func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
 	}
 }
 
-func (i *HostInfo) MarshalJSON() ([]byte, error) {
-	return json.Marshal(m{
-		"remote":             i.remote,
-		"remotes":            i.Remotes,
-		"promote_counter":    i.promoteCounter,
-		"connection_state":   i.ConnectionState,
-		"handshake_start":    i.handshakeStart,
-		"handshake_ready":    i.HandshakeReady,
-		"handshake_counter":  i.HandshakeCounter,
-		"handshake_complete": i.HandshakeComplete,
-		"handshake_packet":   i.HandshakePacket,
-		"packet_store":       i.packetStore,
-		"remote_index":       i.remoteIndexId,
-		"local_index":        i.localIndexId,
-		"host_id":            int2ip(i.hostId),
-		"receive_errors":     i.recvError,
-		"last_roam":          i.lastRoam,
-		"last_roam_remote":   i.lastRoamRemote,
-	})
-}
-
 func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
 	i.ConnectionState = cs
 }
 
+// TryPromoteBest handles re-querying lighthouses and probing for better paths
+// NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
-	if i.remote == nil {
-		i.ForcePromoteBest(preferredRanges)
-		return
-	}
-
-	if atomic.AddUint32(&i.promoteCounter, 1)%PromoteEvery == 0 {
+	c := atomic.AddUint32(&i.promoteCounter, 1)
+	if c%PromoteEvery == 0 {
 		// return early if we are already on a preferred remote
 		rIP := i.remote.IP
 		for _, l := range preferredRanges {
@@ -478,87 +413,21 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 			}
 		}
 
-		// We re-query the lighthouse periodically while sending packets, so
-		// check for new remotes in our local lighthouse cache
-		ips := ifce.lightHouse.QueryCache(i.hostId)
-		for _, ip := range ips {
-			i.AddRemote(ip)
-		}
+		i.remotes.ForEach(preferredRanges, func(addr *udpAddr, preferred bool) {
+			if addr == nil || !preferred {
+				return
+			}
 
-		best, preferred := i.getBestRemote(preferredRanges)
-		if preferred && !best.Equals(i.remote) {
 			// Try to send a test packet to that host, this should
 			// cause it to detect a roaming event and switch remotes
-			ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
-		}
-	}
-}
-
-func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
-	best, _ := i.getBestRemote(preferredRanges)
-	if best != nil {
-		i.remote = best
-	}
-}
-
-func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
-	if len(i.Remotes) > 0 {
-		for _, r := range i.Remotes {
-			for _, l := range preferredRanges {
-				if l.Contains(r.IP) {
-					return r, true
-				}
-			}
-
-			if best == nil || !PrivateIP(r.IP) {
-				best = r
-			}
-			/*
-				for _, r := range i.Remotes {
-					// Must have > 80% probe success to be considered.
-					//fmt.Println("GRADE:", r.addr.IP, r.Grade())
-					if r.Grade() > float64(.8) {
-						if localToMe.Contains(r.addr.IP) == true {
-							best = r.addr
-							break
-							//i.remote = i.Remotes[c].addr
-						} else {
-								//}
-					}
-			*/
-		}
-		return best, false
+			ifce.send(test, testRequest, i.ConnectionState, i, addr, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
+		})
 	}
 
-	return nil, false
-}
-
-// rotateRemote will move remote to the next ip in the list of remote ips for this host
-// This is different than PromoteBest in that what is algorithmically best may not actually work.
-// Only known use case is when sending a stage 0 handshake.
-// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
-func (i *HostInfo) rotateRemote() {
-	// We have 0, can't rotate
-	if len(i.Remotes) < 1 {
-		return
+	// Re query our lighthouses for new remotes occasionally
+	if c%ReQueryEvery == 0 && ifce.lightHouse != nil {
+		ifce.lightHouse.QueryServer(i.hostId, ifce)
 	}
-
-	if i.remote == nil {
-		i.remote = i.Remotes[0]
-		return
-	}
-
-	// We want to look at all but the very last entry since that is handled at the end
-	for x := 0; x < len(i.Remotes)-1; x++ {
-		// Find our current position and move to the next one in the list
-		if i.Remotes[x].Equals(i.remote) {
-			i.remote = i.Remotes[x+1]
-			return
-		}
-	}
-
-	// Our current position was likely the last in the list, start over at 0
-	i.remote = i.Remotes[0]
 }
 
 func (i *HostInfo) cachePacket(l *logrus.Logger, t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
@@ -607,23 +476,13 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger) {
 		}
 	}
 
-	i.badRemotes = make([]*udpAddr, 0)
+	i.remotes.ResetBlockedRemotes()
 	i.packetStore = make([]*cachedPacket, 0)
 	i.ConnectionState.ready = true
 	i.ConnectionState.queueLock.Unlock()
 	i.ConnectionState.certState = nil
 }
 
-func (i *HostInfo) CopyRemotes() []*udpAddr {
-	i.RLock()
-	rc := make([]*udpAddr, len(i.Remotes), len(i.Remotes))
-	for x, addr := range i.Remotes {
-		rc[x] = addr.Copy()
-	}
-	i.RUnlock()
-	return rc
-}
-
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
@@ -631,58 +490,12 @@ func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	return nil
 }
 
-func (i *HostInfo) AddRemote(remote *udpAddr) *udpAddr {
-	if i.unlockedIsBadRemote(remote) {
-		return i.remote
-	}
-
-	for _, r := range i.Remotes {
-		if r.Equals(remote) {
-			return r
-		}
-	}
-
-	// Trim this down if necessary
-	if len(i.Remotes) > MaxRemotes {
-		i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
-	}
-
-	rc := remote.Copy()
-	i.Remotes = append(i.Remotes, rc)
-	return rc
-}
-
 func (i *HostInfo) SetRemote(remote *udpAddr) {
-	i.remote = i.AddRemote(remote)
-}
-
-func (i *HostInfo) unlockedBlockRemote(remote *udpAddr) {
-	if !i.unlockedIsBadRemote(remote) {
-		// We copy here because we are taking something else's memory and we can't trust everything
-		i.badRemotes = append(i.badRemotes, remote.Copy())
+	// We copy here because we likely got this remote from a source that reuses the object
+	if !i.remote.Equals(remote) {
+		i.remote = remote.Copy()
+		i.remotes.LearnRemote(i.hostId, remote.Copy())
 	}
-
-	for k, v := range i.Remotes {
-		if v.Equals(remote) {
-			i.Remotes[k] = i.Remotes[len(i.Remotes)-1]
-			i.Remotes = i.Remotes[:len(i.Remotes)-1]
-			return
-		}
-	}
-}
-
-func (i *HostInfo) unlockedIsBadRemote(remote *udpAddr) bool {
-	for _, v := range i.badRemotes {
-		if v.Equals(remote) {
-			return true
-		}
-	}
-	return false
-}
-
-func (i *HostInfo) ClearRemotes() {
-	i.remote = nil
-	i.Remotes = []*udpAddr{}
 }
 
 func (i *HostInfo) ClearConnectionState() {
@@ -805,13 +618,3 @@ func localIps(l *logrus.Logger, allowList *AllowList) *[]net.IP {
 	}
 	return &ips
 }
-
-func PrivateIP(ip net.IP) bool {
-	//TODO: Private for ipv6 or just let it ride?
-	private := false
-	_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
-	_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
-	_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
-	private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
-	return private
-}

+ 0 - 168
hostmap_test.go

@@ -1,169 +1 @@
 package nebula
-
-import (
-	"net"
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-/*
-func TestHostInfoDestProbe(t *testing.T) {
-	a, _ := net.ResolveUDPAddr("udp", "1.0.0.1:22222")
-	d := NewHostInfoDest(a)
-
-	// 999 probes that all return should give a 100% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		d.ProbeReceived(meh)
-	}
-	assert.Equal(t, d.Grade(), float64(1))
-
-	// 999 probes of which only half return should give a 50% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		if i%2 == 0 {
-			d.ProbeReceived(meh)
-		}
-	}
-	assert.Equal(t, d.Grade(), float64(.5))
-
-	// 999 probes of which none return should give a 0% success rate
-	for i := 0; i < 999; i++ {
-		d.Probe()
-	}
-	assert.Equal(t, d.Grade(), float64(0))
-
-	// 999 probes of which only 1/4 return should give a 25% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		if i%4 == 0 {
-			d.ProbeReceived(meh)
-		}
-	}
-	assert.Equal(t, d.Grade(), float64(.25))
-
-	// 999 probes of which only half return and are duplicates should give a 50% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		if i%2 == 0 {
-			d.ProbeReceived(meh)
-			d.ProbeReceived(meh)
-		}
-	}
-	assert.Equal(t, d.Grade(), float64(.5))
-
-	// 999 probes of which only way old replies return should give a 0% success rate
-	for i := 0; i < 999; i++ {
-		meh := d.Probe()
-		d.ProbeReceived(meh - 101)
-	}
-	assert.Equal(t, d.Grade(), float64(0))
-
-}
-*/
-
-func TestHostmap(t *testing.T) {
-	l := NewTestLogger()
-	_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
-	_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
-	myNets := []*net.IPNet{myNet}
-	preferredRanges := []*net.IPNet{localToMe}
-
-	m := NewHostMap(l, "test", myNet, preferredRanges)
-
-	a := NewUDPAddrFromString("10.127.0.3:11111")
-	b := NewUDPAddrFromString("1.0.0.1:22222")
-	y := NewUDPAddrFromString("10.128.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-
-	info, _ := m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
-
-	// There should be three remotes in the host map
-	assert.Equal(t, 3, len(info.Remotes))
-
-	// Adding an identical remote should not change the count
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-	assert.Equal(t, 3, len(info.Remotes))
-
-	// Adding a fresh remote should add one
-	y = NewUDPAddrFromString("10.18.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-	assert.Equal(t, 4, len(info.Remotes))
-
-	// Query and reference remote should get the first one (and not nil)
-	info, _ = m.QueryVpnIP(ip2int(net.ParseIP("10.128.1.1")))
-	assert.NotNil(t, info.remote)
-
-	// Promotion should ensure that the best remote is chosen (y)
-	info.ForcePromoteBest(myNets)
-	assert.True(t, myNet.Contains(info.remote.IP))
-
-}
-
-func TestHostmapdebug(t *testing.T) {
-	l := NewTestLogger()
-	_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
-	_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
-	preferredRanges := []*net.IPNet{localToMe}
-	m := NewHostMap(l, "test", myNet, preferredRanges)
-
-	a := NewUDPAddrFromString("10.127.0.3:11111")
-	b := NewUDPAddrFromString("1.0.0.1:22222")
-	y := NewUDPAddrFromString("10.128.0.3:11111")
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), b)
-	m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-
-	//t.Errorf("%s", m.DebugRemotes(1))
-}
-
-func TestHostMap_rotateRemote(t *testing.T) {
-	h := HostInfo{}
-	// 0 remotes, no panic
-	h.rotateRemote()
-	assert.Nil(t, h.remote)
-
-	// 1 remote, no panic
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 1}, 0))
-	h.rotateRemote()
-	assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 1})
-
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 2}, 0))
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 3}, 0))
-	h.AddRemote(NewUDPAddr(net.IP{1, 1, 1, 4}, 0))
-
-	//TODO: ensure we are copying and not storing the slice!
-
-	// Rotate through those 3
-	h.rotateRemote()
-	assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 2})
-
-	h.rotateRemote()
-	assert.Equal(t, h.remote.IP, net.IP{1, 1, 1, 3})
-
-	h.rotateRemote()
-	assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 4}, Port: 0})
-
-	// Finally, we should start over
-	h.rotateRemote()
-	assert.Equal(t, h.remote, &udpAddr{IP: net.IP{1, 1, 1, 1}, Port: 0})
-}
-
-func BenchmarkHostmappromote2(b *testing.B) {
-	l := NewTestLogger()
-	for n := 0; n < b.N; n++ {
-		_, myNet, _ := net.ParseCIDR("10.128.0.0/16")
-		_, localToMe, _ := net.ParseCIDR("192.168.1.0/24")
-		preferredRanges := []*net.IPNet{localToMe}
-		m := NewHostMap(l, "test", myNet, preferredRanges)
-		y := NewUDPAddrFromString("10.128.0.3:11111")
-		a := NewUDPAddrFromString("10.127.0.3:11111")
-		g := NewUDPAddrFromString("1.0.0.1:22222")
-		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), a)
-		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), g)
-		m.AddRemote(ip2int(net.ParseIP("10.128.1.1")), y)
-	}
-}

+ 7 - 50
inside.go

@@ -54,10 +54,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 
 	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	if dropReason == nil {
-		mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
-		if f.lightHouse != nil && mc%5000 == 0 {
-			f.lightHouse.Query(fwPacket.RemoteIP, f)
-		}
+		f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 
 	} else if f.l.Level >= logrus.DebugLevel {
 		hostinfo.logger(f.l).
@@ -84,15 +81,13 @@ func (f *Interface) getOrHandshake(vpnIp uint32) *HostInfo {
 			hostinfo = f.handshakeManager.AddVpnIP(vpnIp)
 		}
 	}
-
 	ci := hostinfo.ConnectionState
 
 	if ci != nil && ci.eKey != nil && ci.ready {
 		return hostinfo
 	}
 
-	// Handshake is not ready, we need to grab the lock now before we start
-	// the handshake process
+	// Handshake is not ready, we need to grab the lock now before we start the handshake process
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 
@@ -150,10 +145,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 		return
 	}
 
-	messageCounter := f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
-	if f.lightHouse != nil && messageCounter%5000 == 0 {
-		f.lightHouse.Query(fp.RemoteIP, f)
-	}
+	f.sendNoMetrics(message, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out, 0)
 }
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
@@ -187,50 +179,15 @@ func (f *Interface) sendMessageToVpnIp(t NebulaMessageType, st NebulaMessageSubT
 	f.send(t, st, hostInfo.ConnectionState, hostInfo, hostInfo.remote, p, nb, out)
 }
 
-// SendMessageToAll handles real ip:port lookup and sends to all known addresses for vpnIp
-func (f *Interface) SendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, vpnIp uint32, p, nb, out []byte) {
-	hostInfo := f.getOrHandshake(vpnIp)
-	if hostInfo == nil {
-		if f.l.Level >= logrus.DebugLevel {
-			f.l.WithField("vpnIp", IntIp(vpnIp)).
-				Debugln("dropping SendMessageToAll, vpnIp not in our CIDR or in unsafe routes")
-		}
-		return
-	}
-
-	if hostInfo.ConnectionState.ready == false {
-		// Because we might be sending stored packets, lock here to stop new things going to
-		// the packet queue.
-		hostInfo.ConnectionState.queueLock.Lock()
-		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToAll)
-			hostInfo.ConnectionState.queueLock.Unlock()
-			return
-		}
-		hostInfo.ConnectionState.queueLock.Unlock()
-	}
-
-	f.sendMessageToAll(t, st, hostInfo, p, nb, out)
-	return
-}
-
-func (f *Interface) sendMessageToAll(t NebulaMessageType, st NebulaMessageSubType, hostInfo *HostInfo, p, nb, b []byte) {
-	hostInfo.RLock()
-	for _, r := range hostInfo.Remotes {
-		f.send(t, st, hostInfo.ConnectionState, hostInfo, r, p, nb, b)
-	}
-	hostInfo.RUnlock()
-}
-
 func (f *Interface) send(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte) {
 	f.messageMetrics.Tx(t, st, 1)
 	f.sendNoMetrics(t, st, ci, hostinfo, remote, p, nb, out, 0)
 }
 
-func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) uint64 {
+func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udpAddr, p, nb, out []byte, q int) {
 	if ci.eKey == nil {
 		//TODO: log warning
-		return 0
+		return
 	}
 
 	var err error
@@ -262,7 +219,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 			WithField("udpAddr", remote).WithField("counter", c).
 			WithField("attemptedCounter", c).
 			Error("Failed to encrypt outgoing packet")
-		return c
+		return
 	}
 
 	err = f.writers[q].WriteTo(out, remote)
@@ -270,7 +227,7 @@ func (f *Interface) sendNoMetrics(t NebulaMessageType, st NebulaMessageSubType,
 		hostinfo.logger(f.l).WithError(err).
 			WithField("udpAddr", remote).Error("Failed to write outgoing packet")
 	}
-	return c
+	return
 }
 
 func isMulticast(ip uint32) bool {

+ 117 - 226
lighthouse.go

@@ -13,26 +13,11 @@ import (
 	"github.com/sirupsen/logrus"
 )
 
+//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake?
 //TODO: nodes are roaming lighthouses, this is bad. How are they learning?
 
 var ErrHostNotKnown = errors.New("host not known")
 
-// The maximum number of ip addresses to store for a given vpnIp per address family
-const maxAddrs = 10
-
-type ip4And6 struct {
-	//TODO: adding a lock here could allow us to release the lock on lh.addrMap quicker
-
-	// v4 and v6 store addresses that have been self reported by the client in a server or where all addresses are stored on a client
-	v4 []*Ip4AndPort
-	v6 []*Ip6AndPort
-
-	// Learned addresses are ones that a client does not know about but a lighthouse learned from as a result of the received packet
-	// This is only used if you are a lighthouse server
-	learnedV4 []*Ip4AndPort
-	learnedV6 []*Ip6AndPort
-}
-
 type LightHouse struct {
 	//TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time
 	sync.RWMutex //Because we concurrently read and write to our maps
@@ -42,7 +27,8 @@ type LightHouse struct {
 	punchConn    *udpConn
 
 	// Local cache of answers from light houses
-	addrMap map[uint32]*ip4And6
+	// map of vpn Ip to answers
+	addrMap map[uint32]*RemoteList
 
 	// filters remote addresses allowed for each host
 	// - When we are a lighthouse, this filters what addresses we store and
@@ -81,7 +67,7 @@ func NewLightHouse(l *logrus.Logger, amLighthouse bool, myVpnIpNet *net.IPNet, i
 		amLighthouse: amLighthouse,
 		myVpnIp:      ip2int(myVpnIpNet.IP),
 		myVpnZeros:   uint32(32 - ones),
-		addrMap:      make(map[uint32]*ip4And6),
+		addrMap:      make(map[uint32]*RemoteList),
 		nebulaPort:   nebulaPort,
 		lighthouses:  make(map[uint32]struct{}),
 		staticList:   make(map[uint32]struct{}),
@@ -130,57 +116,79 @@ func (lh *LightHouse) ValidateLHStaticEntries() error {
 	return nil
 }
 
-func (lh *LightHouse) Query(ip uint32, f EncWriter) ([]*udpAddr, error) {
-	//TODO: we need to hold the lock through the next func
+func (lh *LightHouse) Query(ip uint32, f EncWriter) *RemoteList {
 	if !lh.IsLighthouseIP(ip) {
 		lh.QueryServer(ip, f)
 	}
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
-		return TransformLHReplyToUdpAddrs(v), nil
+		return v
 	}
 	lh.RUnlock()
-	return nil, ErrHostNotKnown
+	return nil
 }
 
 // This is asynchronous so no reply should be expected
 func (lh *LightHouse) QueryServer(ip uint32, f EncWriter) {
-	if !lh.amLighthouse {
-		// Send a query to the lighthouses and hope for the best next time
-		query, err := proto.Marshal(NewLhQueryByInt(ip))
-		if err != nil {
-			lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
-			return
-		}
+	if lh.amLighthouse {
+		return
+	}
 
-		lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
-		nb := make([]byte, 12, 12)
-		out := make([]byte, mtu)
-		for n := range lh.lighthouses {
-			f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
-		}
+	if lh.IsLighthouseIP(ip) {
+		return
+	}
+
+	// Send a query to the lighthouses and hope for the best next time
+	query, err := proto.Marshal(NewLhQueryByInt(ip))
+	if err != nil {
+		lh.l.WithError(err).WithField("vpnIp", IntIp(ip)).Error("Failed to marshal lighthouse query payload")
+		return
+	}
+
+	lh.metricTx(NebulaMeta_HostQuery, int64(len(lh.lighthouses)))
+	nb := make([]byte, 12, 12)
+	out := make([]byte, mtu)
+	for n := range lh.lighthouses {
+		f.SendMessageToVpnIp(lightHouse, 0, n, query, nb, out)
 	}
 }
 
-func (lh *LightHouse) QueryCache(ip uint32) []*udpAddr {
-	//TODO: we need to hold the lock through the next func
+func (lh *LightHouse) QueryCache(ip uint32) *RemoteList {
 	lh.RLock()
 	if v, ok := lh.addrMap[ip]; ok {
 		lh.RUnlock()
-		return TransformLHReplyToUdpAddrs(v)
+		return v
 	}
 	lh.RUnlock()
-	return nil
+
+	lh.Lock()
+	defer lh.Unlock()
+	// Add an entry if we don't already have one
+	return lh.unlockedGetRemoteList(ip)
 }
 
-//
-func (lh *LightHouse) queryAndPrepMessage(ip uint32, f func(*ip4And6) (int, error)) (bool, int, error) {
+// queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing
+// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp
+// If one is found then f() is called with proper locking, f() must return result of n.MarshalTo()
+func (lh *LightHouse) queryAndPrepMessage(vpnIp uint32, f func(*cache) (int, error)) (bool, int, error) {
 	lh.RLock()
-	if v, ok := lh.addrMap[ip]; ok {
-		n, err := f(v)
+	// Do we have an entry in the main cache?
+	if v, ok := lh.addrMap[vpnIp]; ok {
+		// Swap lh lock for remote list lock
+		v.RLock()
+		defer v.RUnlock()
+
 		lh.RUnlock()
-		return true, n, err
+
+		// vpnIp should also be the owner here since we are a lighthouse.
+		c := v.cache[vpnIp]
+		// Make sure we have
+		if c != nil {
+			n, err := f(c)
+			return true, n, err
+		}
+		return false, 0, nil
 	}
 	lh.RUnlock()
 	return false, 0, nil
@@ -203,70 +211,47 @@ func (lh *LightHouse) DeleteVpnIP(vpnIP uint32) {
 	lh.Unlock()
 }
 
-// AddRemote is correct way for non LightHouse members to add an address. toAddr will be placed in the learned map
-// static means this is a static host entry from the config file, it should only be used on start up
-func (lh *LightHouse) AddRemote(vpnIP uint32, toAddr *udpAddr, static bool) {
-	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
-		lh.addRemoteV4(vpnIP, NewIp4AndPort(ipv4, uint32(toAddr.Port)), static, true)
-	} else {
-		lh.addRemoteV6(vpnIP, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)), static, true)
-	}
-
-	//TODO: if we do not add due to a config filter we may end up not having any addresses here
-	if static {
-		lh.staticList[vpnIP] = struct{}{}
-	}
-}
-
-// unlockedGetAddrs assumes you have the lh lock
-func (lh *LightHouse) unlockedGetAddrs(vpnIP uint32) *ip4And6 {
-	am, ok := lh.addrMap[vpnIP]
-	if !ok {
-		am = &ip4And6{}
-		lh.addrMap[vpnIP] = am
-	}
-	return am
-}
-
-// addRemoteV4 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated
-func (lh *LightHouse) addRemoteV4(vpnIP uint32, to *Ip4AndPort, static bool, learned bool) {
-	// First we check if the sender thinks this is a static entry
-	// and do nothing if it is not, but should be considered static
-	if static == false {
-		if _, ok := lh.staticList[vpnIP]; ok {
-			return
-		}
-	}
-
+// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner
+// We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with
+// And we don't want a lighthouse query reply to interfere with our learned cache if we are a client
+func (lh *LightHouse) AddStaticRemote(vpnIp uint32, toAddr *udpAddr) {
 	lh.Lock()
-	defer lh.Unlock()
-	am := lh.unlockedGetAddrs(vpnIP)
+	am := lh.unlockedGetRemoteList(vpnIp)
+	am.Lock()
+	defer am.Unlock()
+	lh.Unlock()
 
-	if learned {
-		if !lh.unlockedShouldAddV4(am.learnedV4, to) {
+	if ipv4 := toAddr.IP.To4(); ipv4 != nil {
+		to := NewIp4AndPort(ipv4, uint32(toAddr.Port))
+		if !lh.unlockedShouldAddV4(to) {
 			return
 		}
-		am.learnedV4 = prependAndLimitV4(am.learnedV4, to)
+		am.unlockedPrependV4(lh.myVpnIp, to)
+
 	} else {
-		if !lh.unlockedShouldAddV4(am.v4, to) {
+		to := NewIp6AndPort(toAddr.IP, uint32(toAddr.Port))
+		if !lh.unlockedShouldAddV6(to) {
 			return
 		}
-		am.v4 = prependAndLimitV4(am.v4, to)
+		am.unlockedPrependV6(lh.myVpnIp, to)
 	}
+
+	// Mark it as static
+	lh.staticList[vpnIp] = struct{}{}
 }
 
-func prependAndLimitV4(cache []*Ip4AndPort, to *Ip4AndPort) []*Ip4AndPort {
-	cache = append(cache, nil)
-	copy(cache[1:], cache)
-	cache[0] = to
-	if len(cache) > MaxRemotes {
-		cache = cache[:maxAddrs]
+// unlockedGetRemoteList assumes you have the lh lock
+func (lh *LightHouse) unlockedGetRemoteList(vpnIP uint32) *RemoteList {
+	am, ok := lh.addrMap[vpnIP]
+	if !ok {
+		am = NewRemoteList()
+		lh.addrMap[vpnIP] = am
 	}
-	return cache
+	return am
 }
 
-// unlockedShouldAddV4 checks if to is allowed by our allow list and is not already present in the cache
-func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool {
+// unlockedShouldAddV4 checks if to is allowed by our allow list
+func (lh *LightHouse) unlockedShouldAddV4(to *Ip4AndPort) bool {
 	allow := lh.remoteAllowList.AllowIpV4(to.Ip)
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", IntIp(to.Ip)).WithField("allow", allow).Trace("remoteAllowList.Allow")
@@ -276,69 +261,21 @@ func (lh *LightHouse) unlockedShouldAddV4(am []*Ip4AndPort, to *Ip4AndPort) bool
 		return false
 	}
 
-	for _, v := range am {
-		if v.Ip == to.Ip && v.Port == to.Port {
-			return false
-		}
-	}
-
 	return true
 }
 
-// addRemoteV6 is a lighthouse internal method that prepends a remote if it is allowed by the allow list and not duplicated
-func (lh *LightHouse) addRemoteV6(vpnIP uint32, to *Ip6AndPort, static bool, learned bool) {
-	// First we check if the sender thinks this is a static entry
-	// and do nothing if it is not, but should be considered static
-	if static == false {
-		if _, ok := lh.staticList[vpnIP]; ok {
-			return
-		}
-	}
-
-	lh.Lock()
-	defer lh.Unlock()
-	am := lh.unlockedGetAddrs(vpnIP)
-
-	if learned {
-		if !lh.unlockedShouldAddV6(am.learnedV6, to) {
-			return
-		}
-		am.learnedV6 = prependAndLimitV6(am.learnedV6, to)
-	} else {
-		if !lh.unlockedShouldAddV6(am.v6, to) {
-			return
-		}
-		am.v6 = prependAndLimitV6(am.v6, to)
-	}
-}
-
-func prependAndLimitV6(cache []*Ip6AndPort, to *Ip6AndPort) []*Ip6AndPort {
-	cache = append(cache, nil)
-	copy(cache[1:], cache)
-	cache[0] = to
-	if len(cache) > MaxRemotes {
-		cache = cache[:maxAddrs]
-	}
-	return cache
-}
-
-// unlockedShouldAddV6 checks if to is allowed by our allow list and is not already present in the cache
-func (lh *LightHouse) unlockedShouldAddV6(am []*Ip6AndPort, to *Ip6AndPort) bool {
+// unlockedShouldAddV6 checks if to is allowed by our allow list
+func (lh *LightHouse) unlockedShouldAddV6(to *Ip6AndPort) bool {
 	allow := lh.remoteAllowList.AllowIpV6(to.Hi, to.Lo)
 	if lh.l.Level >= logrus.TraceLevel {
 		lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow")
 	}
 
+	// We don't check our vpn network here because nebula does not support ipv6 on the inside
 	if !allow {
 		return false
 	}
 
-	for _, v := range am {
-		if v.Hi == to.Hi && v.Lo == to.Lo && v.Port == to.Port {
-			return false
-		}
-	}
-
 	return true
 }
 
@@ -349,13 +286,6 @@ func lhIp6ToIp(v *Ip6AndPort) net.IP {
 	return ip
 }
 
-func (lh *LightHouse) AddRemoteAndReset(vpnIP uint32, toIp *udpAddr) {
-	if lh.amLighthouse {
-		lh.DeleteVpnIP(vpnIP)
-		lh.AddRemote(vpnIP, toIp, false)
-	}
-}
-
 func (lh *LightHouse) IsLighthouseIP(vpnIP uint32) bool {
 	if _, ok := lh.lighthouses[vpnIP]; ok {
 		return true
@@ -496,7 +426,6 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta {
 	return lhh.meta
 }
 
-//TODO: do we need c here?
 func (lhh *LightHouseHandler) HandleRequest(rAddr *udpAddr, vpnIp uint32, p []byte, w EncWriter) {
 	n := lhh.resetMeta()
 	err := n.Unmarshal(p)
@@ -544,13 +473,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	//TODO: we can DRY this further
 	reqVpnIP := n.Details.VpnIp
 	//TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data
-	//TODO: If we use a lock on cache we can avoid holding it on lh.addrMap and keep things moving better
-	found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(cache *ip4And6) (int, error) {
+	found, ln, err := lhh.lh.queryAndPrepMessage(n.Details.VpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostQueryReply
 		n.Details.VpnIp = reqVpnIP
 
-		lhh.coalesceAnswers(cache, n)
+		lhh.coalesceAnswers(c, n)
 
 		return n.MarshalTo(lhh.pb)
 	})
@@ -568,12 +496,12 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	w.SendMessageToVpnIp(lightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 
 	// This signals the other side to punch some zero byte udp packets
-	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(cache *ip4And6) (int, error) {
+	found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) {
 		n = lhh.resetMeta()
 		n.Type = NebulaMeta_HostPunchNotification
 		n.Details.VpnIp = vpnIp
 
-		lhh.coalesceAnswers(cache, n)
+		lhh.coalesceAnswers(c, n)
 
 		return n.MarshalTo(lhh.pb)
 	})
@@ -591,12 +519,24 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp uint32, addr
 	w.SendMessageToVpnIp(lightHouse, 0, reqVpnIP, lhh.pb[:ln], lhh.nb, lhh.out[:0])
 }
 
-func (lhh *LightHouseHandler) coalesceAnswers(cache *ip4And6, n *NebulaMeta) {
-	n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.v4...)
-	n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, cache.learnedV4...)
+func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) {
+	if c.v4 != nil {
+		if c.v4.learned != nil {
+			n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned)
+		}
+		if c.v4.reported != nil && len(c.v4.reported) > 0 {
+			n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...)
+		}
+	}
 
-	n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.v6...)
-	n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, cache.learnedV6...)
+	if c.v6 != nil {
+		if c.v6.learned != nil {
+			n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned)
+		}
+		if c.v6.reported != nil && len(c.v6.reported) > 0 {
+			n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...)
+		}
+	}
 }
 
 func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32) {
@@ -604,14 +544,14 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp uint32)
 		return
 	}
 
-	// We can't just slam the responses in as they may come from multiple lighthouses and we should coalesce the answers
-	for _, to := range n.Details.Ip4AndPorts {
-		lhh.lh.addRemoteV4(n.Details.VpnIp, to, false, false)
-	}
+	lhh.lh.Lock()
+	am := lhh.lh.unlockedGetRemoteList(n.Details.VpnIp)
+	am.Lock()
+	lhh.lh.Unlock()
 
-	for _, to := range n.Details.Ip6AndPorts {
-		lhh.lh.addRemoteV6(n.Details.VpnIp, to, false, false)
-	}
+	am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	am.Unlock()
 
 	// Non-blocking attempt to trigger, skip if it would block
 	select {
@@ -637,35 +577,13 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp
 	}
 
 	lhh.lh.Lock()
-	defer lhh.lh.Unlock()
-	am := lhh.lh.unlockedGetAddrs(vpnIp)
-
-	//TODO: other note on a lock for am so we can release more quickly and lock our real unit of change which is far less contended
-
-	// We don't accumulate addresses being told to us
-	am.v4 = am.v4[:0]
-	am.v6 = am.v6[:0]
-
-	for _, v := range n.Details.Ip4AndPorts {
-		if lhh.lh.unlockedShouldAddV4(am.v4, v) {
-			am.v4 = append(am.v4, v)
-		}
-	}
-
-	for _, v := range n.Details.Ip6AndPorts {
-		if lhh.lh.unlockedShouldAddV6(am.v6, v) {
-			am.v6 = append(am.v6, v)
-		}
-	}
+	am := lhh.lh.unlockedGetRemoteList(vpnIp)
+	am.Lock()
+	lhh.lh.Unlock()
 
-	// We prefer the first n addresses if we got too big
-	if len(am.v4) > MaxRemotes {
-		am.v4 = am.v4[:MaxRemotes]
-	}
-
-	if len(am.v6) > MaxRemotes {
-		am.v6 = am.v6[:MaxRemotes]
-	}
+	am.unlockedSetV4(vpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4)
+	am.unlockedSetV6(vpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6)
+	am.Unlock()
 }
 
 func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp uint32, w EncWriter) {
@@ -716,33 +634,6 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp u
 	}
 }
 
-func TransformLHReplyToUdpAddrs(ips *ip4And6) []*udpAddr {
-	addrs := make([]*udpAddr, len(ips.v4)+len(ips.v6)+len(ips.learnedV4)+len(ips.learnedV6))
-	i := 0
-
-	for _, v := range ips.learnedV4 {
-		addrs[i] = NewUDPAddrFromLH4(v)
-		i++
-	}
-
-	for _, v := range ips.v4 {
-		addrs[i] = NewUDPAddrFromLH4(v)
-		i++
-	}
-
-	for _, v := range ips.learnedV6 {
-		addrs[i] = NewUDPAddrFromLH6(v)
-		i++
-	}
-
-	for _, v := range ips.v6 {
-		addrs[i] = NewUDPAddrFromLH6(v)
-		i++
-	}
-
-	return addrs
-}
-
 // ipMaskContains checks if testIp is contained by ip after applying a cidr
 // zeros is 32 - bits from net.IPMask.Size()
 func ipMaskContains(ip uint32, zeros uint32, testIp uint32) bool {

+ 100 - 79
lighthouse_test.go

@@ -48,16 +48,16 @@ func Test_lhStaticMapping(t *testing.T) {
 
 	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
 
-	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
+	meh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
 	err := meh.ValidateLHStaticEntries()
 	assert.Nil(t, err)
 
 	lh2 := "10.128.0.3"
 	lh2IP := net.ParseIP(lh2)
 
-	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
-	meh.AddRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)), true)
+	meh = NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 255}}, []uint32{ip2int(lh1IP), ip2int(lh2IP)}, 10, 10003, udpServer, false, 1, false)
+	meh.AddStaticRemote(ip2int(lh1IP), NewUDPAddr(lh1IP, uint16(4242)))
 	err = meh.ValidateLHStaticEntries()
 	assert.EqualError(t, err, "Lighthouse 10.128.0.3 does not have a static_host_map entry")
 }
@@ -73,17 +73,27 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
 
 	hAddr := NewUDPAddrFromString("4.5.6.7:12345")
 	hAddr2 := NewUDPAddrFromString("4.5.6.7:12346")
-	lh.addrMap[3] = &ip4And6{v4: []*Ip4AndPort{
-		NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
-		NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port))},
-	}
+	lh.addrMap[3] = NewRemoteList()
+	lh.addrMap[3].unlockedSetV4(
+		3,
+		[]*Ip4AndPort{
+			NewIp4AndPort(hAddr.IP, uint32(hAddr.Port)),
+			NewIp4AndPort(hAddr2.IP, uint32(hAddr2.Port)),
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
 
 	rAddr := NewUDPAddrFromString("1.2.2.3:12345")
 	rAddr2 := NewUDPAddrFromString("1.2.2.3:12346")
-	lh.addrMap[2] = &ip4And6{v4: []*Ip4AndPort{
-		NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
-		NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port))},
-	}
+	lh.addrMap[2] = NewRemoteList()
+	lh.addrMap[2].unlockedSetV4(
+		3,
+		[]*Ip4AndPort{
+			NewIp4AndPort(rAddr.IP, uint32(rAddr.Port)),
+			NewIp4AndPort(rAddr2.IP, uint32(rAddr2.Port)),
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
 
 	mw := &mockEncWriter{}
 
@@ -173,7 +183,7 @@ func TestLighthouse_Memory(t *testing.T) {
 	assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4)
 
 	// Ensure proper ordering and limiting
-	// Send 12 addrs, get 10 back, one removed on a dupe check the other by limiting
+	// Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe)
 	newLHHostUpdate(
 		myUdpAddr0,
 		myVpnIp,
@@ -191,11 +201,12 @@ func TestLighthouse_Memory(t *testing.T) {
 			myUdpAddr10,
 			myUdpAddr11, // This should get cut
 		}, lhh)
+
 	r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh)
 	assertIp4InArray(
 		t,
 		r.msg.Details.Ip4AndPorts,
-		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, myUdpAddr10,
+		myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9,
 	)
 
 	// Make sure we won't add ips in our vpn network
@@ -247,71 +258,71 @@ func newLHHostUpdate(fromAddr *udpAddr, vpnIp uint32, addrs []*udpAddr, lhh *Lig
 	lhh.HandleRequest(fromAddr, vpnIp, b, w)
 }
 
-func Test_lhRemoteAllowList(t *testing.T) {
-	l := NewTestLogger()
-	c := NewConfig(l)
-	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
-		"10.20.0.0/12": false,
-	}
-	allowList, err := c.GetAllowList("remoteallowlist", false)
-	assert.Nil(t, err)
-
-	lh1 := "10.128.0.2"
-	lh1IP := net.ParseIP(lh1)
-
-	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
-
-	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
-	lh.SetRemoteAllowList(allowList)
-
-	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
-	remote1IP := net.ParseIP("10.20.0.3")
-	lh.AddRemote(ip2int(remote1IP), NewUDPAddr(remote1IP, uint16(4242)), true)
-	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
-	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v4)
-	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].v6)
-
-	// Make sure a good ip enters the cache and addrMap
-	remote2IP := net.ParseIP("10.128.0.3")
-	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
-	lh.AddRemote(ip2int(remote2IP), remote2UDPAddr, true)
-	assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote2UDPAddr)
-
-	// Another good ip gets into the cache, ordering is inverted
-	remote3IP := net.ParseIP("10.128.0.4")
-	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
-	lh.AddRemote(ip2int(remote2IP), remote3UDPAddr, true)
-	assertIp4InArray(t, lh.addrMap[ip2int(remote2IP)].learnedV4, remote3UDPAddr, remote2UDPAddr)
-
-	// If we exceed the length limit we should only have the most recent addresses
-	addedAddrs := []*udpAddr{}
-	for i := 0; i < 11; i++ {
-		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
-		lh.AddRemote(ip2int(remote2IP), remoteUDPAddr, true)
-		// The first entry here is a duplicate, don't add it to the assert list
-		if i != 0 {
-			addedAddrs = append(addedAddrs, remoteUDPAddr)
-		}
-	}
-
-	// We should only have the last 10 of what we tried to add
-	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
-	ln := len(addedAddrs)
-	assertIp4InArray(
-		t,
-		lh.addrMap[ip2int(remote2IP)].learnedV4,
-		addedAddrs[ln-1],
-		addedAddrs[ln-2],
-		addedAddrs[ln-3],
-		addedAddrs[ln-4],
-		addedAddrs[ln-5],
-		addedAddrs[ln-6],
-		addedAddrs[ln-7],
-		addedAddrs[ln-8],
-		addedAddrs[ln-9],
-		addedAddrs[ln-10],
-	)
-}
+//TODO: this is a RemoteList test
+//func Test_lhRemoteAllowList(t *testing.T) {
+//	l := NewTestLogger()
+//	c := NewConfig(l)
+//	c.Settings["remoteallowlist"] = map[interface{}]interface{}{
+//		"10.20.0.0/12": false,
+//	}
+//	allowList, err := c.GetAllowList("remoteallowlist", false)
+//	assert.Nil(t, err)
+//
+//	lh1 := "10.128.0.2"
+//	lh1IP := net.ParseIP(lh1)
+//
+//	udpServer, _ := NewListener(l, "0.0.0.0", 0, true)
+//
+//	lh := NewLightHouse(l, true, &net.IPNet{IP: net.IP{0, 0, 0, 1}, Mask: net.IPMask{255, 255, 255, 0}}, []uint32{ip2int(lh1IP)}, 10, 10003, udpServer, false, 1, false)
+//	lh.SetRemoteAllowList(allowList)
+//
+//	// A disallowed ip should not enter the cache but we should end up with an empty entry in the addrMap
+//	remote1IP := net.ParseIP("10.20.0.3")
+//	remotes := lh.unlockedGetRemoteList(ip2int(remote1IP))
+//	remotes.unlockedPrependV4(ip2int(remote1IP), NewIp4AndPort(remote1IP, 4242))
+//	assert.NotNil(t, lh.addrMap[ip2int(remote1IP)])
+//	assert.Empty(t, lh.addrMap[ip2int(remote1IP)].CopyAddrs([]*net.IPNet{}))
+//
+//	// Make sure a good ip enters the cache and addrMap
+//	remote2IP := net.ParseIP("10.128.0.3")
+//	remote2UDPAddr := NewUDPAddr(remote2IP, uint16(4242))
+//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote2UDPAddr.IP, uint32(remote2UDPAddr.Port)), false, false)
+//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr)
+//
+//	// Another good ip gets into the cache, ordering is inverted
+//	remote3IP := net.ParseIP("10.128.0.4")
+//	remote3UDPAddr := NewUDPAddr(remote3IP, uint16(4243))
+//	lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remote3UDPAddr.IP, uint32(remote3UDPAddr.Port)), false, false)
+//	assertUdpAddrInArray(t, lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}), remote2UDPAddr, remote3UDPAddr)
+//
+//	// If we exceed the length limit we should only have the most recent addresses
+//	addedAddrs := []*udpAddr{}
+//	for i := 0; i < 11; i++ {
+//		remoteUDPAddr := NewUDPAddr(net.IP{10, 128, 0, 4}, uint16(4243+i))
+//		lh.addRemoteV4(ip2int(remote2IP), ip2int(remote2IP), NewIp4AndPort(remoteUDPAddr.IP, uint32(remoteUDPAddr.Port)), false, false)
+//		// The first entry here is a duplicate, don't add it to the assert list
+//		if i != 0 {
+//			addedAddrs = append(addedAddrs, remoteUDPAddr)
+//		}
+//	}
+//
+//	// We should only have the last 10 of what we tried to add
+//	assert.True(t, len(addedAddrs) >= 10, "We should have tried to add at least 10 addresses")
+//	assertUdpAddrInArray(
+//		t,
+//		lh.addrMap[ip2int(remote2IP)].CopyAddrs([]*net.IPNet{}),
+//		addedAddrs[0],
+//		addedAddrs[1],
+//		addedAddrs[2],
+//		addedAddrs[3],
+//		addedAddrs[4],
+//		addedAddrs[5],
+//		addedAddrs[6],
+//		addedAddrs[7],
+//		addedAddrs[8],
+//		addedAddrs[9],
+//	)
+//}
 
 func Test_ipMaskContains(t *testing.T) {
 	assert.True(t, ipMaskContains(ip2int(net.ParseIP("10.0.0.1")), 32-24, ip2int(net.ParseIP("10.0.0.255"))))
@@ -354,6 +365,16 @@ func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...*udpAddr) {
 	}
 }
 
+// assertUdpAddrInArray asserts every address in want is at the same position in have and that the lengths match
+func assertUdpAddrInArray(t *testing.T, have []*udpAddr, want ...*udpAddr) {
+	assert.Len(t, have, len(want))
+	for k, w := range want {
+		if !(have[k].IP.Equal(w.IP) && have[k].Port == w.Port) {
+			assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v; %v", w, k, have))
+		}
+	}
+}
+
 func translateV4toUdpAddr(ips []*Ip4AndPort) []*udpAddr {
 	addrs := make([]*udpAddr, len(ips))
 	for k, v := range ips {

+ 3 - 4
main.go

@@ -221,7 +221,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	}
 
 	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
-	hostMap.SetDefaultRoute(ip2int(net.ParseIP(config.GetString("default_route", "0.0.0.0"))))
+
 	hostMap.addUnsafeRoutes(&unsafeRoutes)
 	hostMap.metricsEnabled = config.GetBool("stats.message_metrics", false)
 
@@ -302,14 +302,14 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 				if err != nil {
 					return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 				}
-				lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true)
+				lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
 			}
 		} else {
 			ip, port, err := parseIPAndPort(fmt.Sprintf("%v", v))
 			if err != nil {
 				return nil, NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp}, err)
 			}
-			lightHouse.AddRemote(ip2int(vpnIp), NewUDPAddr(ip, port), true)
+			lightHouse.AddStaticRemote(ip2int(vpnIp), NewUDPAddr(ip, port))
 		}
 	}
 
@@ -328,7 +328,6 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 	handshakeConfig := HandshakeConfig{
 		tryInterval:   config.GetDuration("handshakes.try_interval", DefaultHandshakeTryInterval),
 		retries:       config.GetInt("handshakes.retries", DefaultHandshakeRetries),
-		waitRotation:  config.GetInt("handshakes.wait_rotation", DefaultHandshakeWaitRotation),
 		triggerBuffer: config.GetInt("handshakes.trigger_buffer", DefaultHandshakeTriggerBuffer),
 
 		messageMetrics: messageMetrics,

+ 6 - 3
outside.go

@@ -132,6 +132,7 @@ func (f *Interface) readOutsidePackets(addr *udpAddr, out []byte, packet []byte,
 	f.connectionManager.In(hostinfo.hostId)
 }
 
+// closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
 func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	//TODO: this would be better as a single function in ConnectionManager that handled locks appropriately
 	f.connectionManager.ClearIP(hostInfo.hostId)
@@ -140,6 +141,11 @@ func (f *Interface) closeTunnel(hostInfo *HostInfo) {
 	f.hostMap.DeleteHostInfo(hostInfo)
 }
 
+// sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
+func (f *Interface) sendCloseTunnel(h *HostInfo) {
+	f.send(closeTunnel, 0, h.ConnectionState, h, h.remote, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
+}
+
 func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 	if hostDidRoam(hostinfo.remote, addr) {
 		if !f.lightHouse.remoteAllowList.Allow(addr.IP) {
@@ -160,9 +166,6 @@ func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udpAddr) {
 		remoteCopy := *hostinfo.remote
 		hostinfo.lastRoamRemote = &remoteCopy
 		hostinfo.SetRemote(addr)
-		if f.lightHouse.amLighthouse {
-			f.lightHouse.AddRemote(hostinfo.hostId, addr, false)
-		}
 	}
 
 }

+ 500 - 0
remote_list.go

@@ -0,0 +1,500 @@
+package nebula
+
+import (
+	"bytes"
+	"net"
+	"sort"
+	"sync"
+)
+
+// forEachFunc is used to benefit folks that want to do work inside the lock
+type forEachFunc func(addr *udpAddr, preferred bool)
+
+// The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate)
+type checkFuncV4 func(to *Ip4AndPort) bool
+type checkFuncV6 func(to *Ip6AndPort) bool
+
+// CacheMap is a struct that better represents the lighthouse cache for humans
+// The string key is the owners vpnIp
+type CacheMap map[string]*Cache
+
+// Cache is the other part of CacheMap to better represent the lighthouse cache for humans
+// We don't reason about ipv4 vs ipv6 here
+type Cache struct {
+	Learned  []*udpAddr `json:"learned,omitempty"`
+	Reported []*udpAddr `json:"reported,omitempty"`
+}
+
+//TODO: Seems like we should plop static host entries in here too since the are protected by the lighthouse from deletion
+// We will never clean learned/reported information for them as it stands today
+
+// cache is an internal struct that splits v4 and v6 addresses inside the cache map
+type cache struct {
+	v4 *cacheV4
+	v6 *cacheV6
+}
+
+// cacheV4 stores learned and reported ipv4 records under cache
+type cacheV4 struct {
+	learned  *Ip4AndPort
+	reported []*Ip4AndPort
+}
+
+// cacheV4 stores learned and reported ipv6 records under cache
+type cacheV6 struct {
+	learned  *Ip6AndPort
+	reported []*Ip6AndPort
+}
+
+// RemoteList is a unifying concept for lighthouse servers and clients as well as hostinfos.
+// It serves as a local cache of query replies, host update notifications, and locally learned addresses
+type RemoteList struct {
+	// Every interaction with internals requires a lock!
+	sync.RWMutex
+
+	// A deduplicated set of addresses. Any accessor should lock beforehand.
+	addrs []*udpAddr
+
+	// These are maps to store v4 and v6 addresses per lighthouse
+	// Map key is the vpnIp of the person that told us about this the cached entries underneath.
+	// For learned addresses, this is the vpnIp that sent the packet
+	cache map[uint32]*cache
+
+	// This is a list of remotes that we have tried to handshake with and have returned from the wrong vpn ip.
+	// They should not be tried again during a handshake
+	badRemotes []*udpAddr
+
+	// A flag that the cache may have changed and addrs needs to be rebuilt
+	shouldRebuild bool
+}
+
+// NewRemoteList creates a new empty RemoteList
+func NewRemoteList() *RemoteList {
+	return &RemoteList{
+		addrs: make([]*udpAddr, 0),
+		cache: make(map[uint32]*cache),
+	}
+}
+
+// Len locks and reports the size of the deduplicated address list
+// The deduplication work may need to occur here, so you must pass preferredRanges
+func (r *RemoteList) Len(preferredRanges []*net.IPNet) int {
+	r.Rebuild(preferredRanges)
+	r.RLock()
+	defer r.RUnlock()
+	return len(r.addrs)
+}
+
+// ForEach locks and will call the forEachFunc for every deduplicated address in the list
+// The deduplication work may need to occur here, so you must pass preferredRanges
+func (r *RemoteList) ForEach(preferredRanges []*net.IPNet, forEach forEachFunc) {
+	r.Rebuild(preferredRanges)
+	r.RLock()
+	for _, v := range r.addrs {
+		forEach(v, isPreferred(v.IP, preferredRanges))
+	}
+	r.RUnlock()
+}
+
+// CopyAddrs locks and makes a deep copy of the deduplicated address list
+// The deduplication work may need to occur here, so you must pass preferredRanges
+func (r *RemoteList) CopyAddrs(preferredRanges []*net.IPNet) []*udpAddr {
+	r.Rebuild(preferredRanges)
+
+	r.RLock()
+	defer r.RUnlock()
+	c := make([]*udpAddr, len(r.addrs))
+	for i, v := range r.addrs {
+		c[i] = v.Copy()
+	}
+	return c
+}
+
+// LearnRemote locks and sets the learned slot for the owner vpn ip to the provided addr
+// Currently this is only needed when HostInfo.SetRemote is called as that should cover both handshaking and roaming.
+// It will mark the deduplicated address list as dirty, so do not call it unless new information is available
+//TODO: this needs to support the allow list list
+func (r *RemoteList) LearnRemote(ownerVpnIp uint32, addr *udpAddr) {
+	r.Lock()
+	defer r.Unlock()
+	if v4 := addr.IP.To4(); v4 != nil {
+		r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPort(v4, uint32(addr.Port)))
+	} else {
+		r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPort(addr.IP, uint32(addr.Port)))
+	}
+}
+
+// CopyCache locks and creates a more human friendly form of the internal address cache.
+// This may contain duplicates and blocked addresses
+func (r *RemoteList) CopyCache() *CacheMap {
+	r.RLock()
+	defer r.RUnlock()
+
+	cm := make(CacheMap)
+	getOrMake := func(vpnIp string) *Cache {
+		c := cm[vpnIp]
+		if c == nil {
+			c = &Cache{
+				Learned:  make([]*udpAddr, 0),
+				Reported: make([]*udpAddr, 0),
+			}
+			cm[vpnIp] = c
+		}
+		return c
+	}
+
+	for owner, mc := range r.cache {
+		c := getOrMake(IntIp(owner).String())
+
+		if mc.v4 != nil {
+			if mc.v4.learned != nil {
+				c.Learned = append(c.Learned, NewUDPAddrFromLH4(mc.v4.learned))
+			}
+
+			for _, a := range mc.v4.reported {
+				c.Reported = append(c.Reported, NewUDPAddrFromLH4(a))
+			}
+		}
+
+		if mc.v6 != nil {
+			if mc.v6.learned != nil {
+				c.Learned = append(c.Learned, NewUDPAddrFromLH6(mc.v6.learned))
+			}
+
+			for _, a := range mc.v6.reported {
+				c.Reported = append(c.Reported, NewUDPAddrFromLH6(a))
+			}
+		}
+	}
+
+	return &cm
+}
+
+// BlockRemote locks and records the address as bad, it will be excluded from the deduplicated address list
+func (r *RemoteList) BlockRemote(bad *udpAddr) {
+	r.Lock()
+	defer r.Unlock()
+
+	// Check if we already blocked this addr
+	if r.unlockedIsBad(bad) {
+		return
+	}
+
+	// We copy here because we are taking something else's memory and we can't trust everything
+	r.badRemotes = append(r.badRemotes, bad.Copy())
+
+	// Mark the next interaction must recollect/dedupe
+	r.shouldRebuild = true
+}
+
+// CopyBlockedRemotes locks and makes a deep copy of the blocked remotes list
+func (r *RemoteList) CopyBlockedRemotes() []*udpAddr {
+	r.RLock()
+	defer r.RUnlock()
+
+	c := make([]*udpAddr, len(r.badRemotes))
+	for i, v := range r.badRemotes {
+		c[i] = v.Copy()
+	}
+	return c
+}
+
+// ResetBlockedRemotes locks and clears the blocked remotes list
+func (r *RemoteList) ResetBlockedRemotes() {
+	r.Lock()
+	r.badRemotes = nil
+	r.Unlock()
+}
+
+// Rebuild locks and generates the deduplicated address list only if there is work to be done
+// There is generally no reason to call this directly but it is safe to do so
+func (r *RemoteList) Rebuild(preferredRanges []*net.IPNet) {
+	r.Lock()
+	defer r.Unlock()
+
+	// Only rebuild if the cache changed
+	//TODO: shouldRebuild is probably pointless as we don't check for actual change when lighthouse updates come in
+	if r.shouldRebuild {
+		r.unlockedCollect()
+		r.shouldRebuild = false
+	}
+
+	// Always re-sort, preferredRanges can change via HUP
+	r.unlockedSort(preferredRanges)
+}
+
+// unlockedIsBad assumes you have the write lock and checks if the remote matches any entry in the blocked address list
+func (r *RemoteList) unlockedIsBad(remote *udpAddr) bool {
+	for _, v := range r.badRemotes {
+		if v.Equals(remote) {
+			return true
+		}
+	}
+	return false
+}
+
+// unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the
+// deduplicated address list as dirty
+func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp uint32, to *Ip4AndPort) {
+	r.shouldRebuild = true
+	r.unlockedGetOrMakeV4(ownerVpnIp).learned = to
+}
+
+// unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
+// and marks the deduplicated address list as dirty
+func (r *RemoteList) unlockedSetV4(ownerVpnIp uint32, to []*Ip4AndPort, check checkFuncV4) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV4(ownerVpnIp)
+
+	// Reset the slice
+	c.reported = c.reported[:0]
+
+	// We can't take their array but we can take their pointers
+	for _, v := range to[:minInt(len(to), MaxRemotes)] {
+		if check(v) {
+			c.reported = append(c.reported, v)
+		}
+	}
+}
+
+// unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner
+// This is only useful for establishing static hosts
+func (r *RemoteList) unlockedPrependV4(ownerVpnIp uint32, to *Ip4AndPort) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV4(ownerVpnIp)
+
+	// We are doing the easy append because this is rarely called
+	c.reported = append([]*Ip4AndPort{to}, c.reported...)
+	if len(c.reported) > MaxRemotes {
+		c.reported = c.reported[:MaxRemotes]
+	}
+}
+
+// unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the
+// deduplicated address list as dirty
+func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp uint32, to *Ip6AndPort) {
+	r.shouldRebuild = true
+	r.unlockedGetOrMakeV6(ownerVpnIp).learned = to
+}
+
+// unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided
+// and marks the deduplicated address list as dirty
+func (r *RemoteList) unlockedSetV6(ownerVpnIp uint32, to []*Ip6AndPort, check checkFuncV6) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV6(ownerVpnIp)
+
+	// Reset the slice
+	c.reported = c.reported[:0]
+
+	// We can't take their array but we can take their pointers
+	for _, v := range to[:minInt(len(to), MaxRemotes)] {
+		if check(v) {
+			c.reported = append(c.reported, v)
+		}
+	}
+}
+
+// unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner
+// This is only useful for establishing static hosts
+func (r *RemoteList) unlockedPrependV6(ownerVpnIp uint32, to *Ip6AndPort) {
+	r.shouldRebuild = true
+	c := r.unlockedGetOrMakeV6(ownerVpnIp)
+
+	// We are doing the easy append because this is rarely called
+	c.reported = append([]*Ip6AndPort{to}, c.reported...)
+	if len(c.reported) > MaxRemotes {
+		c.reported = c.reported[:MaxRemotes]
+	}
+}
+
+// unlockedGetOrMakeV4 assumes you have the write lock and builds the cache and owner entry. Only the v4 pointer is established.
+// The caller must dirty the learned address cache if required
+func (r *RemoteList) unlockedGetOrMakeV4(ownerVpnIp uint32) *cacheV4 {
+	am := r.cache[ownerVpnIp]
+	if am == nil {
+		am = &cache{}
+		r.cache[ownerVpnIp] = am
+	}
+	// Avoid occupying memory for v6 addresses if we never have any
+	if am.v4 == nil {
+		am.v4 = &cacheV4{}
+	}
+	return am.v4
+}
+
+// unlockedGetOrMakeV6 assumes you have the write lock and builds the cache and owner entry. Only the v6 pointer is established.
+// The caller must dirty the learned address cache if required
+func (r *RemoteList) unlockedGetOrMakeV6(ownerVpnIp uint32) *cacheV6 {
+	am := r.cache[ownerVpnIp]
+	if am == nil {
+		am = &cache{}
+		r.cache[ownerVpnIp] = am
+	}
+	// Avoid occupying memory for v4 addresses if we never have any
+	if am.v6 == nil {
+		am.v6 = &cacheV6{}
+	}
+	return am.v6
+}
+
+// unlockedCollect assumes you have the write lock and collects/transforms the cache into the deduped address list.
+// The result of this function can contain duplicates. unlockedSort handles cleaning it.
+func (r *RemoteList) unlockedCollect() {
+	addrs := r.addrs[:0]
+
+	for _, c := range r.cache {
+		if c.v4 != nil {
+			if c.v4.learned != nil {
+				u := NewUDPAddrFromLH4(c.v4.learned)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+
+			for _, v := range c.v4.reported {
+				u := NewUDPAddrFromLH4(v)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+		}
+
+		if c.v6 != nil {
+			if c.v6.learned != nil {
+				u := NewUDPAddrFromLH6(c.v6.learned)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+
+			for _, v := range c.v6.reported {
+				u := NewUDPAddrFromLH6(v)
+				if !r.unlockedIsBad(u) {
+					addrs = append(addrs, u)
+				}
+			}
+		}
+	}
+
+	r.addrs = addrs
+}
+
+// unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list
+func (r *RemoteList) unlockedSort(preferredRanges []*net.IPNet) {
+	n := len(r.addrs)
+	if n < 2 {
+		return
+	}
+
+	lessFunc := func(i, j int) bool {
+		a := r.addrs[i]
+		b := r.addrs[j]
+		// Preferred addresses first
+
+		aPref := isPreferred(a.IP, preferredRanges)
+		bPref := isPreferred(b.IP, preferredRanges)
+		switch {
+		case aPref && !bPref:
+			// If i is preferred and j is not, i is less than j
+			return true
+
+		case !aPref && bPref:
+			// If j is preferred then i is not due to the else, i is not less than j
+			return false
+
+		default:
+			// Both i an j are either preferred or not, sort within that
+		}
+
+		// ipv6 addresses 2nd
+		a4 := a.IP.To4()
+		b4 := b.IP.To4()
+		switch {
+		case a4 == nil && b4 != nil:
+			// If i is v6 and j is v4, i is less than j
+			return true
+
+		case a4 != nil && b4 == nil:
+			// If j is v6 and i is v4, i is not less than j
+			return false
+
+		case a4 != nil && b4 != nil:
+			// Special case for ipv4, a4 and b4 are not nil
+			aPrivate := isPrivateIP(a4)
+			bPrivate := isPrivateIP(b4)
+			switch {
+			case !aPrivate && bPrivate:
+				// If i is a public ip (not private) and j is a private ip, i is less then j
+				return true
+
+			case aPrivate && !bPrivate:
+				// If j is public (not private) then i is private due to the else, i is not less than j
+				return false
+
+			default:
+				// Both i an j are either public or private, sort within that
+			}
+
+		default:
+			// Both i an j are either ipv4 or ipv6, sort within that
+		}
+
+		// lexical order of ips 3rd
+		c := bytes.Compare(a.IP, b.IP)
+		if c == 0 {
+			// Ips are the same, Lexical order of ports 4th
+			return a.Port < b.Port
+		}
+
+		// Ip wasn't the same
+		return c < 0
+	}
+
+	// Sort it
+	sort.Slice(r.addrs, lessFunc)
+
+	// Deduplicate
+	a, b := 0, 1
+	for b < n {
+		if !r.addrs[a].Equals(r.addrs[b]) {
+			a++
+			if a != b {
+				r.addrs[a], r.addrs[b] = r.addrs[b], r.addrs[a]
+			}
+		}
+		b++
+	}
+
+	r.addrs = r.addrs[:a+1]
+	return
+}
+
+// minInt returns the minimum integer of a or b
+func minInt(a, b int) int {
+	if a < b {
+		return a
+	}
+	return b
+}
+
+// isPreferred returns true of the ip is contained in the preferredRanges list
+func isPreferred(ip net.IP, preferredRanges []*net.IPNet) bool {
+	//TODO: this would be better in a CIDR6Tree
+	for _, p := range preferredRanges {
+		if p.Contains(ip) {
+			return true
+		}
+	}
+	return false
+}
+
+var _, private24BitBlock, _ = net.ParseCIDR("10.0.0.0/8")
+var _, private20BitBlock, _ = net.ParseCIDR("172.16.0.0/12")
+var _, private16BitBlock, _ = net.ParseCIDR("192.168.0.0/16")
+
+// isPrivateIP returns true if the ip is contained by a rfc 1918 private range
+func isPrivateIP(ip net.IP) bool {
+	//TODO: another great cidrtree option
+	//TODO: Private for ipv6 or just let it ride?
+	return private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
+}

+ 228 - 0
remote_list_test.go

@@ -0,0 +1,228 @@
+package nebula
+
+import (
+	"net"
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestRemoteList_Rebuild(t *testing.T) {
+	rl := NewRemoteList()
+	rl.unlockedSetV4(
+		0,
+		[]*Ip4AndPort{
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is duped
+			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101}, // this is duped
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is duped
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101}, // this is a dupe
+			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // almost dupe of 0 with a diff port
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475}, // this is a dupe
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
+
+	rl.unlockedSetV6(
+		1,
+		[]*Ip6AndPort{
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is duped
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // almost dupe of 0 with a diff port, also gets duped
+			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // this is a dupe
+		},
+		func(*Ip6AndPort) bool { return true },
+	)
+
+	rl.Rebuild([]*net.IPNet{})
+	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
+
+	// ipv6 first, sorted lexically within
+	assert.Equal(t, "[1::1]:1", rl.addrs[0].String())
+	assert.Equal(t, "[1::1]:2", rl.addrs[1].String())
+	assert.Equal(t, "[1:100::1]:1", rl.addrs[2].String())
+
+	// ipv4 last, sorted by public first, then private, lexically within them
+	assert.Equal(t, "70.199.182.92:1475", rl.addrs[3].String())
+	assert.Equal(t, "70.199.182.92:1476", rl.addrs[4].String())
+	assert.Equal(t, "172.17.0.182:10101", rl.addrs[5].String())
+	assert.Equal(t, "172.17.1.1:10101", rl.addrs[6].String())
+	assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
+	assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
+	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
+
+	// Now ensure we can hoist ipv4 up
+	_, ipNet, err := net.ParseCIDR("0.0.0.0/0")
+	assert.NoError(t, err)
+	rl.Rebuild([]*net.IPNet{ipNet})
+	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
+
+	// ipv4 first, public then private, lexically within them
+	assert.Equal(t, "70.199.182.92:1475", rl.addrs[0].String())
+	assert.Equal(t, "70.199.182.92:1476", rl.addrs[1].String())
+	assert.Equal(t, "172.17.0.182:10101", rl.addrs[2].String())
+	assert.Equal(t, "172.17.1.1:10101", rl.addrs[3].String())
+	assert.Equal(t, "172.18.0.1:10101", rl.addrs[4].String())
+	assert.Equal(t, "172.19.0.1:10101", rl.addrs[5].String())
+	assert.Equal(t, "172.31.0.1:10101", rl.addrs[6].String())
+
+	// ipv6 last, sorted by public first, then private, lexically within them
+	assert.Equal(t, "[1::1]:1", rl.addrs[7].String())
+	assert.Equal(t, "[1::1]:2", rl.addrs[8].String())
+	assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String())
+
+	// Ensure we can hoist a specific ipv4 range over anything else
+	_, ipNet, err = net.ParseCIDR("172.17.0.0/16")
+	assert.NoError(t, err)
+	rl.Rebuild([]*net.IPNet{ipNet})
+	assert.Len(t, rl.addrs, 10, "addrs contains too many entries")
+
+	// Preferred ipv4 first
+	assert.Equal(t, "172.17.0.182:10101", rl.addrs[0].String())
+	assert.Equal(t, "172.17.1.1:10101", rl.addrs[1].String())
+
+	// ipv6 next
+	assert.Equal(t, "[1::1]:1", rl.addrs[2].String())
+	assert.Equal(t, "[1::1]:2", rl.addrs[3].String())
+	assert.Equal(t, "[1:100::1]:1", rl.addrs[4].String())
+
+	// the remaining ipv4 last
+	assert.Equal(t, "70.199.182.92:1475", rl.addrs[5].String())
+	assert.Equal(t, "70.199.182.92:1476", rl.addrs[6].String())
+	assert.Equal(t, "172.18.0.1:10101", rl.addrs[7].String())
+	assert.Equal(t, "172.19.0.1:10101", rl.addrs[8].String())
+	assert.Equal(t, "172.31.0.1:10101", rl.addrs[9].String())
+}
+
+func BenchmarkFullRebuild(b *testing.B) {
+	rl := NewRemoteList()
+	rl.unlockedSetV4(
+		0,
+		[]*Ip4AndPort{
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
+			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
+
+	rl.unlockedSetV6(
+		0,
+		[]*Ip6AndPort{
+			NewIp6AndPort(net.ParseIP("1::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
+			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+		},
+		func(*Ip6AndPort) bool { return true },
+	)
+
+	b.Run("no preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{})
+		}
+	})
+
+	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
+	assert.NoError(b, err)
+	b.Run("1 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{ipNet})
+		}
+	})
+
+	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
+	assert.NoError(b, err)
+	b.Run("2 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+		}
+	})
+
+	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
+	assert.NoError(b, err)
+	b.Run("3 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+		}
+	})
+}
+
+func BenchmarkSortRebuild(b *testing.B) {
+	rl := NewRemoteList()
+	rl.unlockedSetV4(
+		0,
+		[]*Ip4AndPort{
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1475},
+			{Ip: ip2int(net.ParseIP("172.17.0.182")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.18.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.19.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.31.0.1")), Port: 10101},
+			{Ip: ip2int(net.ParseIP("172.17.1.1")), Port: 10101},   // this is a dupe
+			{Ip: ip2int(net.ParseIP("70.199.182.92")), Port: 1476}, // dupe of 0 with a diff port
+		},
+		func(*Ip4AndPort) bool { return true },
+	)
+
+	rl.unlockedSetV6(
+		0,
+		[]*Ip6AndPort{
+			NewIp6AndPort(net.ParseIP("1::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 2), // dupe of 0 with a diff port
+			NewIp6AndPort(net.ParseIP("1:100::1"), 1),
+			NewIp6AndPort(net.ParseIP("1::1"), 1), // this is a dupe
+		},
+		func(*Ip6AndPort) bool { return true },
+	)
+
+	b.Run("no preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.shouldRebuild = true
+			rl.Rebuild([]*net.IPNet{})
+		}
+	})
+
+	_, ipNet, err := net.ParseCIDR("172.17.0.0/16")
+	rl.Rebuild([]*net.IPNet{ipNet})
+
+	assert.NoError(b, err)
+	b.Run("1 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.Rebuild([]*net.IPNet{ipNet})
+		}
+	})
+
+	_, ipNet2, err := net.ParseCIDR("70.0.0.0/8")
+	rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+
+	assert.NoError(b, err)
+	b.Run("2 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2})
+		}
+	})
+
+	_, ipNet3, err := net.ParseCIDR("0.0.0.0/0")
+	rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+
+	assert.NoError(b, err)
+	b.Run("3 preferred", func(b *testing.B) {
+		for i := 0; i < b.N; i++ {
+			rl.Rebuild([]*net.IPNet{ipNet, ipNet2, ipNet3})
+		}
+	})
+}

+ 39 - 48
ssh.go

@@ -10,8 +10,8 @@ import (
 	"os"
 	"reflect"
 	"runtime/pprof"
+	"sort"
 	"strings"
-	"sync/atomic"
 	"syscall"
 
 	"github.com/sirupsen/logrus"
@@ -335,8 +335,10 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 		return nil
 	}
 
-	hostMap.RLock()
-	defer hostMap.RUnlock()
+	hm := listHostMap(hostMap)
+	sort.Slice(hm, func(i, j int) bool {
+		return bytes.Compare(hm[i].VpnIP, hm[j].VpnIP) < 0
+	})
 
 	if fs.Json || fs.Pretty {
 		js := json.NewEncoder(w.GetWriter())
@@ -344,35 +346,15 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 			js.SetIndent("", "    ")
 		}
 
-		d := make([]m, len(hostMap.Hosts))
-		x := 0
-		var h m
-		for _, v := range hostMap.Hosts {
-			h = m{
-				"vpnIp":         int2ip(v.hostId),
-				"localIndex":    v.localIndexId,
-				"remoteIndex":   v.remoteIndexId,
-				"remoteAddrs":   v.CopyRemotes(),
-				"cachedPackets": len(v.packetStore),
-				"cert":          v.GetCert(),
-			}
-
-			if v.ConnectionState != nil {
-				h["messageCounter"] = atomic.LoadUint64(&v.ConnectionState.atomicMessageCounter)
-			}
-
-			d[x] = h
-			x++
-		}
-
-		err := js.Encode(d)
+		err := js.Encode(hm)
 		if err != nil {
 			//TODO
 			return nil
 		}
+
 	} else {
-		for i, v := range hostMap.Hosts {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(i), v.CopyRemotes()))
+		for _, v := range hm {
+			err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, v.RemoteAddrs))
 			if err != nil {
 				return err
 			}
@@ -389,8 +371,26 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 		return nil
 	}
 
+	type lighthouseInfo struct {
+		VpnIP net.IP    `json:"vpnIp"`
+		Addrs *CacheMap `json:"addrs"`
+	}
+
 	lightHouse.RLock()
-	defer lightHouse.RUnlock()
+	addrMap := make([]lighthouseInfo, len(lightHouse.addrMap))
+	x := 0
+	for k, v := range lightHouse.addrMap {
+		addrMap[x] = lighthouseInfo{
+			VpnIP: int2ip(k),
+			Addrs: v.CopyCache(),
+		}
+		x++
+	}
+	lightHouse.RUnlock()
+
+	sort.Slice(addrMap, func(i, j int) bool {
+		return bytes.Compare(addrMap[i].VpnIP, addrMap[j].VpnIP) < 0
+	})
 
 	if fs.Json || fs.Pretty {
 		js := json.NewEncoder(w.GetWriter())
@@ -398,27 +398,19 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr
 			js.SetIndent("", "    ")
 		}
 
-		d := make([]m, len(lightHouse.addrMap))
-		x := 0
-		var h m
-		for vpnIp, v := range lightHouse.addrMap {
-			h = m{
-				"vpnIp": int2ip(vpnIp),
-				"addrs": TransformLHReplyToUdpAddrs(v),
-			}
-
-			d[x] = h
-			x++
-		}
-
-		err := js.Encode(d)
+		err := js.Encode(addrMap)
 		if err != nil {
 			//TODO
 			return nil
 		}
+
 	} else {
-		for vpnIp, v := range lightHouse.addrMap {
-			err := w.WriteLine(fmt.Sprintf("%s: %s", int2ip(vpnIp), TransformLHReplyToUdpAddrs(v)))
+		for _, v := range addrMap {
+			b, err := json.Marshal(v.Addrs)
+			if err != nil {
+				return err
+			}
+			err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIP, string(b)))
 			if err != nil {
 				return err
 			}
@@ -469,8 +461,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	ips, _ := ifce.lightHouse.Query(vpnIp, ifce)
-	return json.NewEncoder(w.GetWriter()).Encode(ips)
+	return json.NewEncoder(w.GetWriter()).Encode(ifce.lightHouse.Query(vpnIp, ifce).CopyCache())
 }
 
 func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error {
@@ -727,7 +718,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIP(uint32(vpnIp))
+	hostInfo, err := ifce.hostMap.QueryVpnIP(vpnIp)
 	if err != nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
@@ -737,7 +728,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		enc.SetIndent("", "    ")
 	}
 
-	return enc.Encode(hostInfo)
+	return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
 }
 
 func sshReload(fs interface{}, a []string, w sshd.StringWriter) error {

+ 1 - 3
tun_tester.go

@@ -41,9 +41,7 @@ func newTunFromFd(_ *logrus.Logger, _ int, _ *net.IPNet, _ int, _ []route, _ []r
 // These are unencrypted ip layer frames destined for another nebula node.
 // packets should exit the udp side, capture them with udpConn.Get
 func (c *Tun) Send(packet []byte) {
-	if c.l.Level >= logrus.DebugLevel {
-		c.l.Debug("Tun injecting packet")
-	}
+	c.l.WithField("dataLen", len(packet)).Info("Tun receiving injected packet")
 	c.rxPackets <- packet
 }
 

+ 3 - 3
udp_all.go

@@ -13,8 +13,8 @@ type udpAddr struct {
 }
 
 func NewUDPAddr(ip net.IP, port uint16) *udpAddr {
-	addr := udpAddr{IP: make([]byte, len(ip)), Port: port}
-	copy(addr.IP, ip)
+	addr := udpAddr{IP: make([]byte, net.IPv6len), Port: port}
+	copy(addr.IP, ip.To16())
 	return &addr
 }
 
@@ -22,7 +22,7 @@ func NewUDPAddrFromString(s string) *udpAddr {
 	ip, port, err := parseIPAndPort(s)
 	//TODO: handle err
 	_ = err
-	return &udpAddr{IP: ip, Port: port}
+	return &udpAddr{IP: ip.To16(), Port: port}
 }
 
 func (ua *udpAddr) Equals(t *udpAddr) bool {

+ 9 - 28
udp_linux.go

@@ -97,40 +97,21 @@ func (u *udpConn) GetSendBuffer() (int, error) {
 }
 
 func (u *udpConn) LocalAddr() (*udpAddr, error) {
-	var rsa unix.RawSockaddrAny
-	var rLen = unix.SizeofSockaddrAny
-
-	_, _, err := unix.Syscall(
-		unix.SYS_GETSOCKNAME,
-		uintptr(u.sysFd),
-		uintptr(unsafe.Pointer(&rsa)),
-		uintptr(unsafe.Pointer(&rLen)),
-	)
-
-	if err != 0 {
+	sa, err := unix.Getsockname(u.sysFd)
+	if err != nil {
 		return nil, err
 	}
 
 	addr := &udpAddr{}
-	if rsa.Addr.Family == unix.AF_INET {
-		pp := (*unix.RawSockaddrInet4)(unsafe.Pointer(&rsa))
-		addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
-		copy(addr.IP, pp.Addr[:])
-
-	} else if rsa.Addr.Family == unix.AF_INET6 {
-		//TODO: this cast sucks and we can do better
-		pp := (*unix.RawSockaddrInet6)(unsafe.Pointer(&rsa))
-		addr.Port = uint16(rsa.Addr.Data[0])<<8 + uint16(rsa.Addr.Data[1])
-		copy(addr.IP, pp.Addr[:])
-
-	} else {
-		addr.Port = 0
-		addr.IP = []byte{}
+	switch sa := sa.(type) {
+	case *unix.SockaddrInet4:
+		addr.IP = net.IP{sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]}.To16()
+		addr.Port = uint16(sa.Port)
+	case *unix.SockaddrInet6:
+		addr.IP = sa.Addr[0:]
+		addr.Port = uint16(sa.Port)
 	}
 
-	//TODO: Just use this instead?
-	//a, b := unix.Getsockname(u.sysFd)
-
 	return addr, nil
 }
 

+ 9 - 1
udp_tester.go

@@ -3,6 +3,7 @@
 package nebula
 
 import (
+	"fmt"
 	"net"
 
 	"github.com/sirupsen/logrus"
@@ -53,7 +54,14 @@ func NewListener(l *logrus.Logger, ip string, port int, _ bool) (*udpConn, error
 // this is an encrypted packet or a handshake message in most cases
 // packets were transmitted from another nebula node, you can send them with Tun.Send
 func (u *udpConn) Send(packet *UdpPacket) {
-	u.l.Infof("UDP injecting packet %+v", packet)
+	h := &Header{}
+	if err := h.Parse(packet.Data); err != nil {
+		panic(err)
+	}
+	u.l.WithField("header", h).
+		WithField("udpAddr", fmt.Sprintf("%v:%v", packet.FromIp, packet.FromPort)).
+		WithField("dataLen", len(packet.Data)).
+		Info("UDP receiving injected packet")
 	u.rxPackets <- packet
 }