Browse Source

Rehandshaking (#838)

Co-authored-by: Brad Higgins <[email protected]>
Co-authored-by: Wade Simmons <[email protected]>
Nate Brown 2 years ago
parent
commit
03e4a7f988
15 changed files with 761 additions and 172 deletions
  1. 236 29
      connection_manager.go
  2. 4 4
      connection_manager_test.go
  3. 14 0
      control_tester.go
  4. 385 22
      e2e/handshakes_test.go
  5. 4 4
      e2e/helpers_test.go
  6. 5 5
      e2e/router/router.go
  7. 5 4
      handshake_manager.go
  8. 4 0
      handshake_manager_test.go
  9. 34 41
      hostmap.go
  10. 12 24
      inside.go
  11. 1 0
      interface.go
  12. 17 0
      lighthouse_test.go
  13. 12 13
      outside.go
  14. 1 1
      punchy.go
  15. 27 25
      relay_manager.go

+ 236 - 29
connection_manager.go

@@ -1,6 +1,7 @@
 package nebula
 
 import (
+	"bytes"
 	"context"
 	"sync"
 	"time"
@@ -8,9 +9,20 @@ import (
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/header"
+	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 )
 
+type trafficDecision int
+
+const (
+	doNothing     trafficDecision = 0
+	deleteTunnel  trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote
+	closeTunnel   trafficDecision = 2 // delete the hostinfo and notify the remote
+	swapPrimary   trafficDecision = 3
+	migrateRelays trafficDecision = 4
+)
+
 type connectionManager struct {
 	in     map[uint32]struct{}
 	inLock *sync.RWMutex
@@ -18,6 +30,10 @@ type connectionManager struct {
 	out     map[uint32]struct{}
 	outLock *sync.RWMutex
 
+	// relayUsed holds which relay localIndexs are in use
+	relayUsed     map[uint32]struct{}
+	relayUsedLock *sync.RWMutex
+
 	hostMap                 *HostMap
 	trafficTimer            *LockingTimerWheel[uint32]
 	intf                    *Interface
@@ -44,6 +60,8 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 		inLock:                  &sync.RWMutex{},
 		out:                     make(map[uint32]struct{}),
 		outLock:                 &sync.RWMutex{},
+		relayUsed:               make(map[uint32]struct{}),
+		relayUsedLock:           &sync.RWMutex{},
 		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),
 		intf:                    intf,
 		pendingDeletion:         make(map[uint32]struct{}),
@@ -84,6 +102,19 @@ func (n *connectionManager) Out(localIndex uint32) {
 	n.outLock.Unlock()
 }
 
+func (n *connectionManager) RelayUsed(localIndex uint32) {
+	n.relayUsedLock.RLock()
+	// If this already exists, return
+	if _, ok := n.relayUsed[localIndex]; ok {
+		n.relayUsedLock.RUnlock()
+		return
+	}
+	n.relayUsedLock.RUnlock()
+	n.relayUsedLock.Lock()
+	n.relayUsed[localIndex] = struct{}{}
+	n.relayUsedLock.Unlock()
+}
+
 // getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and
 // resets the state for this local index
 func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {
@@ -99,8 +130,15 @@ func (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bo
 }
 
 func (n *connectionManager) AddTrafficWatch(localIndex uint32) {
-	n.Out(localIndex)
+	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index
+	n.outLock.Lock()
+	if _, ok := n.out[localIndex]; ok {
+		n.outLock.Unlock()
+		return
+	}
+	n.out[localIndex] = struct{}{}
 	n.trafficTimer.Add(localIndex, n.checkInterval)
+	n.outLock.Unlock()
 }
 
 func (n *connectionManager) Start(ctx context.Context) {
@@ -136,18 +174,130 @@ func (n *connectionManager) Run(ctx context.Context) {
 }
 
 func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {
-	hostinfo, err := n.hostMap.QueryIndex(localIndex)
-	if err != nil {
+	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)
+
+	switch decision {
+	case deleteTunnel:
+		n.hostMap.DeleteHostInfo(hostinfo)
+
+	case closeTunnel:
+		n.intf.sendCloseTunnel(hostinfo)
+		n.intf.closeTunnel(hostinfo)
+
+	case swapPrimary:
+		n.swapPrimary(hostinfo, primary)
+
+	case migrateRelays:
+		n.migrateRelayUsed(hostinfo, primary)
+	}
+
+	n.resetRelayTrafficCheck(hostinfo)
+}
+
+func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {
+	if hostinfo != nil {
+		n.relayUsedLock.Lock()
+		defer n.relayUsedLock.Unlock()
+		// No need to migrate any relays, delete usage info now.
+		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {
+			delete(n.relayUsed, idx)
+		}
+	}
+}
+
+func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {
+	relayFor := oldhostinfo.relayState.CopyAllRelayFor()
+
+	for _, r := range relayFor {
+		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)
+
+		var index uint32
+		var relayFrom iputil.VpnIp
+		var relayTo iputil.VpnIp
+		switch {
+		case ok && existing.State == Established:
+			// This relay already exists in newhostinfo, then do nothing.
+			continue
+		case ok && existing.State == Requested:
+			// The relay exists in a Requested state; re-send the request
+			index = existing.LocalIndex
+			switch r.Type {
+			case TerminalType:
+				relayFrom = newhostinfo.vpnIp
+				relayTo = existing.PeerIp
+			case ForwardingType:
+				relayFrom = existing.PeerIp
+				relayTo = newhostinfo.vpnIp
+			default:
+				// should never happen
+			}
+		case !ok:
+			n.relayUsedLock.RLock()
+			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {
+				// The relay hasn't been used; don't migrate it.
+				n.relayUsedLock.RUnlock()
+				continue
+			}
+			n.relayUsedLock.RUnlock()
+			// The relay doesn't exist at all; create some relay state and send the request.
+			var err error
+			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)
+			if err != nil {
+				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")
+				continue
+			}
+			switch r.Type {
+			case TerminalType:
+				relayFrom = newhostinfo.vpnIp
+				relayTo = r.PeerIp
+			case ForwardingType:
+				relayFrom = r.PeerIp
+				relayTo = newhostinfo.vpnIp
+			default:
+				// should never happen
+			}
+		}
+
+		// Send a CreateRelayRequest to the peer.
+		req := NebulaControl{
+			Type:                NebulaControl_CreateRelayRequest,
+			InitiatorRelayIndex: index,
+			RelayFromIp:         uint32(relayFrom),
+			RelayToIp:           uint32(relayTo),
+		}
+		msg, err := req.Marshal()
+		if err != nil {
+			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")
+		} else {
+			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))
+			n.l.WithFields(logrus.Fields{
+				"relayFrom":           iputil.VpnIp(req.RelayFromIp),
+				"relayTo":             iputil.VpnIp(req.RelayToIp),
+				"initiatorRelayIndex": req.InitiatorRelayIndex,
+				"responderRelayIndex": req.ResponderRelayIndex,
+				"vpnIp":               newhostinfo.vpnIp}).
+				Info("send CreateRelayRequest")
+		}
+	}
+}
+
+func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {
+	n.hostMap.RLock()
+	defer n.hostMap.RUnlock()
+
+	hostinfo := n.hostMap.Indexes[localIndex]
+	if hostinfo == nil {
 		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")
 		delete(n.pendingDeletion, localIndex)
-		return
+		return doNothing, nil, nil
 	}
 
-	if n.handleInvalidCertificate(now, hostinfo) {
-		return
+	if n.isInvalidCertificate(now, hostinfo) {
+		delete(n.pendingDeletion, hostinfo.localIndexId)
+		return closeTunnel, hostinfo, nil
 	}
 
-	primary, _ := n.hostMap.QueryVpnIp(hostinfo.vpnIp)
+	primary := n.hostMap.Hosts[hostinfo.vpnIp]
 	mainHostInfo := true
 	if primary != nil && primary != hostinfo {
 		mainHostInfo = false
@@ -158,6 +308,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 
 	// A hostinfo is determined alive if there is incoming traffic
 	if inTraffic {
+		decision := doNothing
 		if n.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(n.l).
 				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).
@@ -165,11 +316,14 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 		}
 		delete(n.pendingDeletion, hostinfo.localIndexId)
 
-		if !mainHostInfo {
-			if hostinfo.vpnIp > n.intf.myVpnIp {
-				// We are receiving traffic on the non primary hostinfo and we really just want 1 tunnel. Make
-				// This the primary and prime the old primary hostinfo for testing
-				n.hostMap.MakePrimary(hostinfo)
+		if mainHostInfo {
+			n.tryRehandshake(hostinfo)
+		} else {
+			if n.shouldSwapPrimary(hostinfo, primary) {
+				decision = swapPrimary
+			} else {
+				// migrate the relays to the primary, if in use.
+				decision = migrateRelays
 			}
 		}
 
@@ -180,7 +334,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 			n.sendPunch(hostinfo)
 		}
 
-		return
+		return decision, hostinfo, primary
 	}
 
 	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {
@@ -189,22 +343,17 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).
 			Info("Tunnel status")
 
-		n.hostMap.DeleteHostInfo(hostinfo)
 		delete(n.pendingDeletion, hostinfo.localIndexId)
-		return
+		return deleteTunnel, hostinfo, nil
 	}
 
-	hostinfo.logger(n.l).
-		WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
-		Debug("Tunnel status")
-
 	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {
 		if !outTraffic {
 			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.
 			// Just maintain NAT state if configured to do so.
 			n.sendPunch(hostinfo)
 			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
-			return
+			return doNothing, nil, nil
 
 		}
 
@@ -218,22 +367,58 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte,
 		if n.intf.lightHouse.IsLighthouseIP(hostinfo.vpnIp) {
 			// We are sending traffic to the lighthouse, let recv_error sort out any issues instead of testing the tunnel
 			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)
-			return
+			return doNothing, nil, nil
+		}
+
+		if n.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(n.l).
+				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
+				Debug("Tunnel status")
 		}
 
 		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
-		n.intf.sendMessageToVpnIp(header.Test, header.TestRequest, hostinfo, p, nb, out)
+		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)
 
 	} else {
-		hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		if n.l.Level >= logrus.DebugLevel {
+			hostinfo.logger(n.l).Debugf("Hostinfo sadness")
+		}
 	}
 
 	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}
 	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)
+	return doNothing, nil, nil
+}
+
+func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
+
+	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.
+	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.
+	// Let's sort this out.
+
+	if current.vpnIp < n.intf.myVpnIp {
+		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.
+		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.
+		// The remotes vpn ip is lower than mine. I will not flip.
+		return false
+	}
+
+	certState := n.intf.certState.Load()
+	return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
 }
 
-// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
-func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
+func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
+	n.hostMap.Lock()
+	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.
+	if n.hostMap.Hosts[current.vpnIp] == primary {
+		n.hostMap.unlockedMakePrimary(current)
+	}
+	n.hostMap.Unlock()
+}
+
+// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and
+// the certificate is no longer valid
+func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {
 	if !n.intf.disconnectInvalid {
 		return false
 	}
@@ -253,10 +438,6 @@ func (n *connectionManager) handleInvalidCertificate(now time.Time, hostinfo *Ho
 		WithField("fingerprint", fingerprint).
 		Info("Remote certificate is no longer valid, tearing down the tunnel")
 
-	// Inform the remote and close the tunnel locally
-	n.intf.sendCloseTunnel(hostinfo)
-	n.intf.closeTunnel(hostinfo)
-	delete(n.pendingDeletion, hostinfo.localIndexId)
 	return true
 }
 
@@ -277,3 +458,29 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)
 	}
 }
+
+func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
+	certState := n.intf.certState.Load()
+	if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
+		return
+	}
+
+	n.l.WithField("vpnIp", hostinfo.vpnIp).
+		WithField("reason", "local certificate is not current").
+		Info("Re-handshaking with remote")
+
+	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
+	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
+	if !newHostinfo.HandshakeReady {
+		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
+	}
+
+	//If this is a static host, we don't need to wait for the HostQueryReply
+	//We can trigger the handshake right now
+	if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
+		select {
+		case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:
+		default:
+		}
+	}
+}

+ 4 - 4
connection_manager_test.go

@@ -279,13 +279,13 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	// Check if to disconnect with invalid certificate.
 	// Should be alive.
 	nextTick := now.Add(45 * time.Second)
-	destroyed := nc.handleInvalidCertificate(nextTick, hostinfo)
-	assert.False(t, destroyed)
+	invalid := nc.isInvalidCertificate(nextTick, hostinfo)
+	assert.False(t, invalid)
 
 	// Move ahead 61s.
 	// Check if to disconnect with invalid certificate.
 	// Should be disconnected.
 	nextTick = now.Add(61 * time.Second)
-	destroyed = nc.handleInvalidCertificate(nextTick, hostinfo)
-	assert.True(t, destroyed)
+	invalid = nc.isInvalidCertificate(nextTick, hostinfo)
+	assert.True(t, invalid)
 }

+ 14 - 0
control_tester.go

@@ -163,3 +163,17 @@ func (c *Control) GetHostmap() *HostMap {
 func (c *Control) GetCert() *cert.NebulaCertificate {
 	return c.f.certState.Load().certificate
 }
+
+func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
+	hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
+	ixHandshakeStage0(c.f, vpnIp, hostinfo)
+
+	// If this is a static host, we don't need to wait for the HostQueryReply
+	// We can trigger the handshake right now
+	if _, ok := c.f.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {
+		select {
+		case c.f.handshakeManager.trigger <- hostinfo.vpnIp:
+		default:
+		}
+	}
+}

+ 385 - 22
e2e/handshakes_test.go

@@ -4,6 +4,7 @@
 package e2e
 
 import (
+	"fmt"
 	"net"
 	"testing"
 	"time"
@@ -15,12 +16,13 @@ import (
 	"github.com/slackhq/nebula/iputil"
 	"github.com/slackhq/nebula/udp"
 	"github.com/stretchr/testify/assert"
+	"gopkg.in/yaml.v2"
 )
 
 func BenchmarkHotPath(b *testing.B) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, _, _, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -43,8 +45,8 @@ func BenchmarkHotPath(b *testing.B) {
 
 func TestGoodHandshake(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -98,9 +100,9 @@ func TestWrongResponderHandshake(t *testing.T) {
 	// The IPs here are chosen on purpose:
 	// The current remote handling will sort by preference, public, and then lexically.
 	// So we need them to have a higher address than evil (we could apply a preference though)
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
-	evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
+	evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
 
 	// Add their real udp addr, which should be tried after evil.
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -163,8 +165,8 @@ func TestStage1Race(t *testing.T) {
 	// But will eventually collapse down to a single tunnel
 
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Put their info in our lighthouse and vice versa
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -240,8 +242,8 @@ func TestStage1Race(t *testing.T) {
 
 func TestUncleanShutdownRaceLoser(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -289,8 +291,8 @@ func TestUncleanShutdownRaceLoser(t *testing.T) {
 
 func TestUncleanShutdownRaceWinner(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
@@ -340,9 +342,9 @@ func TestUncleanShutdownRaceWinner(t *testing.T) {
 
 func TestRelays(t *testing.T) {
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
@@ -371,9 +373,9 @@ func TestRelays(t *testing.T) {
 func TestStage1RaceRelays(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 
 	// Teach my how to get to the relay and that their can be reached via the relay
 	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
@@ -418,9 +420,9 @@ func TestStage1RaceRelays(t *testing.T) {
 func TestStage1RaceRelays2(t *testing.T) {
 	//NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay
 	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
-	myControl, myVpnIpNet, myUdpAddr := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
-	relayControl, relayVpnIpNet, relayUdpAddr := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
-	theirControl, theirVpnIpNet, theirUdpAddr := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
 	l := NewTestLogger()
 
 	// Teach my how to get to the relay and that their can be reached via the relay
@@ -503,5 +505,366 @@ func TestStage1RaceRelays2(t *testing.T) {
 	//
 	////TODO: assert hostmaps
 }
+func TestRehandshakingRelays(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me     ", net.IP{10, 0, 0, 1}, m{"relay": m{"use_relays": true}})
+	relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay  ", net.IP{10, 0, 0, 128}, m{"relay": m{"am_relay": true}})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them   ", net.IP{10, 0, 0, 2}, m{"relay": m{"use_relays": true}})
+
+	// Teach my how to get to the relay and that their can be reached via the relay
+	myControl.InjectLightHouseAddr(relayVpnIpNet.IP, relayUdpAddr)
+	myControl.InjectRelays(theirVpnIpNet.IP, []net.IP{relayVpnIpNet.IP})
+	relayControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, relayControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	relayControl.Start()
+	theirControl.Start()
+
+	t.Log("Trigger a handshake from me to them via the relay")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+
+	p := r.RouteForAllUntilTxTun(theirControl)
+	r.Log("Assert the tunnel works")
+	assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.IP, theirVpnIpNet.IP, 80, 80)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+
+	// When I update the certificate for the relay, both me and them will have 2 host infos for the relay,
+	// and the main host infos will not have any relay state to handle the me<->relay<->them tunnel.
+	r.Log("Renew relay certificate and spin until me and them sees it")
+	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	relayConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(myNextPEM),
+		"key":  string(myNextPrivKey),
+	}
+	rc, err := yaml.Marshal(relayConfig.Settings)
+	assert.NoError(t, err)
+	relayConfig.ReloadConfigString(string(rc))
+
+	for {
+		r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, myVpnIpNet.IP, relayVpnIpNet.IP, myControl, relayControl, r)
+		c := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between my and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	for {
+		r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet")
+		assertTunnel(t, theirVpnIpNet.IP, relayVpnIpNet.IP, theirControl, relayControl, r)
+		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(relayVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			r.Log("Certificate between their and relay is updated!")
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	r.Log("Assert the relay tunnel still works")
+	assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+	r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl)
+	// We should have two hostinfos on all sides
+	for len(myControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("myControl hostinfos got cleaned up!")
+	for len(theirControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("theirControl hostinfos got cleaned up!")
+	for len(relayControl.GetHostmap().Indexes) != 2 {
+		t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes))
+		r.Log("Assert the relay tunnel still works")
+		assertTunnel(t, theirVpnIpNet.IP, myVpnIpNet.IP, theirControl, myControl, r)
+		r.Log("yupitdoes")
+		time.Sleep(time.Second)
+	}
+	t.Logf("relayControl hostinfos got cleaned up!")
+}
+
+func TestRehandshaking(t *testing.T) {
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+
+	// Put their info in our lighthouse and vice versa
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up a tunnel between me and them")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	r.Log("Renew my certificate and spin until their sees it")
+	_, _, myNextPrivKey, myNextPEM := newTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	myConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(myNextPEM),
+		"key":  string(myNextPrivKey),
+	}
+	rc, err := yaml.Marshal(myConfig.Settings)
+	assert.NoError(t, err)
+	myConfig.ReloadConfigString(string(rc))
+
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+		if len(c.Cert.Details.Groups) != 0 {
+			// We have a new certificate now
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
+	rc, err = yaml.Marshal(theirConfig.Settings)
+	assert.NoError(t, err)
+	var theirNewConfig m
+	assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
+	theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
+	theirFirewall["inbound"] = []m{{
+		"proto": "any",
+		"port":  "any",
+		"group": "new group",
+	}}
+	rc, err = yaml.Marshal(theirNewConfig)
+	assert.NoError(t, err)
+	theirConfig.ReloadConfigString(string(rc))
+
+	r.Log("Spin until there is only 1 tunnel")
+	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		t.Log("Connection manager hasn't ticked yet")
+		time.Sleep(time.Second)
+	}
+
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
+	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
+
+	// Make sure the correct tunnel won
+	c := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	assert.Contains(t, c.Cert.Details.Groups, "new group")
+
+	// We should only have a single tunnel now on both sides
+	assert.Len(t, myFinalHostmapHosts, 1)
+	assert.Len(t, theirFinalHostmapHosts, 1)
+	assert.Len(t, myFinalHostmapIndexes, 1)
+	assert.Len(t, theirFinalHostmapIndexes, 1)
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestRehandshakingLoser(t *testing.T) {
+	// The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel
+	// Should be the one with the new certificate
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me  ", net.IP{10, 0, 0, 2}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 1}, nil)
+
+	// Put their info in our lighthouse and vice versa
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Build a router so we don't have to reason who gets which packet
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	t.Log("Stand up a tunnel between me and them")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	tt1 := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+	tt2 := theirControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(myVpnIpNet.IP), false)
+	fmt.Println(tt1.LocalIndex, tt2.LocalIndex)
+
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	r.Log("Renew their certificate and spin until mine sees it")
+	_, _, theirNextPrivKey, theirNextPEM := newTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"})
+
+	caB, err := ca.MarshalToPEM()
+	if err != nil {
+		panic(err)
+	}
+
+	theirConfig.Settings["pki"] = m{
+		"ca":   string(caB),
+		"cert": string(theirNextPEM),
+		"key":  string(theirNextPrivKey),
+	}
+	rc, err := yaml.Marshal(theirConfig.Settings)
+	assert.NoError(t, err)
+	theirConfig.ReloadConfigString(string(rc))
+
+	for {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+
+		_, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"]
+		if theirNewGroup {
+			break
+		}
+
+		time.Sleep(time.Second)
+	}
+
+	// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
+	rc, err = yaml.Marshal(myConfig.Settings)
+	assert.NoError(t, err)
+	var myNewConfig m
+	assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
+	theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
+	theirFirewall["inbound"] = []m{{
+		"proto": "any",
+		"port":  "any",
+		"group": "their new group",
+	}}
+	rc, err = yaml.Marshal(myNewConfig)
+	assert.NoError(t, err)
+	myConfig.ReloadConfigString(string(rc))
+
+	r.Log("Spin until there is only 1 tunnel")
+	for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 {
+		assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+		t.Log("Connection manager hasn't ticked yet")
+		time.Sleep(time.Second)
+	}
+
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+	myFinalHostmapHosts := myControl.ListHostmapHosts(false)
+	myFinalHostmapIndexes := myControl.ListHostmapIndexes(false)
+	theirFinalHostmapHosts := theirControl.ListHostmapHosts(false)
+	theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false)
+
+	// Make sure the correct tunnel won
+	theirCertInMe := myControl.GetHostInfoByVpnIp(iputil.Ip2VpnIp(theirVpnIpNet.IP), false)
+	assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group")
+
+	// We should only have a single tunnel now on both sides
+	assert.Len(t, myFinalHostmapHosts, 1)
+	assert.Len(t, theirFinalHostmapHosts, 1)
+	assert.Len(t, myFinalHostmapIndexes, 1)
+	assert.Len(t, theirFinalHostmapIndexes, 1)
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+	myControl.Stop()
+	theirControl.Stop()
+}
+
+func TestRaceRegression(t *testing.T) {
+	// This test forces stage 1, stage 2, stage 1 to be received by me from them
+	// We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which
+	// caused a cross-linked hostinfo
+	ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
+	myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
+
+	// Put their info in our lighthouse
+	myControl.InjectLightHouseAddr(theirVpnIpNet.IP, theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet.IP, myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	//them rx stage:1 initiatorIndex=642843150 responderIndex=0
+	//me rx   stage:1 initiatorIndex=120607833 responderIndex=0
+	//them rx stage:1 initiatorIndex=642843150 responderIndex=0
+	//me rx   stage:2 initiatorIndex=642843150 responderIndex=3701775874
+	//me rx   stage:1 initiatorIndex=120607833 responderIndex=0
+	//them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089
+
+	t.Log("Start both handshakes")
+	myControl.InjectTunUDPPacket(theirVpnIpNet.IP, 80, 80, []byte("Hi from me"))
+	theirControl.InjectTunUDPPacket(myVpnIpNet.IP, 80, 80, []byte("Hi from them"))
+
+	t.Log("Get both stage 1")
+	myStage1ForThem := myControl.GetFromUDP(true)
+	theirStage1ForMe := theirControl.GetFromUDP(true)
+
+	t.Log("Inject them in a special way")
+	theirControl.InjectUDPPacket(myStage1ForThem)
+	myControl.InjectUDPPacket(theirStage1ForMe)
+	theirControl.InjectUDPPacket(myStage1ForThem)
+
+	//TODO: ensure stage 2
+	t.Log("Get both stage 2")
+	myStage2ForThem := myControl.GetFromUDP(true)
+	theirStage2ForMe := theirControl.GetFromUDP(true)
+
+	t.Log("Inject them in a special way again")
+	myControl.InjectUDPPacket(theirStage2ForMe)
+	myControl.InjectUDPPacket(theirStage1ForMe)
+	theirControl.InjectUDPPacket(myStage2ForThem)
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	t.Log("Flush the packets")
+	r.RouteForAllUntilTxTun(myControl)
+	r.RouteForAllUntilTxTun(theirControl)
+	r.RenderHostmaps("Starting hostmaps", myControl, theirControl)
+
+	t.Log("Make sure the tunnel still works")
+	assertTunnel(t, myVpnIpNet.IP, theirVpnIpNet.IP, myControl, theirControl, r)
+
+	myControl.Stop()
+	theirControl.Stop()
+}
 
+//TODO: test
+// Race winner renews and handshakes
+// Race loser renews and handshakes
+// Does race winner repin the cert to old?
 //TODO: add a test with many lies

+ 4 - 4
e2e/helpers_test.go

@@ -30,7 +30,7 @@ import (
 type m map[string]interface{}
 
 // newSimpleServer creates a nebula instance with many assumptions
-func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr) {
+func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, overrides m) (*nebula.Control, *net.IPNet, *net.UDPAddr, *config.C) {
 	l := NewTestLogger()
 
 	vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
@@ -78,8 +78,8 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 			"level":            l.Level.String(),
 		},
 		"timers": m{
-			"pending_deletion_interval": 4,
-			"connection_alive_interval": 4,
+			"pending_deletion_interval": 2,
+			"connection_alive_interval": 2,
 		},
 	}
 
@@ -105,7 +105,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
 		panic(err)
 	}
 
-	return control, vpnIpNet, &udpAddr
+	return control, vpnIpNet, &udpAddr, c
 }
 
 // newTestCaCert will generate a CA cert

+ 5 - 5
e2e/router/router.go

@@ -215,7 +215,7 @@ func (r *R) renderFlow() {
 			continue
 		}
 		participants[addr] = struct{}{}
-		sanAddr := strings.Replace(addr, ":", "#58;", 1)
+		sanAddr := strings.Replace(addr, ":", "-", 1)
 		participantsVals = append(participantsVals, sanAddr)
 		fmt.Fprintf(
 			f, "    participant %s as Nebula: %s<br/>UDP: %s\n",
@@ -252,9 +252,9 @@ func (r *R) renderFlow() {
 
 			fmt.Fprintf(f,
 				"    %s%s%s: %s(%s), index %v, counter: %v\n",
-				strings.Replace(p.from.GetUDPAddr(), ":", "#58;", 1),
+				strings.Replace(p.from.GetUDPAddr(), ":", "-", 1),
 				line,
-				strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1),
+				strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
 				h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter,
 			)
 		}
@@ -758,8 +758,8 @@ func (r *R) formatUdpPacket(p *packet) string {
 	data := packet.ApplicationLayer()
 	return fmt.Sprintf(
 		"    %s-->>%s: src port: %v<br/>dest port: %v<br/>data: \"%v\"\n",
-		strings.Replace(from, ":", "#58;", 1),
-		strings.Replace(p.to.GetUDPAddr(), ":", "#58;", 1),
+		strings.Replace(from, ":", "-", 1),
+		strings.Replace(p.to.GetUDPAddr(), ":", "-", 1),
 		udp.SrcPort,
 		udp.DstPort,
 		string(data.Payload()),

+ 5 - 4
handshake_manager.go

@@ -231,7 +231,8 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu))
+						// This must send over the hostinfo, not over hm.Hosts[ip]
+						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						c.l.WithFields(logrus.Fields{
 							"relayFrom":           c.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
@@ -266,7 +267,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 							WithError(err).
 							Error("Failed to marshal Control message to create relay")
 					} else {
-						f.SendMessageToVpnIp(header.Control, 0, *relay, msg, make([]byte, 12), make([]byte, mtu))
+						f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
 						c.l.WithFields(logrus.Fields{
 							"relayFrom":           c.lightHouse.myVpnIp,
 							"relayTo":             vpnIp,
@@ -328,8 +329,8 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 		testHostInfo := existingHostInfo
 		for testHostInfo != nil {
 			// Is it just a delayed handshake packet?
-			if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], existingHostInfo.HandshakePacket[handshakePacket]) {
-				return existingHostInfo, ErrAlreadySeen
+			if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) {
+				return testHostInfo, ErrAlreadySeen
 			}
 
 			testHostInfo = testHostInfo.next

+ 4 - 0
handshake_manager_test.go

@@ -88,4 +88,8 @@ func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte
 	return
 }
 
+func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
+	return
+}
+
 func (mw *mockEncWriter) Handshake(vpnIP iputil.VpnIp) {}

+ 34 - 41
hostmap.go

@@ -32,6 +32,7 @@ const RoamingSuppressSeconds = 2
 
 const (
 	Requested = iota
+	PeerRequested
 	Established
 )
 
@@ -79,6 +80,16 @@ func (rs *RelayState) DeleteRelay(ip iputil.VpnIp) {
 	delete(rs.relays, ip)
 }
 
+func (rs *RelayState) CopyAllRelayFor() []*Relay {
+	rs.RLock()
+	defer rs.RUnlock()
+	ret := make([]*Relay, 0, len(rs.relayForByIdx))
+	for _, r := range rs.relayForByIdx {
+		ret = append(ret, r)
+	}
+	return ret
+}
+
 func (rs *RelayState) GetRelayForByIp(ip iputil.VpnIp) (*Relay, bool) {
 	rs.RLock()
 	defer rs.RUnlock()
@@ -279,29 +290,13 @@ func (hm *HostMap) EmitStats(name string) {
 
 func (hm *HostMap) RemoveRelay(localIdx uint32) {
 	hm.Lock()
-	hiRelay, ok := hm.Relays[localIdx]
+	_, ok := hm.Relays[localIdx]
 	if !ok {
 		hm.Unlock()
 		return
 	}
 	delete(hm.Relays, localIdx)
 	hm.Unlock()
-	ip, ok := hiRelay.relayState.RemoveRelay(localIdx)
-	if !ok {
-		return
-	}
-	hiPeer, err := hm.QueryVpnIp(ip)
-	if err != nil {
-		return
-	}
-	var otherPeerIdx uint32
-	hiPeer.relayState.DeleteRelay(hiRelay.vpnIp)
-	relay, ok := hiPeer.relayState.GetRelayForByIp(hiRelay.vpnIp)
-	if ok {
-		otherPeerIdx = relay.LocalIndex
-	}
-	// I am a relaying host. I need to remove the other relay, too.
-	hm.RemoveRelay(otherPeerIdx)
 }
 
 func (hm *HostMap) GetIndexByVpnIp(vpnIp iputil.VpnIp) (uint32, error) {
@@ -395,29 +390,6 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool {
 	hm.unlockedDeleteHostInfo(hostinfo)
 	hm.Unlock()
 
-	// And tear down all the relays going through this host, if final
-	for _, localIdx := range hostinfo.relayState.CopyRelayForIdxs() {
-		hm.RemoveRelay(localIdx)
-	}
-
-	if final {
-		// And tear down the relays this deleted hostInfo was using to be reached
-		teardownRelayIdx := []uint32{}
-		for _, relayIp := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, err := hm.QueryVpnIp(relayIp)
-			if err != nil {
-				hm.l.WithError(err).WithField("relay", relayIp).Info("Missing relay host in hostmap")
-			} else {
-				if r, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp); ok {
-					teardownRelayIdx = append(teardownRelayIdx, r.LocalIndex)
-				}
-			}
-		}
-		for _, localIdx := range teardownRelayIdx {
-			hm.RemoveRelay(localIdx)
-		}
-	}
-
 	return final
 }
 
@@ -508,6 +480,10 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
 			"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
 			Debug("Hostmap hostInfo deleted")
 	}
+
+	for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() {
+		delete(hm.Relays, localRelayIdx)
+	}
 }
 
 func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
@@ -562,6 +538,24 @@ func (hm *HostMap) QueryVpnIp(vpnIp iputil.VpnIp) (*HostInfo, error) {
 	return hm.queryVpnIp(vpnIp, nil)
 }
 
+func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp iputil.VpnIp) (*HostInfo, *Relay, error) {
+	hm.RLock()
+	defer hm.RUnlock()
+
+	h, ok := hm.Hosts[relayHostIp]
+	if !ok {
+		return nil, nil, errors.New("unable to find host")
+	}
+	for h != nil {
+		r, ok := h.relayState.QueryRelayForByIp(targetIp)
+		if ok && r.State == Established {
+			return h, r, nil
+		}
+		h = h.next
+	}
+	return nil, nil, errors.New("unable to find host with relay")
+}
+
 // PromoteBestQueryVpnIp will attempt to lazily switch to the best remote every
 // `PromoteEvery` calls to this function for a given host.
 func (hm *HostMap) PromoteBestQueryVpnIp(vpnIp iputil.VpnIp, ifce *Interface) (*HostInfo, error) {
@@ -709,7 +703,6 @@ func (i *HostInfo) handshakeComplete(l *logrus.Logger, m *cachedPacketMetrics) {
 	i.packetStore = make([]*cachedPacket, 0)
 	i.ConnectionState.ready = true
 	i.ConnectionState.queueLock.Unlock()
-	i.ConnectionState.certState = nil
 }
 
 func (i *HostInfo) GetCert() *cert.NebulaCertificate {

+ 12 - 24
inside.go

@@ -57,7 +57,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 	}
 	ci := hostinfo.ConnectionState
 
-	if ci.ready == false {
+	if !ci.ready {
 		// Because we might be sending stored packets, lock here to stop new things going to
 		// the packet queue.
 		ci.queueLock.Lock()
@@ -177,7 +177,7 @@ func (f *Interface) initHostInfo(hostinfo *HostInfo) {
 	hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
 }
 
-func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
+func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
 	fp := &firewall.Packet{}
 	err := newPacket(p, false, fp)
 	if err != nil {
@@ -186,7 +186,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	}
 
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
+	dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil)
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("fwPacket", fp).
@@ -196,7 +196,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 		return
 	}
 
-	f.sendNoMetrics(header.Message, st, hostInfo.ConnectionState, hostInfo, nil, p, nb, out, 0)
+	f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, nil, p, nb, out, 0)
 }
 
 // SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp
@@ -215,19 +215,18 @@ func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSu
 		// the packet queue.
 		hostInfo.ConnectionState.queueLock.Lock()
 		if !hostInfo.ConnectionState.ready {
-			hostInfo.cachePacket(f.l, t, st, p, f.sendMessageToVpnIp, f.cachedPacketMetrics)
+			hostInfo.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
 			hostInfo.ConnectionState.queueLock.Unlock()
 			return
 		}
 		hostInfo.ConnectionState.queueLock.Unlock()
 	}
 
-	f.sendMessageToVpnIp(t, st, hostInfo, p, nb, out)
-	return
+	f.SendMessageToHostInfo(t, st, hostInfo, p, nb, out)
 }
 
-func (f *Interface) sendMessageToVpnIp(t header.MessageType, st header.MessageSubType, hostInfo *HostInfo, p, nb, out []byte) {
-	f.send(t, st, hostInfo.ConnectionState, hostInfo, p, nb, out)
+func (f *Interface) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hi *HostInfo, p, nb, out []byte) {
+	f.send(t, st, hi.ConnectionState, hi, p, nb, out)
 }
 
 func (f *Interface) send(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, p, nb, out []byte) {
@@ -302,6 +301,7 @@ func (f *Interface) SendVia(via *HostInfo,
 	if err != nil {
 		via.logger(f.l).WithError(err).Info("Failed to WriteTo in sendVia")
 	}
+	f.connectionManager.RelayUsed(relay.LocalIndex)
 }
 
 func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType, ci *ConnectionState, hostinfo *HostInfo, remote *udp.Addr, p, nb, out []byte, q int) {
@@ -372,31 +372,19 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType
 	} else {
 		// Try to send via a relay
 		for _, relayIP := range hostinfo.relayState.CopyRelayIps() {
-			relayHostInfo, err := f.hostMap.QueryVpnIp(relayIP)
+			relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP)
 			if err != nil {
+				hostinfo.relayState.DeleteRelay(relayIP)
 				hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo")
 				continue
 			}
-			relay, ok := relayHostInfo.relayState.QueryRelayForByIp(hostinfo.vpnIp)
-			if !ok {
-				hostinfo.logger(f.l).
-					WithField("relay", relayHostInfo.vpnIp).
-					WithField("relayTo", hostinfo.vpnIp).
-					Info("sendNoMetrics relay missing object for target")
-				continue
-			}
 			f.SendVia(relayHostInfo, relay, out, nb, fullOut[:header.Len+len(out)], true)
 			break
 		}
 	}
-	return
 }
 
 func isMulticast(ip iputil.VpnIp) bool {
 	// Class D multicast
-	if (((ip >> 24) & 0xff) & 0xf0) == 0xe0 {
-		return true
-	}
-
-	return false
+	return (((ip >> 24) & 0xff) & 0xf0) == 0xe0
 }

+ 1 - 0
interface.go

@@ -99,6 +99,7 @@ type EncWriter interface {
 		nocopy bool,
 	)
 	SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, nb, out []byte)
+	SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
 	Handshake(vpnIp iputil.VpnIp)
 }
 

+ 17 - 0
lighthouse_test.go

@@ -377,6 +377,23 @@ func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte
 func (tw *testEncWriter) Handshake(vpnIp iputil.VpnIp) {
 }
 
+func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, _, _ []byte) {
+	msg := &NebulaMeta{}
+	err := msg.Unmarshal(p)
+	if tw.metaFilter == nil || msg.Type == *tw.metaFilter {
+		tw.lastReply = testLhReply{
+			nebType:    t,
+			nebSubType: st,
+			vpnIp:      hostinfo.vpnIp,
+			msg:        msg,
+		}
+	}
+
+	if err != nil {
+		panic(err)
+	}
+}
+
 func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp iputil.VpnIp, p, _, _ []byte) {
 	msg := &NebulaMeta{}
 	err := msg.Unmarshal(p)

+ 12 - 13
outside.go

@@ -83,7 +83,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 
 		switch h.Subtype {
 		case header.MessageNone:
-			f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache)
+			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
+				return
+			}
 		case header.MessageRelay:
 			// The entire body is sent as AD, not encrypted.
 			// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
@@ -100,7 +102,9 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 			signedPayload = signedPayload[header.Len:]
 			// Pull the Roaming parts up here, and return in all call paths.
 			f.handleHostRoaming(hostinfo, addr)
+			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
 			f.connectionManager.In(hostinfo.localIndexId)
+			f.connectionManager.RelayUsed(h.RemoteIndex)
 
 			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
 			if !ok {
@@ -118,17 +122,11 @@ func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byt
 				return
 			case ForwardingType:
 				// Find the target HostInfo relay object
-				targetHI, err := f.hostMap.QueryVpnIp(relay.PeerIp)
+				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
 				if err != nil {
 					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
 					return
 				}
-				// find the target Relay info object
-				targetRelay, ok := targetHI.relayState.QueryRelayForByIp(hostinfo.vpnIp)
-				if !ok {
-					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp}).Info("Failed to find relay in hostinfo")
-					return
-				}
 
 				// If that relay is Established, forward the payload through it
 				if targetRelay.State == Established {
@@ -382,7 +380,7 @@ func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []
 	return out, nil
 }
 
-func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) {
+func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
 	var err error
 
 	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
@@ -390,20 +388,20 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
 		//TODO: maybe after build 64 is out? 06/14/2018 - NB
 		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
-		return
+		return false
 	}
 
 	err = newPacket(out, true, fwPacket)
 	if err != nil {
 		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
 			Warnf("Error while validating inbound packet")
-		return
+		return false
 	}
 
 	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
 		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 			Debugln("dropping out of window packet")
-		return
+		return false
 	}
 
 	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
@@ -414,7 +412,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 				WithField("reason", dropReason).
 				Debugln("dropping inbound packet")
 		}
-		return
+		return false
 	}
 
 	f.connectionManager.In(hostinfo.localIndexId)
@@ -422,6 +420,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 	if err != nil {
 		f.l.WithError(err).Error("Failed to write to tun")
 	}
+	return true
 }
 
 func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {

+ 1 - 1
punchy.go

@@ -75,7 +75,7 @@ func (p *Punchy) reload(c *config.C, initial bool) {
 	}
 
 	if initial || c.HasChanged("punchy.target_all_remotes") {
-		p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", true))
+		p.punchEverything.Store(c.GetBool("punchy.target_all_remotes", false))
 		if !initial {
 			p.l.WithField("target_all_remotes", p.GetTargetEverything()).Info("punchy.target_all_remotes changed")
 		}

+ 27 - 25
relay_manager.go

@@ -141,27 +141,29 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *
 		rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo")
 		return
 	}
-	peerRelay.State = Established
-	resp := NebulaControl{
-		Type:                NebulaControl_CreateRelayResponse,
-		ResponderRelayIndex: peerRelay.LocalIndex,
-		InitiatorRelayIndex: peerRelay.RemoteIndex,
-		RelayFromIp:         uint32(peerHostInfo.vpnIp),
-		RelayToIp:           uint32(target),
-	}
-	msg, err := resp.Marshal()
-	if err != nil {
-		rm.l.
-			WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
-	} else {
-		f.SendMessageToVpnIp(header.Control, 0, peerHostInfo.vpnIp, msg, make([]byte, 12), make([]byte, mtu))
-		rm.l.WithFields(logrus.Fields{
-			"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
-			"relayTo":             iputil.VpnIp(resp.RelayToIp),
-			"initiatorRelayIndex": resp.InitiatorRelayIndex,
-			"responderRelayIndex": resp.ResponderRelayIndex,
-			"vpnIp":               peerHostInfo.vpnIp}).
-			Info("send CreateRelayResponse")
+	if peerRelay.State == PeerRequested {
+		peerRelay.State = Established
+		resp := NebulaControl{
+			Type:                NebulaControl_CreateRelayResponse,
+			ResponderRelayIndex: peerRelay.LocalIndex,
+			InitiatorRelayIndex: peerRelay.RemoteIndex,
+			RelayFromIp:         uint32(peerHostInfo.vpnIp),
+			RelayToIp:           uint32(target),
+		}
+		msg, err := resp.Marshal()
+		if err != nil {
+			rm.l.
+				WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
+		} else {
+			f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu))
+			rm.l.WithFields(logrus.Fields{
+				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
+				"relayTo":             iputil.VpnIp(resp.RelayToIp),
+				"initiatorRelayIndex": resp.InitiatorRelayIndex,
+				"responderRelayIndex": resp.ResponderRelayIndex,
+				"vpnIp":               peerHostInfo.vpnIp}).
+				Info("send CreateRelayResponse")
+		}
 	}
 }
 
@@ -223,7 +225,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 			logMsg.
 				WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
 		} else {
-			f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu))
+			f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 			rm.l.WithFields(logrus.Fields{
 				"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
 				"relayTo":             iputil.VpnIp(resp.RelayToIp),
@@ -278,7 +280,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 				logMsg.
 					WithError(err).Error("relayManager Failed to marshal Control message to create relay")
 			} else {
-				f.SendMessageToVpnIp(header.Control, 0, target, msg, make([]byte, 12), make([]byte, mtu))
+				f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu))
 				rm.l.WithFields(logrus.Fields{
 					"relayFrom":           iputil.VpnIp(req.RelayFromIp),
 					"relayTo":             iputil.VpnIp(req.RelayToIp),
@@ -292,7 +294,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 		relay, ok := h.relayState.QueryRelayForByIp(target)
 		if !ok {
 			// Add the relay
-			state := Requested
+			state := PeerRequested
 			if targetRelay != nil && targetRelay.State == Established {
 				state = Established
 			}
@@ -324,7 +326,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N
 					rm.l.
 						WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay")
 				} else {
-					f.SendMessageToVpnIp(header.Control, 0, h.vpnIp, msg, make([]byte, 12), make([]byte, mtu))
+					f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu))
 					rm.l.WithFields(logrus.Fields{
 						"relayFrom":           iputil.VpnIp(resp.RelayFromIp),
 						"relayTo":             iputil.VpnIp(resp.RelayToIp),