Browse Source

Dont apply race avoidance to existing handshakes, use the handshake time to determine who wins (#451)

Co-authored-by: Wade Simmons <[email protected]>
Nathan Brown 4 years ago
parent
commit
db23fdf9bc
3 changed files with 23 additions and 17 deletions
  1. 13 13
      handshake_ix.go
  2. 5 4
      handshake_manager.go
  3. 5 0
      hostmap.go

+ 13 - 13
handshake_ix.go

@@ -119,11 +119,12 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 	}
 
 	hostinfo := &HostInfo{
-		ConnectionState: ci,
-		localIndexId:    myIndex,
-		remoteIndexId:   hs.Details.InitiatorIndex,
-		hostId:          vpnIP,
-		HandshakePacket: make(map[uint8][]byte, 0),
+		ConnectionState:   ci,
+		localIndexId:      myIndex,
+		remoteIndexId:     hs.Details.InitiatorIndex,
+		hostId:            vpnIP,
+		HandshakePacket:   make(map[uint8][]byte, 0),
+		lastHandshakeTime: hs.Details.Time,
 	}
 
 	hostinfo.Lock()
@@ -138,6 +139,8 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 
 	hs.Details.ResponderIndex = myIndex
 	hs.Details.Cert = ci.certState.rawCertificateNoKey
+	// Update the time in case their clock is way off from ours
+	hs.Details.Time = uint64(time.Now().Unix())
 
 	hsBytes, err := proto.Marshal(hs)
 	if err != nil {
@@ -204,18 +207,15 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 			}
 			return
 		case ErrExistingHostInfo:
-			// This means there was an existing tunnel and we didn't win
-			// handshake avoidance
-
-			//TODO: sprinkle the new protobuf stuff in here, send a reply to get the recv_errors flowing
-			//TODO: if not new send a test packet like old
-
+			// This means there was an existing tunnel and this handshake was older than the one we are currently based on
 			f.l.WithField("vpnIp", IntIp(vpnIP)).WithField("udpAddr", addr).
 				WithField("certName", certName).
+				WithField("oldHandshakeTime", existing.lastHandshakeTime).
+				WithField("newHandshakeTime", hostinfo.lastHandshakeTime).
 				WithField("fingerprint", fingerprint).
 				WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
 				WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
-				Info("Prevented a handshake race")
+				Info("Handshake too old")
 
 			// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues
 			f.SendMessageToVpnIp(test, testRequest, vpnIP, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
@@ -394,7 +394,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		Info("Handshake message received")
 
 	hostinfo.remoteIndexId = hs.Details.ResponderIndex
-	hs.Details.Cert = ci.certState.rawCertificateNoKey
+	hostinfo.lastHandshakeTime = hs.Details.Time
 
 	// Store their cert and our symmetric keys
 	ci.peerCert = remoteCert

+ 5 - 4
handshake_manager.go

@@ -199,7 +199,7 @@ var (
 // exact same handshake packet
 //
 // ErrExistingHostInfo if we already have an entry in the hostmap for this
-// VpnIP and overwrite was false.
+// VpnIP and the new handshake was older than the one we currently have
 //
 // ErrLocalIndexCollision if we already have an entry in the main or pending
 // hostmap for the hostinfo.localIndexId.
@@ -217,10 +217,12 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 			return existingHostInfo, ErrAlreadySeen
 		}
 
-		if !overwrite {
-			// It's a new handshake and we lost the race
+		// Is this a newer handshake?
+		if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime {
 			return existingHostInfo, ErrExistingHostInfo
 		}
+
+		existingHostInfo.logger(c.l).Info("Taking new handshake")
 	}
 
 	existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
@@ -261,7 +263,6 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket
 	}
 
 	if existingHostInfo != nil {
-		hostinfo.logger(c.l).Info("Race lost, taking new handshake")
 		// We are going to overwrite this entry, so remove the old references
 		delete(c.mainHostMap.Hosts, existingHostInfo.hostId)
 		delete(c.mainHostMap.Indexes, existingHostInfo.localIndexId)

+ 5 - 0
hostmap.go

@@ -59,6 +59,11 @@ type HostInfo struct {
 	// with a handshake
 	lastRebindCount int8
 
+	// lastHandshakeTime records the time the remote side told us about at the stage when the handshake was completed locally
+	// Stage 1 packet will contain it if I am a responder, stage 2 packet if I am an initiator
+	// This is used to avoid an attack where a handshake packet is replayed after some time
+	lastHandshakeTime uint64
+
 	lastRoam       time.Time
 	lastRoamRemote *udpAddr
 }