瀏覽代碼

Use atomic.Pointer for certState (#833)

Nate Brown 2 年之前
父節點
當前提交
6b3d42efa5
共有 5 個文件被更改,包括 12 次插入11 次删除
  1. 3 3
      connection_manager_test.go
  2. 1 1
      connection_state.go
  3. 1 1
      control_tester.go
  4. 6 5
      interface.go
  5. 1 1
      ssh.go

+ 3 - 3
connection_manager_test.go

@@ -54,12 +54,12 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		hostMap:          hostMap,
 		inside:           &test.NoopTun{},
 		outside:          &udp.Conn{},
-		certState:        cs,
 		firewall:         &Firewall{},
 		lightHouse:       lh,
 		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                l,
 	}
+	ifce.certState.Store(cs)
 	now := time.Now()
 
 	// Create manager
@@ -130,12 +130,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		hostMap:          hostMap,
 		inside:           &test.NoopTun{},
 		outside:          &udp.Conn{},
-		certState:        cs,
 		firewall:         &Firewall{},
 		lightHouse:       lh,
 		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
 		l:                l,
 	}
+	ifce.certState.Store(cs)
 	now := time.Now()
 
 	// Create manager
@@ -245,7 +245,6 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		hostMap:           hostMap,
 		inside:            &test.NoopTun{},
 		outside:           &udp.Conn{},
-		certState:         cs,
 		firewall:          &Firewall{},
 		lightHouse:        lh,
 		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.Conn{}, defaultHandshakeConfig),
@@ -253,6 +252,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		disconnectInvalid: true,
 		caPool:            ncp,
 	}
+	ifce.certState.Store(cs)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())

+ 1 - 1
connection_state.go

@@ -33,7 +33,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
 		cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
 	}
 
-	curCertState := f.certState
+	curCertState := f.certState.Load()
 	static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
 
 	b := NewBits(ReplayWindow)

+ 1 - 1
control_tester.go

@@ -161,5 +161,5 @@ func (c *Control) GetHostmap() *HostMap {
 }
 
 func (c *Control) GetCert() *cert.NebulaCertificate {
-	return c.f.certState.certificate
+	return c.f.certState.Load().certificate
 }

+ 6 - 5
interface.go

@@ -52,7 +52,7 @@ type Interface struct {
 	hostMap            *HostMap
 	outside            *udp.Conn
 	inside             overlay.Device
-	certState          *CertState
+	certState          atomic.Pointer[CertState]
 	cipher             string
 	firewall           *Firewall
 	connectionManager  *connectionManager
@@ -141,7 +141,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		hostMap:            c.HostMap,
 		outside:            c.Outside,
 		inside:             c.Inside,
-		certState:          c.certState,
 		cipher:             c.Cipher,
 		firewall:           c.Firewall,
 		serveDns:           c.ServeDns,
@@ -172,6 +171,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		l: c.l,
 	}
 
+	ifce.certState.Store(c.certState)
 	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval)
 
 	return ifce, nil
@@ -298,14 +298,15 @@ func (f *Interface) reloadCertKey(c *config.C) {
 	}
 
 	// did IP in cert change? if so, don't set
-	oldIPs := f.certState.certificate.Details.Ips
+	currentCert := f.certState.Load().certificate
+	oldIPs := currentCert.Details.Ips
 	newIPs := cs.certificate.Details.Ips
 	if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
 		f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
 		return
 	}
 
-	f.certState = cs
+	f.certState.Store(cs)
 	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
 }
 
@@ -316,7 +317,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.l, f.certState.certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return

+ 1 - 1
ssh.go

@@ -753,7 +753,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 		return nil
 	}
 
-	cert := ifce.certState.certificate
+	cert := ifce.certState.Load().certificate
 	if len(a) > 0 {
 		parsedIp := net.ParseIP(a[0])
 		if parsedIp == nil {