Browse Source

create ConnectionState before adding to HostMap (#535)

We have a few small race conditions with creating the HostInfo.ConnectionState
since we add the host info to the pendingHostMap before we set this
field. We can make everything a lot easier if we just add an "init"
function so that we can set this field in the hostinfo before we add it
to the hostmap.
Wade Simmons 3 năm trước cách đây
mục cha
commit
304b12f63f
7 tập tin đã thay đổi với 42 bổ sung31 xóa
  1. 3 1
      CHANGELOG.md
  2. 3 3
      connection_manager_test.go
  3. 7 7
      handshake_manager.go
  4. 14 2
      handshake_manager_test.go
  5. 7 6
      hostmap.go
  6. 7 11
      inside.go
  7. 1 1
      ssh.go

+ 3 - 1
CHANGELOG.md

@@ -53,10 +53,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
   will immediately switch to a preferred remote address after the reception of
   a handshake packet (instead of waiting until 1,000 packets have been sent).
   (#532)
-  
+
 - A race condition when `punchy.respond` is enabled and ensures the correct
   vpn ip is sent a punch back response in highly queried node. (#566)
 
+- Fix a rare crash during handshake due to a race condition. (#535)
+
 ## [1.4.0] - 2021-05-11
 
 ### Added

+ 3 - 3
connection_manager_test.go

@@ -57,7 +57,7 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	out := make([]byte, mtu)
 	nc.HandleMonitorTick(now, p, nb, out)
 	// Add an ip we have established a connection w/ to hostmap
-	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
+	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		H:         &noise.HandshakeState{},
@@ -126,7 +126,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	out := make([]byte, mtu)
 	nc.HandleMonitorTick(now, p, nb, out)
 	// Add an ip we have established a connection w/ to hostmap
-	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
+	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		H:         &noise.HandshakeState{},
@@ -232,7 +232,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	defer cancel()
 	nc := newConnectionManager(ctx, l, ifce, 5, 10)
 	ifce.connectionManager = nc
-	hostinfo := nc.hostMap.AddVpnIp(vpnIp)
+	hostinfo, _ := nc.hostMap.AddVpnIp(vpnIp, nil)
 	hostinfo.ConnectionState = &ConnectionState{
 		certState: cs,
 		peerCert:  &peerCert,

+ 7 - 7
handshake_manager.go

@@ -191,13 +191,13 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f udp.EncWriter, l
 	}
 }
 
-func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
-	hostinfo := c.pendingHostMap.AddVpnIp(vpnIp)
-	// We lock here and use an array to insert items to prevent locking the
-	// main receive thread for very long by waiting to add items to the pending map
-	//TODO: what lock?
-	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
-	c.metricInitiated.Inc(1)
+func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
+	hostinfo, created := c.pendingHostMap.AddVpnIp(vpnIp, init)
+
+	if created {
+		c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
+		c.metricInitiated.Inc(1)
+	}
 
 	return hostinfo
 }

+ 14 - 2
handshake_manager_test.go

@@ -27,7 +27,19 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
-	i := blah.AddVpnIp(ip)
+	var initCalled bool
+	initFunc := func(*HostInfo) {
+		initCalled = true
+	}
+
+	i := blah.AddVpnIp(ip, initFunc)
+	assert.True(t, initCalled)
+
+	initCalled = false
+	i2 := blah.AddVpnIp(ip, initFunc)
+	assert.False(t, initCalled)
+	assert.Same(t, i, i2)
+
 	i.remotes = NewRemoteList()
 	i.HandshakeReady = true
 
@@ -71,7 +83,7 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 
 	assert.Equal(t, 0, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 
-	hi := blah.AddVpnIp(ip)
+	hi := blah.AddVpnIp(ip, nil)
 	hi.HandshakeReady = true
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 	assert.Equal(t, 0, hi.HandshakeCounter, "Should not have attempted a handshake yet")

+ 7 - 6
hostmap.go

@@ -134,24 +134,25 @@ func (hm *HostMap) Add(ip iputil.VpnIp, hostinfo *HostInfo) {
 	hm.Unlock()
 }
 
-func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
-	h := &HostInfo{}
+func (hm *HostMap) AddVpnIp(vpnIp iputil.VpnIp, init func(hostinfo *HostInfo)) (hostinfo *HostInfo, created bool) {
 	hm.RLock()
-	if _, ok := hm.Hosts[vpnIp]; !ok {
+	if h, ok := hm.Hosts[vpnIp]; !ok {
 		hm.RUnlock()
 		h = &HostInfo{
 			promoteCounter:  0,
 			vpnIp:           vpnIp,
 			HandshakePacket: make(map[uint8][]byte, 0),
 		}
+		if init != nil {
+			init(h)
+		}
 		hm.Lock()
 		hm.Hosts[vpnIp] = h
 		hm.Unlock()
-		return h
+		return h, true
 	} else {
-		h = hm.Hosts[vpnIp]
 		hm.RUnlock()
-		return h
+		return h, false
 	}
 }
 

+ 7 - 11
inside.go

@@ -83,7 +83,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 	if err != nil {
 		hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
 		if err != nil {
-			hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
+			hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
 		}
 	}
 	ci := hostinfo.ConnectionState
@@ -102,16 +102,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 		return hostinfo
 	}
 
-	if ci == nil {
-		// if we don't have a connection state, then send a handshake initiation
-		ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
-		// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
-		//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
-		hostinfo.ConnectionState = ci
-	} else if ci.eKey == nil {
-		// if we don't have any state at all, create it
-	}
-
 	// If we have already created the handshake packet, we don't want to call the function at all.
 	if !hostinfo.HandshakeReady {
 		ixHandshakeStage0(f, vpnIp, hostinfo)
@@ -131,6 +121,12 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 	return hostinfo
 }
 
+// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
+// will create the initial Noise ConnectionState
+func (f *Interface) initHostInfo(hostinfo *HostInfo) {
+	hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
+}
+
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
 	fp := &firewall.Packet{}
 	err := newPacket(p, false, fp)

+ 1 - 1
ssh.go

@@ -569,7 +569,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		}
 	}
 
-	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
+	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
 	if addr != nil {
 		hostInfo.SetRemote(addr)
 	}