Bläddra i källkod

Clean up a hostinfo to reduce memory usage (#955)

Nate Brown 1 år sedan
förälder
incheckning
a44e1b8b05
10 ändrade filer med 241 tillägg och 266 borttagningar
  1. 0 3
      connection_state.go
  2. 0 2
      control.go
  3. 1 2
      control_test.go
  4. 0 31
      handshake.go
  5. 43 53
      handshake_ix.go
  6. 164 87
      handshake_manager.go
  7. 10 1
      handshake_manager_test.go
  8. 17 78
      hostmap.go
  9. 5 5
      inside.go
  10. 1 4
      outside.go

+ 0 - 3
connection_state.go

@@ -24,7 +24,6 @@ type ConnectionState struct {
 	messageCounter atomic.Uint64
 	window         *Bits
 	writeLock      sync.Mutex
-	ready          bool
 }
 
 func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
@@ -71,7 +70,6 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
 		H:         hs,
 		initiator: initiator,
 		window:    b,
-		ready:     false,
 		myCert:    certState.Certificate,
 	}
 
@@ -83,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
 		"certificate":     cs.peerCert,
 		"initiator":       cs.initiator,
 		"message_counter": cs.messageCounter.Load(),
-		"ready":           cs.ready,
 	})
 }

+ 0 - 2
control.go

@@ -41,7 +41,6 @@ type ControlHostInfo struct {
 	LocalIndex             uint32                  `json:"localIndex"`
 	RemoteIndex            uint32                  `json:"remoteIndex"`
 	RemoteAddrs            []*udp.Addr             `json:"remoteAddrs"`
-	CachedPackets          int                     `json:"cachedPackets"`
 	Cert                   *cert.NebulaCertificate `json:"cert"`
 	MessageCounter         uint64                  `json:"messageCounter"`
 	CurrentRemote          *udp.Addr               `json:"currentRemote"`
@@ -234,7 +233,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
 		LocalIndex:             h.localIndexId,
 		RemoteIndex:            h.remoteIndexId,
 		RemoteAddrs:            h.remotes.CopyAddrs(preferredRanges),
-		CachedPackets:          len(h.packetStore),
 		CurrentRelaysToMe:      h.relayState.CopyRelayIps(),
 		CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
 	}

+ 1 - 2
control_test.go

@@ -96,7 +96,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 		LocalIndex:             201,
 		RemoteIndex:            200,
 		RemoteAddrs:            []*udp.Addr{remote2, remote1},
-		CachedPackets:          0,
 		Cert:                   crt.Copy(),
 		MessageCounter:         0,
 		CurrentRemote:          udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
@@ -105,7 +104,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
 	}
 
 	// Make sure we don't have any unexpected fields
-	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
+	assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
 	test.AssertDeepCopyEqual(t, &expectedInfo, thi)
 
 	// Make sure we don't panic if the host info doesn't have a cert yet

+ 0 - 31
handshake.go

@@ -1,31 +0,0 @@
-package nebula
-
-import (
-	"github.com/slackhq/nebula/header"
-	"github.com/slackhq/nebula/udp"
-)
-
-func HandleIncomingHandshake(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H, hostinfo *HostInfo) {
-	// First remote allow list check before we know the vpnIp
-	if addr != nil {
-		if !f.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
-			f.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
-			return
-		}
-	}
-
-	switch h.Subtype {
-	case header.HandshakeIXPSK0:
-		switch h.MessageCounter {
-		case 1:
-			ixHandshakeStage1(f, addr, via, packet, h)
-		case 2:
-			newHostinfo := f.handshakeManager.QueryIndex(h.RemoteIndex)
-			tearDown := ixHandshakeStage2(f, addr, via, newHostinfo, packet, h)
-			if tearDown && newHostinfo != nil {
-				f.handshakeManager.DeleteHostInfo(newHostinfo)
-			}
-		}
-	}
-
-}

+ 43 - 53
handshake_ix.go

@@ -4,6 +4,7 @@ import (
 	"time"
 
 	"github.com/flynn/noise"
+	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
@@ -13,20 +14,20 @@ import (
 
 // This function constructs a handshake packet, but does not actually send it
 // Sending is done by the handshake manager
-func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
-	err := f.handshakeManager.allocateIndex(hostinfo)
+func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
+	err := f.handshakeManager.allocateIndex(hh)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
 		return false
 	}
 
 	certState := f.pki.GetCertState()
 	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
-	hostinfo.ConnectionState = ci
+	hh.hostinfo.ConnectionState = ci
 
 	hsProto := &NebulaHandshakeDetails{
-		InitiatorIndex: hostinfo.localIndexId,
+		InitiatorIndex: hh.hostinfo.localIndexId,
 		Time:           uint64(time.Now().UnixNano()),
 		Cert:           certState.RawCertificateNoKey,
 	}
@@ -39,7 +40,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
 	hsBytes, err = hs.Marshal()
 
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
 		return false
 	}
@@ -49,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
 
 	msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
 	if err != nil {
-		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
+		f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
 			WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
 		return false
 	}
@@ -58,9 +59,8 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
 	// handshake packet 1 from the responder
 	ci.window.Update(f.l, 1)
 
-	hostinfo.HandshakePacket[0] = msg
-	hostinfo.HandshakeReady = true
-	hostinfo.handshakeStart = time.Now()
+	hh.hostinfo.HandshakePacket[0] = msg
+	hh.ready = true
 	return true
 }
 
@@ -140,9 +140,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		},
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -208,19 +205,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 	if err != nil {
 		switch err {
 		case ErrAlreadySeen:
-			// Update remote if preferred (Note we have to switch to locking
-			// the existing hostinfo, and then switch back so the defer Unlock
-			// higher in this function still works)
-			hostinfo.Unlock()
-			existing.Lock()
 			// Update remote if preferred
 			if existing.SetRemoteIfPreferred(f.hostMap, addr) {
 				// Send a test packet to ensure the other side has also switched to
 				// the preferred remote
 				f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
 			}
-			existing.Unlock()
-			hostinfo.Lock()
 
 			msg = existing.HandshakePacket[2]
 			f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
@@ -307,7 +297,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 				WithField("issuer", issuer).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-				WithField("sentCachedPackets", len(hostinfo.packetStore)).
 				Info("Handshake message sent")
 		}
 	} else {
@@ -323,25 +312,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 			WithField("issuer", issuer).
 			WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 			WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
-			WithField("sentCachedPackets", len(hostinfo.packetStore)).
 			Info("Handshake message sent")
 	}
 
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
+	hostinfo.ConnectionState.messageCounter.Store(2)
+	hostinfo.remotes.ResetBlockedRemotes()
 
 	return
 }
 
-func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
-	if hostinfo == nil {
+func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
+	if hh == nil {
 		// Nothing here to tear down, got a bogus stage 2 packet
 		return true
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
+	hh.Lock()
+	defer hh.Unlock()
 
+	hostinfo := hh.hostinfo
 	if addr != nil {
 		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
 			f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
@@ -350,22 +340,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	}
 
 	ci := hostinfo.ConnectionState
-	if ci.ready {
-		f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
-			WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
-			Info("Handshake is already complete")
-
-		// Update remote if preferred
-		if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
-			// Send a test packet to ensure the other side has also switched to
-			// the preferred remote
-			f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
-		}
-
-		// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
-		return false
-	}
-
 	msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
 	if err != nil {
 		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
@@ -422,22 +396,22 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		f.handshakeManager.DeleteHostInfo(hostinfo)
 
 		// Create a new hostinfo/handshake for the intended vpn ip
-		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
+		f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
 			//TODO: this doesnt know if its being added or is being used for caching a packet
 			// Block the current used address
-			newHostInfo.remotes = hostinfo.remotes
-			newHostInfo.remotes.BlockRemote(addr)
+			newHH.hostinfo.remotes = hostinfo.remotes
+			newHH.hostinfo.remotes.BlockRemote(addr)
 
 			// Get the correct remote list for the host we did handshake with
 			hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)
 
-			f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
-				WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
+			f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
+				WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
 				Info("Blocked addresses for handshakes")
 
 			// Swap the packet store to benefit the original intended recipient
-			newHostInfo.packetStore = hostinfo.packetStore
-			hostinfo.packetStore = []*cachedPacket{}
+			newHH.packetStore = hh.packetStore
+			hh.packetStore = []*cachedPacket{}
 
 			// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
 			hostinfo.vpnIp = vpnIp
@@ -450,7 +424,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	// Mark packet 2 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 2)
 
-	duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
+	duration := time.Since(hh.startTime).Nanoseconds()
 	f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
 		WithField("certName", certName).
 		WithField("fingerprint", fingerprint).
@@ -458,7 +432,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 		WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 		WithField("durationNs", duration).
-		WithField("sentCachedPackets", len(hostinfo.packetStore)).
+		WithField("sentCachedPackets", len(hh.packetStore)).
 		Info("Handshake message received")
 
 	hostinfo.remoteIndexId = hs.Details.ResponderIndex
@@ -482,7 +456,23 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 	// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
 	f.handshakeManager.Complete(hostinfo, f)
 	f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
-	hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
+
+	hostinfo.ConnectionState.messageCounter.Store(2)
+
+	if f.l.Level >= logrus.DebugLevel {
+		hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
+	}
+
+	if len(hh.packetStore) > 0 {
+		nb := make([]byte, 12, 12)
+		out := make([]byte, mtu)
+		for _, cp := range hh.packetStore {
+			cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
+		}
+		f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
+	}
+
+	hostinfo.remotes.ResetBlockedRemotes()
 	f.metricHandshakes.Update(duration)
 
 	return false

+ 164 - 87
handshake_manager.go

@@ -46,8 +46,8 @@ type HandshakeManager struct {
 	// Mutex for interacting with the vpnIps and indexes maps
 	sync.RWMutex
 
-	vpnIps  map[iputil.VpnIp]*HostInfo
-	indexes map[uint32]*HostInfo
+	vpnIps  map[iputil.VpnIp]*HandshakeHostInfo
+	indexes map[uint32]*HandshakeHostInfo
 
 	mainHostMap            *HostMap
 	lightHouse             *LightHouse
@@ -64,10 +64,47 @@ type HandshakeManager struct {
 	trigger chan iputil.VpnIp
 }
 
+type HandshakeHostInfo struct {
+	sync.Mutex
+
+	startTime   time.Time       // Time that we first started trying with this handshake
+	ready       bool            // Is the handshake ready
+	counter     int             // How many attempts have we made so far
+	lastRemotes []*udp.Addr     // Remotes that we sent to during the previous attempt
+	packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
+
+	hostinfo *HostInfo
+}
+
+func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
+	if len(hh.packetStore) < 100 {
+		tempPacket := make([]byte, len(packet))
+		copy(tempPacket, packet)
+
+		hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
+		if l.Level >= logrus.DebugLevel {
+			hh.hostinfo.logger(l).
+				WithField("length", len(hh.packetStore)).
+				WithField("stored", true).
+				Debugf("Packet store")
+		}
+
+	} else {
+		m.dropped.Inc(1)
+
+		if l.Level >= logrus.DebugLevel {
+			hh.hostinfo.logger(l).
+				WithField("length", len(hh.packetStore)).
+				WithField("stored", false).
+				Debugf("Packet store")
+		}
+	}
+}
+
 func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
 	return &HandshakeManager{
-		vpnIps:                 map[iputil.VpnIp]*HostInfo{},
-		indexes:                map[uint32]*HostInfo{},
+		vpnIps:                 map[iputil.VpnIp]*HandshakeHostInfo{},
+		indexes:                map[uint32]*HandshakeHostInfo{},
 		mainHostMap:            mainHostMap,
 		lightHouse:             lightHouse,
 		outside:                outside,
@@ -97,6 +134,31 @@ func (c *HandshakeManager) Run(ctx context.Context) {
 	}
 }
 
+func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
+	// First remote allow list check before we know the vpnIp
+	if addr != nil {
+		if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
+			hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
+			return
+		}
+	}
+
+	switch h.Subtype {
+	case header.HandshakeIXPSK0:
+		switch h.MessageCounter {
+		case 1:
+			ixHandshakeStage1(hm.f, addr, via, packet, h)
+
+		case 2:
+			newHostinfo := hm.queryIndex(h.RemoteIndex)
+			tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
+			if tearDown && newHostinfo != nil {
+				hm.DeleteHostInfo(newHostinfo.hostinfo)
+			}
+		}
+	}
+}
+
 func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	c.OutboundHandshakeTimer.Advance(now)
 	for {
@@ -108,41 +170,35 @@ func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
 	}
 }
 
-func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
-	hostinfo := c.QueryVpnIp(vpnIp)
-	if hostinfo == nil {
-		return
-	}
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
-	// We may have raced to completion but now that we have a lock we should ensure we have not yet completed.
-	if hostinfo.HandshakeComplete {
-		// Ensure we don't exist in the pending hostmap anymore since we have completed
-		c.DeleteHostInfo(hostinfo)
+func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
+	hh := hm.queryVpnIp(vpnIp)
+	if hh == nil {
 		return
 	}
+	hh.Lock()
+	defer hh.Unlock()
 
+	hostinfo := hh.hostinfo
 	// If we are out of time, clean up
-	if hostinfo.HandshakeCounter >= c.config.retries {
-		hostinfo.logger(c.l).WithField("udpAddrs", hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)).
-			WithField("initiatorIndex", hostinfo.localIndexId).
-			WithField("remoteIndex", hostinfo.remoteIndexId).
+	if hh.counter >= hm.config.retries {
+		hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
+			WithField("initiatorIndex", hh.hostinfo.localIndexId).
+			WithField("remoteIndex", hh.hostinfo.remoteIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-			WithField("durationNs", time.Since(hostinfo.handshakeStart).Nanoseconds()).
+			WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
 			Info("Handshake timed out")
-		c.metricTimedOut.Inc(1)
-		c.DeleteHostInfo(hostinfo)
+		hm.metricTimedOut.Inc(1)
+		hm.DeleteHostInfo(hostinfo)
 		return
 	}
 
 	// Increment the counter to increase our delay, linear backoff
-	hostinfo.HandshakeCounter++
+	hh.counter++
 
 	// Check if we have a handshake packet to transmit yet
-	if !hostinfo.HandshakeReady {
-		if !ixHandshakeStage0(c.f, hostinfo) {
-			c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+	if !hh.ready {
+		if !ixHandshakeStage0(hm.f, hh) {
+			hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
 			return
 		}
 	}
@@ -152,11 +208,11 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 	// NB ^ This comment doesn't jive. It's how the thing gets initialized.
 	// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
 	if hostinfo.remotes == nil {
-		hostinfo.remotes = c.lightHouse.QueryCache(vpnIp)
+		hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
 	}
 
-	remotes := hostinfo.remotes.CopyAddrs(c.mainHostMap.preferredRanges)
-	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hostinfo.HandshakeLastRemotes)
+	remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
+	remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
 
 	// We only care about a lighthouse trigger if we have new remotes to send to.
 	// This is a very specific optimization for a fast lighthouse reply.
@@ -165,25 +221,25 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 		return
 	}
 
-	hostinfo.HandshakeLastRemotes = remotes
+	hh.lastRemotes = remotes
 
 	// TODO: this will generate a load of queries for hosts with only 1 ip
 	// (such as ones registered to the lighthouse with only a private IP)
 	// So we only do it one time after attempting 5 handshakes already.
-	if len(remotes) <= 1 && hostinfo.HandshakeCounter == 5 {
+	if len(remotes) <= 1 && hh.counter == 5 {
 		// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
 		// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
 		// the learned public ip for them. Query again to short circuit the promotion counter
-		c.lightHouse.QueryServer(vpnIp, c.f)
+		hm.lightHouse.QueryServer(vpnIp, hm.f)
 	}
 
 	// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
 	var sentTo []*udp.Addr
-	hostinfo.remotes.ForEach(c.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
-		c.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
-		err := c.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
+	hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
+		hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
+		err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
 		if err != nil {
-			hostinfo.logger(c.l).WithField("udpAddr", addr).
+			hostinfo.logger(hm.l).WithField("udpAddr", addr).
 				WithField("initiatorIndex", hostinfo.localIndexId).
 				WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 				WithError(err).Error("Failed to send handshake message")
@@ -196,63 +252,63 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 	// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
 	// so only log when the list of remotes has changed
 	if remotesHaveChanged {
-		hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Info("Handshake message sent")
-	} else if c.l.IsLevelEnabled(logrus.DebugLevel) {
-		hostinfo.logger(c.l).WithField("udpAddrs", sentTo).
+	} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
+		hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
 			WithField("initiatorIndex", hostinfo.localIndexId).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
 			Debug("Handshake message sent")
 	}
 
-	if c.config.useRelays && len(hostinfo.remotes.relays) > 0 {
-		hostinfo.logger(c.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
+	if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
+		hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
 		// Send a RelayRequest to all known Relay IP's
 		for _, relay := range hostinfo.remotes.relays {
 			// Don't relay to myself, and don't relay through the host I'm trying to connect to
-			if *relay == vpnIp || *relay == c.lightHouse.myVpnIp {
+			if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
 				continue
 			}
-			relayHostInfo := c.mainHostMap.QueryVpnIp(*relay)
+			relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
 			if relayHostInfo == nil || relayHostInfo.remote == nil {
-				hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
-				c.f.Handshake(*relay)
+				hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
+				hm.f.Handshake(*relay)
 				continue
 			}
 			// Check the relay HostInfo to see if we already established a relay through it
 			if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
 				switch existingRelay.State {
 				case Established:
-					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Send handshake via relay")
-					c.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
+					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
+					hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
 				case Requested:
-					hostinfo.logger(c.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
+					hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
 					// Re-send the CreateRelay request, in case the previous one was lost.
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: existingRelay.LocalIndex,
-						RelayFromIp:         uint32(c.lightHouse.myVpnIp),
+						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
 						RelayToIp:           uint32(vpnIp),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
-						hostinfo.logger(c.l).
+						hostinfo.logger(hm.l).
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
 						// This must send over the hostinfo, not over hm.Hosts[ip]
-						c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						c.l.WithFields(logrus.Fields{
-							"relayFrom":           c.lightHouse.myVpnIp,
+						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						hm.l.WithFields(logrus.Fields{
+							"relayFrom":           hm.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": existingRelay.LocalIndex,
 							"relay":               *relay}).
 							Info("send CreateRelayRequest")
 					}
 				default:
-					hostinfo.logger(c.l).
+					hostinfo.logger(hm.l).
 						WithField("vpnIp", vpnIp).
 						WithField("state", existingRelay.State).
 						WithField("relay", relayHostInfo.vpnIp).
@@ -261,26 +317,26 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 			} else {
 				// No relays exist or requested yet.
 				if relayHostInfo.remote != nil {
-					idx, err := AddRelay(c.l, relayHostInfo, c.mainHostMap, vpnIp, nil, TerminalType, Requested)
+					idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
 					if err != nil {
-						hostinfo.logger(c.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
+						hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
 					}
 
 					m := NebulaControl{
 						Type:                NebulaControl_CreateRelayRequest,
 						InitiatorRelayIndex: idx,
-						RelayFromIp:         uint32(c.lightHouse.myVpnIp),
+						RelayFromIp:         uint32(hm.lightHouse.myVpnIp),
 						RelayToIp:           uint32(vpnIp),
 					}
 					msg, err := m.Marshal()
 					if err != nil {
-						hostinfo.logger(c.l).
+						hostinfo.logger(hm.l).
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						c.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
-						c.l.WithFields(logrus.Fields{
-							"relayFrom":           c.lightHouse.myVpnIp,
+						hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+						hm.l.WithFields(logrus.Fields{
+							"relayFrom":           hm.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
 							"initiatorRelayIndex": idx,
 							"relay":               *relay}).
@@ -293,13 +349,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggere
 
 	// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
 	if !lighthouseTriggered {
-		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval*time.Duration(hostinfo.HandshakeCounter))
+		hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
 	}
 }
 
 // GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
 // The 2nd argument will be true if the hostinfo is ready to transmit traffic
-func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) (*HostInfo, bool) {
+func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	// Check the main hostmap and maintain a read lock if our host is not there
 	hm.mainHostMap.RLock()
 	if h, ok := hm.mainHostMap.Hosts[vpnIp]; ok {
@@ -316,16 +372,16 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos
 }
 
 // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
-func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HostInfo)) *HostInfo {
+func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
 	hm.Lock()
 
-	if hostinfo, ok := hm.vpnIps[vpnIp]; ok {
+	if hh, ok := hm.vpnIps[vpnIp]; ok {
 		// We are already trying to handshake with this vpn ip
 		if cacheCb != nil {
-			cacheCb(hostinfo)
+			cacheCb(hh)
 		}
 		hm.Unlock()
-		return hostinfo
+		return hh.hostinfo
 	}
 
 	hostinfo := &HostInfo{
@@ -338,12 +394,16 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Hos
 		},
 	}
 
-	hm.vpnIps[vpnIp] = hostinfo
+	hh := &HandshakeHostInfo{
+		hostinfo:  hostinfo,
+		startTime: time.Now(),
+	}
+	hm.vpnIps[vpnIp] = hh
 	hm.metricInitiated.Inc(1)
 	hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
 
 	if cacheCb != nil {
-		cacheCb(hostinfo)
+		cacheCb(hh)
 	}
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
@@ -416,8 +476,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		return existingIndex, ErrLocalIndexCollision
 	}
 
-	existingIndex, found = c.indexes[hostinfo.localIndexId]
-	if found && existingIndex != hostinfo {
+	existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
+	if found && existingPendingIndex.hostinfo != hostinfo {
 		// We have a collision, but for a different hostinfo
 		return existingIndex, ErrLocalIndexCollision
 	}
@@ -461,7 +521,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
 // allocateIndex generates a unique localIndexId for this HostInfo
 // and adds it to the pendingHostMap. Will error if we are unable to generate
 // a unique localIndexId
-func (hm *HandshakeManager) allocateIndex(h *HostInfo) error {
+func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
 	hm.mainHostMap.RLock()
 	defer hm.mainHostMap.RUnlock()
 	hm.Lock()
@@ -477,8 +537,8 @@ func (hm *HandshakeManager) allocateIndex(h *HostInfo) error {
 		_, inMain := hm.mainHostMap.Indexes[index]
 
 		if !inMain && !inPending {
-			h.localIndexId = index
-			hm.indexes[index] = h
+			hh.hostinfo.localIndexId = index
+			hm.indexes[index] = hh
 			return nil
 		}
 	}
@@ -495,12 +555,12 @@ func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
 func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	delete(c.vpnIps, hostinfo.vpnIp)
 	if len(c.vpnIps) == 0 {
-		c.vpnIps = map[iputil.VpnIp]*HostInfo{}
+		c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
 	}
 
 	delete(c.indexes, hostinfo.localIndexId)
 	if len(c.vpnIps) == 0 {
-		c.indexes = map[uint32]*HostInfo{}
+		c.indexes = map[uint32]*HandshakeHostInfo{}
 	}
 
 	if c.l.Level >= logrus.DebugLevel {
@@ -510,16 +570,33 @@ func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 	}
 }
 
-func (c *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
-	c.RLock()
-	defer c.RUnlock()
-	return c.vpnIps[vpnIp]
+func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
+	hh := hm.queryVpnIp(vpnIp)
+	if hh != nil {
+		return hh.hostinfo
+	}
+	return nil
+
 }
 
-func (c *HandshakeManager) QueryIndex(index uint32) *HostInfo {
-	c.RLock()
-	defer c.RUnlock()
-	return c.indexes[index]
+func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
+	hm.RLock()
+	defer hm.RUnlock()
+	return hm.vpnIps[vpnIp]
+}
+
+func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo {
+	hh := hm.queryIndex(index)
+	if hh != nil {
+		return hh.hostinfo
+	}
+	return nil
+}
+
+func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
+	hm.RLock()
+	defer hm.RUnlock()
+	return hm.indexes[index]
 }
 
 func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
@@ -531,7 +608,7 @@ func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
 	defer c.RUnlock()
 
 	for _, v := range c.vpnIps {
-		f(v)
+		f(v.hostinfo)
 	}
 }
 
@@ -540,7 +617,7 @@ func (c *HandshakeManager) ForEachIndex(f controlEach) {
 	defer c.RUnlock()
 
 	for _, v := range c.indexes {
-		f(v)
+		f(v.hostinfo)
 	}
 }
 

+ 10 - 1
handshake_manager_test.go

@@ -5,6 +5,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/test"
@@ -21,7 +22,16 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	mainHM := NewHostMap(l, vpncidr, preferredRanges)
 	lh := newTestLighthouse()
 
+	cs := &CertState{
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
+	}
+
 	blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig)
+	blah.f = &Interface{handshakeManager: blah, pki: &PKI{}, l: l}
+	blah.f.pki.cs.Store(cs)
 
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now)
@@ -31,7 +41,6 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	assert.Same(t, i, i2)
 
 	i.remotes = NewRemoteList(nil)
-	i.HandshakeReady = true
 
 	// Adding something to pending should not affect the main hostmap
 	assert.Len(t, mainHM.Hosts, 0)

+ 17 - 78
hostmap.go

@@ -21,6 +21,7 @@ const defaultPromoteEvery = 1000       // Count of packets sent before we try mo
 const defaultReQueryEvery = 5000       // Count of packets sent before re-querying a hostinfo to the lighthouse
 const defaultReQueryWait = time.Minute // Minimum amount of seconds to wait before re-querying a hostinfo the lighthouse. Evaluated every ReQueryEvery
 const MaxRemotes = 10
+const maxRecvError = 4
 
 // MaxHostInfosPerVpnIp is the max number of hostinfos we will track for a given vpn ip
 // 5 allows for an initial handshake and each host pair re-handshaking twice
@@ -196,25 +197,20 @@ func (rs *RelayState) InsertRelay(ip iputil.VpnIp, idx uint32, r *Relay) {
 }
 
 type HostInfo struct {
-	sync.RWMutex
-
-	remote               *udp.Addr
-	remotes              *RemoteList
-	promoteCounter       atomic.Uint32
-	ConnectionState      *ConnectionState
-	handshakeStart       time.Time   //todo: this an entry in the handshake manager
-	HandshakeReady       bool        //todo: being in the manager means you are ready
-	HandshakeCounter     int         //todo: another handshake manager entry
-	HandshakeLastRemotes []*udp.Addr //todo: another handshake manager entry, which remotes we sent to last time
-	HandshakeComplete    bool        //todo: this should go away in favor of ConnectionState.ready
-	HandshakePacket      map[uint8][]byte
-	packetStore          []*cachedPacket //todo: this is other handshake manager entry
-	remoteIndexId        uint32
-	localIndexId         uint32
-	vpnIp                iputil.VpnIp
-	recvError            int
-	remoteCidr           *cidr.Tree4
-	relayState           RelayState
+	remote          *udp.Addr
+	remotes         *RemoteList
+	promoteCounter  atomic.Uint32
+	ConnectionState *ConnectionState
+	remoteIndexId   uint32
+	localIndexId    uint32
+	vpnIp           iputil.VpnIp
+	recvError       atomic.Uint32
+	remoteCidr      *cidr.Tree4
+	relayState      RelayState
+
+	// HandshakePacket records the packets used to create this hostinfo
+	// We need these to avoid replayed handshake packets creating new hostinfos which causes churn
+	HandshakePacket map[uint8][]byte
 
 	// nextLHQuery is the earliest we can ask the lighthouse for new information.
 	// This is used to limit lighthouse re-queries in chatty clients
@@ -412,7 +408,6 @@ func (hm *HostMap) QueryIndex(index uint32) *HostInfo {
 }
 
 func (hm *HostMap) QueryRelayIndex(index uint32) *HostInfo {
-	//TODO: we probably just want to return bool instead of error, or at least a static error
 	hm.RLock()
 	if h, ok := hm.Relays[index]; ok {
 		hm.RUnlock()
@@ -535,10 +530,7 @@ func (hm *HostMap) ForEachIndex(f controlEach) {
 func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
 	c := i.promoteCounter.Add(1)
 	if c%ifce.tryPromoteEvery.Load() == 0 {
-		// The lock here is currently protecting i.remote access
-		i.RLock()
 		remote := i.remote
-		i.RUnlock()
 
 		// return early if we are already on a preferred remote
 		if remote != nil {
@@ -573,58 +565,6 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface)
 	}
 }
 
-func (i *HostInfo) unlockedCachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
-	//TODO: return the error so we can log with more context
-	if len(i.packetStore) < 100 {
-		tempPacket := make([]byte, len(packet))
-		copy(tempPacket, packet)
-		//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
-		i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
-		if l.Level >= logrus.DebugLevel {
-			i.logger(l).
-				WithField("length", len(i.packetStore)).
-				WithField("stored", true).
-				Debugf("Packet store")
-		}
-
-	} else if l.Level >= logrus.DebugLevel {
-		m.dropped.Inc(1)
-		i.logger(l).
-			WithField("length", len(i.packetStore)).
-			WithField("stored", false).
-			Debugf("Packet store")
-	}
-}
-
-// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
-func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
-	//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
-	//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
-	//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
-
-	i.HandshakeComplete = true
-	//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
-	// Clamping it to 2 gets us out of the woods for now
-	i.ConnectionState.messageCounter.Store(2)
-
-	if l.Level >= logrus.DebugLevel {
-		i.logger(l).Debugf("Sending %d stored packets", len(i.packetStore))
-	}
-
-	if len(i.packetStore) > 0 {
-		nb := make([]byte, 12, 12)
-		out := make([]byte, mtu)
-		for _, cp := range i.packetStore {
-			cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
-		}
-		m.sent.Inc(int64(len(i.packetStore)))
-	}
-
-	i.remotes.ResetBlockedRemotes()
-	i.packetStore = make([]*cachedPacket, 0)
-	i.ConnectionState.ready = true
-}
-
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {
 	if i.ConnectionState != nil {
 		return i.ConnectionState.peerCert
@@ -681,9 +621,8 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
 }
 
 func (i *HostInfo) RecvErrorExceeded() bool {
-	if i.recvError < 3 {
-		i.recvError += 1
-		return false
+	if i.recvError.Add(1) >= maxRecvError {
+		return true
 	}
 	return true
 }

+ 5 - 5
inside.go

@@ -44,8 +44,8 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		return
 	}
 
-	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(h *HostInfo) {
-		h.unlockedCachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
+	hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) {
+		hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
 	})
 
 	if hostinfo == nil {
@@ -108,7 +108,7 @@ func (f *Interface) Handshake(vpnIp iputil.VpnIp) {
 
 // getOrHandshake returns nil if the vpnIp is not routable.
 // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
-func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(info *HostInfo)) (*HostInfo, bool) {
+func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
 	if !ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, vpnIp) {
 		vpnIp = f.inside.RouteFor(vpnIp)
 		if vpnIp == 0 {
@@ -143,8 +143,8 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
 func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte) {
-	hostInfo, ready := f.getOrHandshake(vpnIp, func(h *HostInfo) {
-		h.unlockedCachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
+	hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) {
+		hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 	})
 
 	if hostInfo == nil {

+ 1 - 4
outside.go

@@ -198,7 +198,7 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 	case header.Handshake:
 		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
-		HandleIncomingHandshake(f, addr, via, packet, h, hostinfo)
+		f.handshakeManager.HandleIncoming(addr, via, packet, h)
 		return
 
 	case header.RecvError:
@@ -455,9 +455,6 @@ func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
 		return
 	}
 
-	hostinfo.Lock()
-	defer hostinfo.Unlock()
-
 	if !hostinfo.RecvErrorExceeded() {
 		return
 	}