فهرست منبع

even spicier change to rehandshake if we detect our cert is lower-version than our peer, and we have a newer-version cert available

JackDoan 17 ساعت پیش
والد
کامیت
654cb4b8b4
4فایلهای تغییر یافته به همراه104 افزوده شده و 7 حذف شده
  1. 27 3
      connection_manager.go
  2. 71 2
      e2e/tunnels_test.go
  3. 4 1
      handshake_ix.go
  4. 2 1
      handshake_manager.go

+ 27 - 3
connection_manager.go

@@ -354,7 +354,6 @@ func (cm *connectionManager) makeTrafficDecision(localIndex uint32, now time.Tim
 
 		if mainHostInfo {
 			decision = tryRehandshake
-
 		} else {
 			if cm.shouldSwapPrimary(hostinfo) {
 				decision = swapPrimary
@@ -554,8 +553,33 @@ func (cm *connectionManager) sendPunch(hostinfo *HostInfo) {
 func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 	cs := cm.intf.pki.getCertState()
 	curCrt := hostinfo.ConnectionState.myCert
-	myCrt := cs.getCertificate(curCrt.Version())
-	if myCrt != nil && curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+	curCrtVersion := curCrt.Version()
+	myCrt := cs.getCertificate(curCrtVersion)
+	if myCrt == nil {
+		cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+			WithField("version", curCrtVersion).
+			WithField("reason", "local certificate removed").
+			Info("Re-handshaking with remote")
+		cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil)
+		return
+	}
+	peerCrt := hostinfo.ConnectionState.peerCert
+	if peerCrt != nil && curCrtVersion < peerCrt.Certificate.Version() {
+		// if our certificate version is less than theirs, and we have a matching version available, rehandshake?
+		if cs.getCertificate(peerCrt.Certificate.Version()) != nil {
+			//todo trigger rehandshake with specific cert?
+			cm.l.WithField("vpnAddrs", hostinfo.vpnAddrs).
+				WithField("version", curCrtVersion).
+				WithField("peerVersion", peerCrt.Certificate.Version()).
+				WithField("reason", "local certificate version mismatch with peer, correcting").
+				Info("Re-handshaking with remote")
+			cm.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(hh *HandshakeHostInfo) {
+				hh.initiatingVersionOverride = peerCrt.Certificate.Version()
+			})
+			return
+		}
+	}
+	if curCrtVersion >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
 		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}

+ 71 - 2
e2e/tunnels_test.go

@@ -224,11 +224,13 @@ func TestCertDowngrade(t *testing.T) {
 	for {
 		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
 		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
-		if c == nil {
+		c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+		if c == nil || c2 == nil {
 			r.Log("nil")
 		} else {
 			version := c.Cert.Version()
-			r.Logf("version %d", version)
+			theirVersion := c2.Cert.Version()
+			r.Logf("version %d,%d", version, theirVersion)
 			if version == cert.Version1 {
 				break
 			}
@@ -249,3 +251,70 @@ func TestCertDowngrade(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestCertMismatchCorrection(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	myCert, _, myPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	//assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	//r.Log("yay")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		c2 := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false)
+		if c == nil || c2 == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			theirVersion := c2.Cert.Version()
+			r.Logf("version %d,%d", version, theirVersion)
+			if version == theirVersion {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*5 {
+			r.Log("wtf")
+		}
+		if since > time.Second*10 {
+			r.Log("wtf")
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}

+ 4 - 1
handshake_ix.go

@@ -23,9 +23,12 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
 		return false
 	}
 
-	// If we're connecting to a v6 address we must use a v2 cert
 	cs := f.pki.getCertState()
 	v := cs.initiatingVersion
+	if hh.initiatingVersionOverride != cert.VersionPre1 {
+		v = hh.initiatingVersionOverride
+	}
+	// If we're connecting to a v6 address we must use a v2 cert
 	for _, a := range hh.hostinfo.vpnAddrs {
 		if a.Is6() {
 			v = cert.Version2

+ 2 - 1
handshake_manager.go

@@ -74,7 +74,8 @@ type HandshakeHostInfo struct {
 	lastRemotes []netip.AddrPort // Remotes that we sent to during the previous attempt
 	packetStore []*cachedPacket  // A set of packets to be transmitted once the handshake completes
 
-	hostinfo *HostInfo
+	hostinfo                  *HostInfo
+	initiatingVersionOverride cert.Version
 }
 
 func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {