Переглянути джерело

squish bug from cert removals

JackDoan 2 днів тому
батько
коміт
8431da4e66
2 змінених файлів з 101 додано та 1 видалено
  1. 5 1
      connection_manager.go
  2. 96 0
      e2e/tunnels_test.go

+ 5 - 1
connection_manager.go

@@ -461,6 +461,10 @@ func (cm *connectionManager) shouldSwapPrimary(current *HostInfo) bool {
 	}
 
 	crt := cm.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version())
+	if crt == nil {
+		//my cert was reloaded away. We should definitely swap from this tunnel
+		return true
+	}
 	// If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things
 	// settle down.
 	return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature())
@@ -551,7 +555,7 @@ func (cm *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 	cs := cm.intf.pki.getCertState()
 	curCrt := hostinfo.ConnectionState.myCert
 	myCrt := cs.getCertificate(curCrt.Version())
-	if curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
+	if myCrt != nil && curCrt.Version() >= cs.initiatingVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true {
 		// The current tunnel is using the latest certificate and version, no need to rehandshake.
 		return
 	}

+ 96 - 0
e2e/tunnels_test.go

@@ -153,3 +153,99 @@ func TestCertUpgrade(t *testing.T) {
 	myControl.Stop()
 	theirControl.Stop()
 }
+
+func TestCertDowngrade(t *testing.T) {
+	// The goal of this test is to ensure the shortest inactivity timeout will close the tunnel on both sides
+	// under ideal conditions
+	ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+	caB, err := ca.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	ca2, _, caKey2, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{})
+
+	ca2B, err := ca2.MarshalPEM()
+	if err != nil {
+		panic(err)
+	}
+	caStr := fmt.Sprintf("%s\n%s", caB, ca2B)
+
+	myCert, _, myPrivKey, myCertPem := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.1/24")}, nil, []string{})
+	myCert2, _ := cert_test.NewTestCertDifferentVersion(myCert, cert.Version2, ca2, caKey2)
+
+	theirCert, _, theirPrivKey, _ := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.MustParsePrefix("10.128.0.2/24")}, nil, []string{})
+	theirCert2, _ := cert_test.NewTestCertDifferentVersion(theirCert, cert.Version2, ca2, caKey2)
+
+	myControl, myVpnIpNet, myUdpAddr, myC := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{myCert2}, myPrivKey, m{})
+	theirControl, theirVpnIpNet, theirUdpAddr, _ := newServer([]cert.Certificate{ca, ca2}, []cert.Certificate{theirCert, theirCert2}, theirPrivKey, m{})
+
+	// Share our underlay information
+	myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr)
+	theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr)
+
+	// Start the servers
+	myControl.Start()
+	theirControl.Start()
+
+	r := router.NewR(t, myControl, theirControl)
+	defer r.RenderFlow()
+
+	r.Log("Assert the tunnel between me and them works")
+	assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r)
+	r.Log("yay")
+	assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+	r.Log("yay")
+	//todo ???
+	time.Sleep(1 * time.Second)
+	r.FlushAll()
+
+	mc := m{
+		"pki": m{
+			"ca":   caStr,
+			"cert": string(myCertPem),
+			"key":  string(myPrivKey),
+		},
+		"firewall": myC.Settings["firewall"],
+		"listen":   myC.Settings["listen"],
+		"logging":  myC.Settings["logging"],
+		"timers":   myC.Settings["timers"],
+	}
+
+	cb, err := yaml.Marshal(mc)
+	if err != nil {
+		panic(err)
+	}
+
+	r.Logf("reload new v1-only config")
+	err = myC.ReloadConfigString(string(cb))
+	assert.NoError(t, err)
+	r.Log("yay, spin until their sees it")
+	waitStart := time.Now()
+	for {
+		assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r)
+		c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false)
+		if c == nil {
+			r.Log("nil")
+		} else {
+			version := c.Cert.Version()
+			r.Logf("version %d", version)
+			if version == cert.Version1 {
+				break
+			}
+		}
+		since := time.Since(waitStart)
+		if since > time.Second*5 {
+			r.Log("wtf")
+		}
+		if since > time.Second*10 {
+			r.Log("wtf")
+			t.Fatal("Cert should be new by now")
+		}
+		time.Sleep(time.Second)
+	}
+
+	r.RenderHostmaps("Final hostmaps", myControl, theirControl)
+
+	myControl.Stop()
+	theirControl.Stop()
+}