瀏覽代碼

Pull hostmap and pending hostmap apart, remove unused functions (#843)

Nate Brown 2 年之前
父節點
當前提交
a10baeee92
共有 16 個文件被更改,包括 291 次插入294 次删除
  1. 12 8
      connection_manager_test.go
  2. 32 31
      control.go
  3. 5 5
      control_test.go
  4. 3 3
      control_tester.go
  5. 2 2
      dns_server.go
  6. 1 1
      handshake.go
  7. 1 1
      handshake_ix.go
  8. 115 36
      handshake_manager.go
  9. 4 4
      handshake_manager_test.go
  10. 52 135
      hostmap.go
  11. 14 14
      hostmap_test.go
  12. 4 7
      inside.go
  13. 2 2
      main.go
  14. 7 10
      outside.go
  15. 6 5
      relay_manager.go
  16. 31 30
      ssh.go

+ 12 - 8
connection_manager_test.go

@@ -42,7 +42,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 	cs := &CertState{
 		rawCertificate:      []byte{},
 		privateKey:          []byte{},
@@ -121,7 +121,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	preferredRanges := []*net.IPNet{localrange}
 
 	// Very incomplete mock objects
-	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 	cs := &CertState{
 		rawCertificate:      []byte{},
 		privateKey:          []byte{},
@@ -207,7 +207,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
 	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
 	preferredRanges := []*net.IPNet{localrange}
-	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 
 	// Generate keys for CA and peer's cert.
 	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
@@ -268,12 +268,16 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	punchy := NewPunchyFromConfig(l, config.NewC(l))
 	nc := newConnectionManager(ctx, l, ifce, 5, 10, punchy)
 	ifce.connectionManager = nc
-	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
-	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		peerCert:  &peerCert,
-		H:         &noise.HandshakeState{},
+
+	hostinfo := &HostInfo{
+		vpnIp: vpnIp,
+		ConnectionState: &ConnectionState{
+			certState: cs,
+			peerCert:  &peerCert,
+			H:         &noise.HandshakeState{},
+		},
 	}
+	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 	// Move ahead 45s.
 	// Check if to disconnect with invalid certificate.

+ 32 - 31
control.go

@@ -17,6 +17,15 @@ import (
 // Every interaction here needs to take extra care to copy memory and not return or use arguments "as is" when touching
 // core. This means copying IP objects, slices, de-referencing pointers and taking the actual value, etc
 
+type controlEach func(h *HostInfo)
+
+type controlHostLister interface {
+	QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo
+	ForEachIndex(each controlEach)
+	ForEachVpnIp(each controlEach)
+	GetPreferredRanges() []*net.IPNet
+}
+
 type Control struct {
 	f          *Interface
 	l          *logrus.Logger
@@ -98,7 +107,7 @@ func (c *Control) RebindUDPServer() {
 // ListHostmapHosts returns details about the actual or pending (handshaking) hostmap by vpn ip
 func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
 	if pendingMap {
-		return listHostMapHosts(c.f.handshakeManager.pendingHostMap)
+		return listHostMapHosts(c.f.handshakeManager)
 	} else {
 		return listHostMapHosts(c.f.hostMap)
 	}
@@ -107,7 +116,7 @@ func (c *Control) ListHostmapHosts(pendingMap bool) []ControlHostInfo {
 // ListHostmapIndexes returns details about the actual or pending (handshaking) hostmap by local index id
 func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 	if pendingMap {
-		return listHostMapIndexes(c.f.handshakeManager.pendingHostMap)
+		return listHostMapIndexes(c.f.handshakeManager)
 	} else {
 		return listHostMapIndexes(c.f.hostMap)
 	}
@@ -115,15 +124,15 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo {
 
 // GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found
 func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlHostInfo {
-	var hm *HostMap
+	var hl controlHostLister
 	if pending {
-		hm = c.f.handshakeManager.pendingHostMap
+		hl = c.f.handshakeManager
 	} else {
-		hm = c.f.hostMap
+		hl = c.f.hostMap
 	}
 
-	h, err := hm.QueryVpnIp(vpnIp)
-	if err != nil {
+	h := hl.QueryVpnIp(vpnIp)
+	if h == nil {
 		return nil
 	}
 
@@ -133,8 +142,8 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
 
 // SetRemoteForTunnel forces a tunnel to use a specific remote
 func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *ControlHostInfo {
-	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return nil
 	}
 
@@ -145,8 +154,8 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
 
 // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well.
 func (c *Control) CloseTunnel(vpnIp iputil.VpnIp, localOnly bool) bool {
-	hostInfo, err := c.f.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := c.f.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return false
 	}
 
@@ -241,28 +250,20 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 	return chi
 }
 
-func listHostMapHosts(hm *HostMap) []ControlHostInfo {
-	hm.RLock()
-	hosts := make([]ControlHostInfo, len(hm.Hosts))
-	i := 0
-	for _, v := range hm.Hosts {
-		hosts[i] = copyHostInfo(v, hm.preferredRanges)
-		i++
-	}
-	hm.RUnlock()
-
+func listHostMapHosts(hl controlHostLister) []ControlHostInfo {
+	hosts := make([]ControlHostInfo, 0)
+	pr := hl.GetPreferredRanges()
+	hl.ForEachVpnIp(func(hostinfo *HostInfo) {
+		hosts = append(hosts, copyHostInfo(hostinfo, pr))
+	})
 	return hosts
 }
 
-func listHostMapIndexes(hm *HostMap) []ControlHostInfo {
-	hm.RLock()
-	hosts := make([]ControlHostInfo, len(hm.Indexes))
-	i := 0
-	for _, v := range hm.Indexes {
-		hosts[i] = copyHostInfo(v, hm.preferredRanges)
-		i++
-	}
-	hm.RUnlock()
-
+func listHostMapIndexes(hl controlHostLister) []ControlHostInfo {
+	hosts := make([]ControlHostInfo, 0)
+	pr := hl.GetPreferredRanges()
+	hl.ForEachIndex(func(hostinfo *HostInfo) {
+		hosts = append(hosts, copyHostInfo(hostinfo, pr))
+	})
 	return hosts
 }

+ 5 - 5
control_test.go

@@ -18,7 +18,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	l := test.NewLogger()
 	// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
 	// To properly ensure we are not exposing core memory to the caller
-	hm := NewHostMap(l, "test", &net.IPNet{}, make([]*net.IPNet, 0))
+	hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
 	remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
 	remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
 	ipNet := net.IPNet{
@@ -50,7 +50,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	remotes := NewRemoteList(nil)
 	remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
 	remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
-	hm.Add(iputil.Ip2VpnIp(ipNet.IP), &HostInfo{
+	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
@@ -64,9 +64,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 			relayForByIp:  map[iputil.VpnIp]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
-	})
+	}, &Interface{})
 
-	hm.Add(iputil.Ip2VpnIp(ipNet2.IP), &HostInfo{
+	hm.unlockedAddHostInfo(&HostInfo{
 		remote:  remote1,
 		remotes: remotes,
 		ConnectionState: &ConnectionState{
@@ -80,7 +80,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 			relayForByIp:  map[iputil.VpnIp]*Relay{},
 			relayForByIdx: map[uint32]*Relay{},
 		},
-	})
+	}, &Interface{})
 
 	c := Control{
 		f: &Interface{

+ 3 - 3
control_tester.go

@@ -147,12 +147,12 @@ func (c *Control) GetUDPAddr() string {
 }
 
 func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
-	hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[iputil.Ip2VpnIp(vpnIp)]
-	if !ok {
+	hostinfo := c.f.handshakeManager.QueryVpnIp(iputil.Ip2VpnIp(vpnIp))
+	if hostinfo == nil {
 		return false
 	}
 
-	c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
+	c.f.handshakeManager.DeleteHostInfo(hostinfo)
 	return true
 }
 

+ 2 - 2
dns_server.go

@@ -47,8 +47,8 @@ func (d *dnsRecords) QueryCert(data string) string {
 		return ""
 	}
 	iip := iputil.Ip2VpnIp(ip)
-	hostinfo, err := d.hostMap.QueryVpnIp(iip)
-	if err != nil {
+	hostinfo := d.hostMap.QueryVpnIp(iip)
+	if hostinfo == nil {
 		return ""
 	}
 	q := hostinfo.GetCert()

+ 1 - 1
handshake.go

@@ -20,7 +20,7 @@ func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packe
 		case 1:
 			ixHandshakeStage1(f, addr, via, packet, h)
 		case 2:
-			newHostinfo, _ := f.handshakeManager.QueryIndex(h.RemoteIndex)
+			newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
 			tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
 			if tearDown && newHostinfo != nil {
 				f.handshakeManager.DeleteHostInfo(newHostinfo)

+ 1 - 1
handshake_ix.go

@@ -422,7 +422,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 			Info("Incorrect host responded to handshake")
 
 		// Release our old handshake from pending, it should not continue
-		f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
+		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
 		//TODO: this adds it to the timer wheel in a way that aggressively retries

+ 115 - 36
handshake_manager.go

@@ -7,6 +7,7 @@ import (
 	"encoding/binary"
 	"errors"
 	"net"
+	"sync"
 	"time"
 
 	"github.com/rcrowley/go-metrics"
@@ -42,7 +43,12 @@ type HandshakeConfig struct {
 }
 
 type HandshakeManager struct {
-	pendingHostMap         *HostMap
+	// Mutex for interacting with the vpnIps and indexes maps
+	sync.RWMutex
+
+	vpnIps  map[iputil.VpnIp]*HostInfo
+	indexes map[uint32]*HostInfo
+
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
 	outside                udp.Conn
@@ -59,7 +65,8 @@ type HandshakeManager struct {
 
 func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges []*net.IPNet, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		pendingHostMap:         NewHostMap(l, "pending", tunCidr, preferredRanges),
+		vpnIps:                 map[iputil.VpnIp]*HostInfo{},
+		indexes:                map[uint32]*HostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
@@ -101,8 +108,8 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f EncWr
 }
 
 func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, lighthouseTriggered bool) {
-	hostinfo, err := c.pendingHostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostinfo := c.QueryVpnIp(vpnIp)
+	if hostinfo == nil {
 		return
 	}
 	hostinfo.Lock()
@@ -111,7 +118,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
 	if hostinfo.HandshakeComplete {
 		// Ensure we don't exist in the pending hostmap anymore since we have completed
-		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		c.DeleteHostInfo(hostinfo)
 		return
 	}
 
@@ -125,14 +132,14 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 
 	// If we are out of time, clean up
 	if hostinfo.HandshakeCounter >= c.config.retries {
-		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)).
+		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("remoteIndex", hostinfo.remoteIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
 			Info("Handshake timed out")
 		c.metricTimedOut.Inc(1)
-		c.pendingHostMap.DeleteHostInfo(hostinfo)
+		c.DeleteHostInfo(hostinfo)
 		return
 	}
 
@@ -144,7 +151,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
 	}
 
-	remotes := hostinfo.remotes.CopyAddrs(c.pendingHostMap.preferredRanges)
+	remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
 	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
@@ -168,9 +175,9 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(c.pendingHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+	hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
 		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-		err = c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+		err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
 			hostinfo.logger(c.l).WithField("udpAddr", addr).
 				WithField("initiatorIndex", hostinfo.localIndexId).
@@ -204,9 +211,9 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 			if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
 				continue
 			}
-			relayHostInfo, err := c.mainHostMap.QueryVpnIp(*relay)
-			if err != nil || relayHostInfo.remote == nil {
-				hostinfo.logger(c.l).WithError(err).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
+			relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
+			if relayHostInfo == nil || relayHostInfo.remote == nil {
+				hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
 				f.Handshake(*relay)
 				continue
 			}
@@ -289,14 +296,35 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 	}
 }
 
+// AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
 func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
-	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
+	// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
+	c.Lock()
+	defer c.Unlock()
+
+	if hostinfo, ok := c.vpnIps[vpnIp]; ok {
+		// We are already tracking this vpn ip
+		return hostinfo
+	}
 
-	if created {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
-		c.metricInitiated.Inc(1)
+	hostinfo := &HostInfo{
+		vpnIp:           vpnIp,
+		HandshakePacket: make(map[uint8][]byte, 0),
+		relayState: RelayState{
+			relays:        map[iputil.VpnIp]struct{}{},
+			relayForByIp:  map[iputil.VpnIp]*Relay{},
+			relayForByIdx: map[uint32]*Relay{},
+		},
 	}
 
+	if init != nil {
+		init(hostinfo)
+	}
+
+	c.vpnIps[vpnIp] = hostinfo
+	c.metricInitiated.Inc(1)
+	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+
 	return hostinfo
 }
 
@@ -318,8 +346,8 @@ var (
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
 func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
-	c.pendingHostMap.Lock()
-	defer c.pendingHostMap.Unlock()
+	c.Lock()
+	defer c.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 
@@ -350,7 +378,7 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingIndex, found = c.pendingHostMap.Indexes[hostinfo.localIndexId]
+	existingIndex, found = c.indexes[hostinfo.localIndexId]
 	if found && existingIndex != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
@@ -373,8 +401,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 // won't have a localIndexId collision because we already have an entry in the
 // pendingHostMap. An existing hostinfo is returned if there was one.
 func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
-	c.pendingHostMap.Lock()
-	defer c.pendingHostMap.Unlock()
+	c.Lock()
+	defer c.Unlock()
 	c.mainHostMap.Lock()
 	defer c.mainHostMap.Unlock()
 
@@ -388,7 +416,7 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 	}
 
 	// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
-	c.pendingHostMap.unlockedDeleteHostInfo(hostinfo)
+	c.unlockedDeleteHostInfo(hostinfo)
 	c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
 }
 
@@ -396,8 +424,8 @@ func (c *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 // and adds it to the pendingHostMap. Will error if we are unable to generate
 // a unique localIndexId
 func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
-	c.pendingHostMap.Lock()
-	defer c.pendingHostMap.Unlock()
+	c.Lock()
+	defer c.Unlock()
 	c.mainHostMap.RLock()
 	defer c.mainHostMap.RUnlock()
 
@@ -407,12 +435,12 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
 			return err
 		}
 
-		_, inPending := c.pendingHostMap.Indexes[index]
+		_, inPending := c.indexes[index]
 		_, inMain := c.mainHostMap.Indexes[index]
 
 		if !inMain && !inPending {
 			h.localIndexId = index
-			c.pendingHostMap.Indexes[index] = h
+			c.indexes[index] = h
 			return nil
 		}
 	}
@@ -420,22 +448,73 @@ func (c *HandshakeManager) AddIndexHostInfo(h *HostInfo) error {
 	return errors.New("failed to generate unique localIndexId")
 }
 
-func (c *HandshakeManager) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
-	c.pendingHostMap.addRemoteIndexHostInfo(index, h)
+func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
+	c.Lock()
+	defer c.Unlock()
+	c.unlockedDeleteHostInfo(hostinfo)
 }
 
-func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
-	//l.Debugln("Deleting pending hostinfo :", hostinfo)
-	c.pendingHostMap.DeleteHostInfo(hostinfo)
+func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
+	delete(c.vpnIps, hostinfo.vpnIp)
+	if len(c.vpnIps) == 0 {
+		c.vpnIps = map[iputil.VpnIp]*HostInfo{}
+	}
+
+	delete(c.indexes, hostinfo.localIndexId)
+	if len(c.vpnIps) == 0 {
+		c.indexes = map[uint32]*HostInfo{}
+	}
+
+	if c.l.Level >= logrus.DebugLevel {
+		c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
+			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
+			Debug("Pending hostmap hostInfo deleted")
+	}
+}
+
+func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	c.RLock()
+	defer c.RUnlock()
+	return c.vpnIps[vpnIp]
+}
+
+func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
+	c.RLock()
+	defer c.RUnlock()
+	return c.indexes[index]
+}
+
+func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
+	return c.mainHostMap.preferredRanges
 }
 
-func (c *HandshakeManager) QueryIndex(index uint32) (*HostInfo, error) {
-	return c.pendingHostMap.QueryIndex(index)
+func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
+	c.RLock()
+	defer c.RUnlock()
+
+	for _, v := range c.vpnIps {
+		f(v)
+	}
+}
+
+func (c *HandshakeManager) ForEachIndex(f controlEach) {
+	c.RLock()
+	defer c.RUnlock()
+
+	for _, v := range c.indexes {
+		f(v)
+	}
 }
 
 func (c *HandshakeManager) EmitStats() {
-	c.pendingHostMap.EmitStats("pending")
-	c.mainHostMap.EmitStats("main")
+	c.RLock()
+	hostLen := len(c.vpnIps)
+	indexLen := len(c.indexes)
+	c.RUnlock()
+
+	metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
+	metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
+	c.mainHostMap.EmitStats()
 }
 
 // Utility functions below

+ 4 - 4
handshake_manager_test.go

@@ -20,7 +20,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
 	preferredRanges := []*net.IPNet{localrange}
 	mw := &mockEncWriter{}
-	mainHM := NewHostMap(l, "test", vpncidr, preferredRanges)
+	mainHM := NewHostMap(l, vpncidr, preferredRanges)
 	lh := newTestLighthouse()
 
 	blah := NewHandshakeManager(l, tuncidr, preferredRanges, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
@@ -48,7 +48,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.Len(t, mainHM.Hosts, 0)
 
 	// Confirm they are in the pending index list
-	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
+	assert.Contains(t, blah.vpnIps, ip)
 
 	// Jump ahead `HandshakeRetries` ticks, offset by one to get the sleep logic right
 	for i := 1; i <= DefaultHandshakeRetries+1; i++ {
@@ -57,13 +57,13 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	}
 
 	// Confirm they are still in the pending index list
-	assert.Contains(t, blah.pendingHostMap.Hosts, ip)
+	assert.Contains(t, blah.vpnIps, ip)
 
 	// Tick 1 more time, a minute will certainly flush it out
 	blah.NextOutboundHandshakeTimerTick(now.Add(time.Minute), mw)
 
 	// Confirm they have been removed
-	assert.NotContains(t, blah.pendingHostMap.Hosts, ip)
+	assert.NotContains(t, blah.vpnIps, ip)
 }
 
 func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {

+ 52 - 135
hostmap.go

@@ -2,7 +2,6 @@ package nebula
 
 import (
 	"errors"
-	"fmt"
 	"net"
 	"sync"
 	"sync/atomic"
@@ -52,7 +51,6 @@ type Relay struct {
 
 type HostMap struct {
 	sync.RWMutex    //Because we concurrently read and write to our maps
-	name            string
 	Indexes         map[uint32]*HostInfo
 	Relays          map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
 	RemoteIndexes   map[uint32]*HostInfo
@@ -203,13 +201,13 @@ type HostInfo struct {
 	remotes              *RemoteList
 	promoteCounter       atomic.Uint32
 	ConnectionState      *ConnectionState
-	handshakeStart       time.Time        //todo: this an entry in the handshake manager
-	HandshakeReady       bool             //todo: being in the manager means you are ready
-	HandshakeCounter     int              //todo: another handshake manager entry
-	HandshakeLastRemotes []*udp.Addr      //todo: another handshake manager entry, which remotes we sent to last time
-	HandshakeComplete    bool             //todo: this should go away in favor of ConnectionState.ready
-	HandshakePacket      map[uint8][]byte //todo: this is other handshake manager entry
-	packetStore          []*cachedPacket  //todo: this is other handshake manager entry
+	handshakeStart       time.Time   //todo: this an entry in the handshake manager
+	HandshakeReady       bool        //todo: being in the manager means you are ready
+	HandshakeCounter     int         //todo: another handshake manager entry
+	HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
+	HandshakeComplete    bool        //todo: this should go away in favor of ConnectionState.ready
+	HandshakePacket      map[uint8][]byte
+	packetStore          []*cachedPacket //todo: this is other handshake manager entry
 	remoteIndexId        uint32
 	localIndexId         uint32
 	vpnIp                iputil.VpnIp
@@ -255,13 +253,12 @@ type cachedPacketMetrics struct {
 	dropped metrics.Counter
 }
 
-func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
+func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
 	h := map[iputil.VpnIp]*HostInfo{}
 	i := map[uint32]*HostInfo{}
 	r := map[uint32]*HostInfo{}
 	relays := map[uint32]*HostInfo{}
 	m := HostMap{
-		name:            name,
 		Indexes:         i,
 		Relays:          relays,
 		RemoteIndexes:   r,
@@ -273,8 +270,8 @@ func NewHostMap(l *logrus.Logger, name string, vpnCIDR *net.IPNet, preferredRang
 	return &m
 }
 
-// UpdateStats takes a name and reports host and index counts to the stats collection system
-func (hm *HostMap) EmitStats(name string) {
+// EmitStats reports host, index, and relay counts to the stats collection system
+func (hm *HostMap) EmitStats() {
 	hm.RLock()
 	hostLen := len(hm.Hosts)
 	indexLen := len(hm.Indexes)
@@ -282,10 +279,10 @@ func (hm *HostMap) EmitStats(name string) {
 	relaysLen := len(hm.Relays)
 	hm.RUnlock()
 
-	metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
-	metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
-	metrics.GetOrRegisterGauge("hostmap."+name+".remoteIndexes", nil).Update(int64(remoteIndexLen))
-	metrics.GetOrRegisterGauge("hostmap."+name+".relayIndexes", nil).Update(int64(relaysLen))
+	metrics.GetOrRegisterGauge("hostmap.main.hosts", nil).Update(int64(hostLen))
+	metrics.GetOrRegisterGauge("hostmap.main.indexes", nil).Update(int64(indexLen))
+	metrics.GetOrRegisterGauge("hostmap.main.remoteIndexes", nil).Update(int64(remoteIndexLen))
+	metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen))
 }
 
 func (hm *HostMap) RemoveRelay(localIdx uint32) {
@@ -299,88 +296,6 @@ func (hm *HostMap) RemoveRelay(localIdx uint32) {
 	hm.Unlock()
 }
 
-func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
-	hm.RLock()
-	if i, ok := hm.Hosts[vpnIp]; ok {
-		index := i.localIndexId
-		hm.RUnlock()
-		return index, nil
-	}
-	hm.RUnlock()
-	return 0, errors.New("vpn IP not found")
-}
-
-func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
-	hm.Lock()
-	hm.Hosts[ip] = hostinfo
-	hm.Unlock()
-}
-
-func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
-	hm.RLock()
-	if h, ok := hm.Hosts[vpnIp]; !ok {
-		hm.RUnlock()
-		h = &HostInfo{
-			vpnIp:           vpnIp,
-			HandshakePacket: make(map[uint8][]byte, 0),
-			relayState: RelayState{
-				relays:        map[iputil.VpnIp]struct{}{},
-				relayForByIp:  map[iputil.VpnIp]*Relay{},
-				relayForByIdx: map[uint32]*Relay{},
-			},
-		}
-		if init != nil {
-			init(h)
-		}
-		hm.Lock()
-		hm.Hosts[vpnIp] = h
-		hm.Unlock()
-		return h, true
-	} else {
-		hm.RUnlock()
-		return h, false
-	}
-}
-
-// Only used by pendingHostMap when the remote index is not initially known
-func (hm *HostMap) addRemoteIndexHostInfo(index uint32, h *HostInfo) {
-	hm.Lock()
-	h.remoteIndexId = index
-	hm.RemoteIndexes[index] = h
-	hm.Unlock()
-
-	if hm.l.Level > logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
-			"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": h.vpnIp}}).
-			Debug("Hostmap remoteIndex added")
-	}
-}
-
-// DeleteReverseIndex is used to clean up on recv_error
-// This function should only ever be called on the pending hostmap
-func (hm *HostMap) DeleteReverseIndex(index uint32) {
-	hm.Lock()
-	hostinfo, ok := hm.RemoteIndexes[index]
-	if ok {
-		delete(hm.Indexes, hostinfo.localIndexId)
-		delete(hm.RemoteIndexes, index)
-
-		// Check if we have an entry under hostId that matches the same hostinfo
-		// instance. Clean it up as well if we do (they might not match in pendingHostmap)
-		var hostinfo2 *HostInfo
-		hostinfo2, ok = hm.Hosts[hostinfo.vpnIp]
-		if ok && hostinfo2 == hostinfo {
-			delete(hm.Hosts, hostinfo.vpnIp)
-		}
-	}
-	hm.Unlock()
-
-	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
-			Debug("Hostmap remote index deleted")
-	}
-}
-
 // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip
 func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	// Delete the host itself, ensuring it's not modified anymore
@@ -393,12 +308,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	return final
 }
 
-func (hm *HostMap) DeleteRelayIdx(localIdx uint32) {
-	hm.Lock()
-	defer hm.Unlock()
-	delete(hm.RemoteIndexes, localIdx)
-}
-
 func (hm *HostMap) MakePrimary(hostinfo *HostInfo) {
 	hm.Lock()
 	defer hm.Unlock()
@@ -476,7 +385,7 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "mapTotalSize": len(hm.Hosts),
+		hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts),
 			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
@@ -486,55 +395,41 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 }
 
-func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
+func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Indexes[index]; ok {
 		hm.RUnlock()
-		return h, nil
+		return h
 	} else {
 		hm.RUnlock()
-		return nil, errors.New("unable to find index")
+		return nil
 	}
 }
 
-// Retrieves a HostInfo by Index. Returns whether the HostInfo is primary at time of query.
-// This helper exists so that the hostinfo.prev pointer can be read while the hostmap lock is held.
-func (hm *HostMap) QueryIndexIsPrimary(index uint32) (*HostInfo, bool, error) {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
-	hm.RLock()
-	if h, ok := hm.Indexes[index]; ok {
-		hm.RUnlock()
-		return h, h.prev == nil, nil
-	} else {
-		hm.RUnlock()
-		return nil, false, errors.New("unable to find index")
-	}
-}
-func (hm *HostMap) QueryRelayIndex(index uint32) (*HostInfo, error) {
+func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
 	//TODO: we probably just want to return bool instead of error, or at least a static error
 	hm.RLock()
 	if h, ok := hm.Relays[index]; ok {
 		hm.RUnlock()
-		return h, nil
+		return h
 	} else {
 		hm.RUnlock()
-		return nil, errors.New("unable to find index")
+		return nil
 	}
 }
 
-func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
+func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.RemoteIndexes[index]; ok {
 		hm.RUnlock()
-		return h, nil
+		return h
 	} else {
 		hm.RUnlock()
-		return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
+		return nil
 	}
 }
 
-func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
+func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 	return hm.queryVpnIp(vpnIp, nil)
 }
 
@@ -558,11 +453,11 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*Host
 
 // PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
 // `PromoteEvery` calls to this function for a given host.
-func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
+func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) *HostInfo {
 	return hm.queryVpnIp(vpnIp, ifce)
 }
 
-func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*HostInfo, error) {
+func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostInfo {
 	hm.RLock()
 	if h, ok := hm.Hosts[vpnIp]; ok {
 		hm.RUnlock()
@@ -570,12 +465,12 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) (*Host
 		if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
 			h.TryPromoteBest(hm.preferredRanges, promoteIfce)
 		}
-		return h, nil
+		return h
 
 	}
 
 	hm.RUnlock()
-	return nil, errors.New("unable to find host")
+	return nil
 }
 
 // unlockedAddHostInfo assumes you have a write-lock and will add a hostinfo object to the hostmap Indexes and RemoteIndexes maps.
@@ -598,7 +493,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo
 
 	if hm.l.Level >= logrus.DebugLevel {
-		hm.l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
+		hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts),
 			"hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}).
 			Debug("Hostmap vpnIp added")
 	}
@@ -614,6 +509,28 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
 	}
 }
 
+func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
+	return hm.preferredRanges
+}
+
+func (hm *HostMap) ForEachVpnIp(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
+
+	for _, v := range hm.Hosts {
+		f(v)
+	}
+}
+
+func (hm *HostMap) ForEachIndex(f controlEach) {
+	hm.RLock()
+	defer hm.RUnlock()
+
+	for _, v := range hm.Indexes {
+		f(v)
+	}
+}
+
 // TryPromoteBest handles re-querying lighthouses and probing for better paths
 // NOTE: It is an error to call this if you are a lighthouse since they should not roam clients!
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {

+ 14 - 14
hostmap_test.go

@@ -11,7 +11,7 @@ import (
 func TestHostMap_MakePrimary(t *testing.T) {
 	l := test.NewLogger()
 	hm := NewHostMap(
-		l, "test",
+		l,
 		&net.IPNet{
 			IP:   net.IP{10, 0, 0, 1},
 			Mask: net.IPMask{255, 255, 255, 0},
@@ -32,7 +32,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.unlockedAddHostInfo(h1, f)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4
-	prim, _ := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(1)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -47,7 +47,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h3)
 
 	// Make sure we go h3 -> h1 -> h2 -> h4
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h3.localIndexId, prim.localIndexId)
 	assert.Equal(t, h1.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -62,7 +62,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -77,7 +77,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 	hm.MakePrimary(h4)
 
 	// Make sure we go h4 -> h3 -> h1 -> h2
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -92,7 +92,7 @@ func TestHostMap_MakePrimary(t *testing.T) {
 func TestHostMap_DeleteHostInfo(t *testing.T) {
 	l := test.NewLogger()
 	hm := NewHostMap(
-		l, "test",
+		l,
 		&net.IPNet{
 			IP:   net.IP{10, 0, 0, 1},
 			Mask: net.IPMask{255, 255, 255, 0},
@@ -119,11 +119,11 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	// h6 should be deleted
 	assert.Nil(t, h6.next)
 	assert.Nil(t, h6.prev)
-	_, err := hm.QueryIndex(h6.localIndexId)
-	assert.Error(t, err)
+	h := hm.QueryIndex(h6.localIndexId)
+	assert.Nil(t, h)
 
 	// Make sure we go h1 -> h2 -> h3 -> h4 -> h5
-	prim, _ := hm.QueryVpnIp(1)
+	prim := hm.QueryVpnIp(1)
 	assert.Equal(t, h1.localIndexId, prim.localIndexId)
 	assert.Equal(t, h2.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -142,7 +142,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h1.next)
 
 	// Make sure we go h2 -> h3 -> h4 -> h5
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h3.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -160,7 +160,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h3.next)
 
 	// Make sure we go h2 -> h4 -> h5
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -176,7 +176,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h5.next)
 
 	// Make sure we go h2 -> h4
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h2.localIndexId, prim.localIndexId)
 	assert.Equal(t, h4.localIndexId, prim.next.localIndexId)
 	assert.Nil(t, prim.prev)
@@ -190,7 +190,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h2.next)
 
 	// Make sure we only have h4
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Equal(t, h4.localIndexId, prim.localIndexId)
 	assert.Nil(t, prim.prev)
 	assert.Nil(t, prim.next)
@@ -202,6 +202,6 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
 	assert.Nil(t, h4.next)
 
 	// Make sure we have nil
-	prim, _ = hm.QueryVpnIp(1)
+	prim = hm.QueryVpnIp(1)
 	assert.Nil(t, prim)
 }

+ 4 - 7
inside.go

@@ -121,14 +121,10 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 			return nil
 		}
 	}
-	hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
 
-	//if err != nil || hostinfo.ConnectionState == nil {
-	if err != nil {
-		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
-		if err != nil {
-			hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
-		}
+	hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
+	if hostinfo == nil {
+		hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
 	}
 	ci := hostinfo.ConnectionState
 
@@ -137,6 +133,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 	}
 
 	// Handshake is not ready, we need to grab the lock now before we start the handshake process
+	//TODO: move this to handshake manager
 	hostinfo.Lock()
 	defer hostinfo.Unlock()
 

+ 2 - 2
main.go

@@ -212,7 +212,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	}
 
-	hostMap := NewHostMap(l, "main", tunCidr, preferredRanges)
+	hostMap := NewHostMap(l, tunCidr, preferredRanges)
 	hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)
 
 	l.
@@ -339,7 +339,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	//TODO: check if we _should_ be emitting stats
 	go ifce.emitStats(ctx, c.GetDuration("stats.interval", time.Second*10))
 
-	attachCommands(l, c, ssh, hostMap, handshakeManager.pendingHostMap, lightHouse, ifce)
+	attachCommands(l, c, ssh, ifce)
 
 	// Start DNS server last to allow using the nebula IP as lighthouse.dns.host
 	var dnsStart func()

+ 7 - 10
outside.go

@@ -64,9 +64,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 	var hostinfo *HostInfo
 	// verify if we've seen this index before, otherwise respond to the handshake initiation
 	if h.Type == header.Message && h.Subtype == header.MessageRelay {
-		hostinfo, _ = f.hostMap.QueryRelayIndex(h.RemoteIndex)
+		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
 	} else {
-		hostinfo, _ = f.hostMap.QueryIndex(h.RemoteIndex)
+		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
 	}
 
 	var ci *ConnectionState
@@ -449,12 +449,9 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 			Debug("Recv error received")
 	}
 
-	// First, clean up in the pending hostmap
-	f.handshakeManager.pendingHostMap.DeleteReverseIndex(h.RemoteIndex)
-
-	hostinfo, err := f.hostMap.QueryReverseIndex(h.RemoteIndex)
-	if err != nil {
-		f.l.Debugln(err, ": ", h.RemoteIndex)
+	hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
+	if hostinfo == nil {
+		f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
 		return
 	}
 
@@ -464,14 +461,14 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 	if !hostinfo.RecvErrorExceeded() {
 		return
 	}
+
 	if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
 		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
 		return
 	}
 
 	f.closeTunnel(hostinfo)
-	// We also delete it from pending hostmap to allow for
-	// fast reconnect.
+	// We also delete it from pending hostmap to allow for fast reconnect.
 	f.handshakeManager.DeleteHostInfo(hostinfo)
 }
 

+ 6 - 5
relay_manager.go

@@ -131,9 +131,9 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		return
 	}
 	// I'm the middle man. Let the initiator know that the I've established the relay they requested.
-	peerHostInfo, err := rm.hostmap.QueryVpnIp(relay.PeerIp)
-	if err != nil {
-		rm.l.WithError(err).WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
+	peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp)
+	if peerHostInfo == nil {
+		rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer")
 		return
 	}
 	peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(target)
@@ -240,8 +240,8 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		if !rm.GetAmRelay() {
 			return
 		}
-		peer, err := rm.hostmap.QueryVpnIp(target)
-		if err != nil {
+		peer := rm.hostmap.QueryVpnIp(target)
+		if peer == nil {
 			// Try to establish a connection to this host. If we get a future relay request,
 			// we'll be ready!
 			f.getOrHandshake(target)
@@ -253,6 +253,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		}
 		sendCreateRequest := false
 		var index uint32
+		var err error
 		targetRelay, ok := peer.relayState.QueryRelayForByIp(from)
 		if ok {
 			index = targetRelay.LocalIndex

+ 31 - 30
ssh.go

@@ -3,6 +3,7 @@ package nebula
 import (
 	"bytes"
 	"encoding/json"
+	"errors"
 	"flag"
 	"fmt"
 	"io/ioutil"
@@ -168,7 +169,7 @@ func configSSH(l *logrus.Logger, ssh *sshd.SSHServer, c *config.C) (func(), erro
 	return runner, nil
 }
 
-func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap *HostMap, pendingHostMap *HostMap, lightHouse *LightHouse, ifce *Interface) {
+func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Interface) {
 	ssh.RegisterCommand(&sshd.Command{
 		Name:             "list-hostmap",
 		ShortDescription: "List all known previously connected hosts",
@@ -181,7 +182,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshListHostMap(hostMap, fs, w)
+			return sshListHostMap(f.hostMap, fs, w)
 		},
 	})
 
@@ -197,7 +198,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshListHostMap(pendingHostMap, fs, w)
+			return sshListHostMap(f.handshakeManager, fs, w)
 		},
 	})
 
@@ -212,7 +213,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshListLighthouseMap(lightHouse, fs, w)
+			return sshListLighthouseMap(f.lightHouse, fs, w)
 		},
 	})
 
@@ -277,7 +278,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 		Name:             "version",
 		ShortDescription: "Prints the currently running version of nebula",
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshVersion(ifce, fs, a, w)
+			return sshVersion(f, fs, a, w)
 		},
 	})
 
@@ -293,7 +294,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshPrintCert(ifce, fs, a, w)
+			return sshPrintCert(f, fs, a, w)
 		},
 	})
 
@@ -307,7 +308,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshPrintTunnel(ifce, fs, a, w)
+			return sshPrintTunnel(f, fs, a, w)
 		},
 	})
 
@@ -321,7 +322,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshPrintRelays(ifce, fs, a, w)
+			return sshPrintRelays(f, fs, a, w)
 		},
 	})
 
@@ -335,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshChangeRemote(ifce, fs, a, w)
+			return sshChangeRemote(f, fs, a, w)
 		},
 	})
 
@@ -349,7 +350,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshCloseTunnel(ifce, fs, a, w)
+			return sshCloseTunnel(f, fs, a, w)
 		},
 	})
 
@@ -364,7 +365,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 			return fl, &s
 		},
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshCreateTunnel(ifce, fs, a, w)
+			return sshCreateTunnel(f, fs, a, w)
 		},
 	})
 
@@ -373,12 +374,12 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, hostMap
 		ShortDescription: "Query the lighthouses for the provided vpn ip",
 		Help:             "This command is asynchronous. Only currently known udp ips will be printed.",
 		Callback: func(fs interface{}, a []string, w sshd.StringWriter) error {
-			return sshQueryLighthouse(ifce, fs, a, w)
+			return sshQueryLighthouse(f, fs, a, w)
 		},
 	})
 }
 
-func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error {
+func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) error {
 	fs, ok := a.(*sshListHostMapFlags)
 	if !ok {
 		//TODO: error
@@ -387,9 +388,9 @@ func sshListHostMap(hostMap *HostMap, a interface{}, w sshd.StringWriter) error
 
 	var hm []ControlHostInfo
 	if fs.ByIndex {
-		hm = listHostMapIndexes(hostMap)
+		hm = listHostMapIndexes(hl)
 	} else {
-		hm = listHostMapHosts(hostMap)
+		hm = listHostMapHosts(hl)
 	}
 
 	sort.Slice(hm, func(i, j int) bool {
@@ -546,8 +547,8 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 
@@ -588,12 +589,12 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, _ := ifce.hostMap.QueryVpnIp(vpnIp)
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already exists"))
 	}
 
-	hostInfo, _ = ifce.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
+	hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp)
 	if hostInfo != nil {
 		return w.WriteLine(fmt.Sprintf("Tunnel already handshaking"))
 	}
@@ -645,8 +646,8 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}
 
@@ -765,8 +766,8 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 			return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 		}
 
-		hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-		if err != nil {
+		hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+		if hostInfo == nil {
 			return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 		}
 
@@ -851,9 +852,9 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 	for k, v := range relays {
 		ro := RelayOutput{NebulaIp: v.vpnIp}
 		co.Relays = append(co.Relays, &ro)
-		relayHI, err := ifce.hostMap.QueryVpnIp(v.vpnIp)
-		if err != nil {
-			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: err})
+		relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp)
+		if relayHI == nil {
+			ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")})
 			continue
 		}
 		for _, vpnIp := range relayHI.relayState.CopyRelayForIps() {
@@ -889,8 +890,8 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 					rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k)
 				}
 			}
-			relayedHI, err := ifce.hostMap.QueryVpnIp(vpnIp)
-			if err == nil {
+			relayedHI := ifce.hostMap.QueryVpnIp(vpnIp)
+			if relayedHI != nil {
 				rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...)
 			}
 
@@ -925,8 +926,8 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
 		return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0]))
 	}
 
-	hostInfo, err := ifce.hostMap.QueryVpnIp(vpnIp)
-	if err != nil {
+	hostInfo := ifce.hostMap.QueryVpnIp(vpnIp)
+	if hostInfo == nil {
 		return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0]))
 	}