Browse Source

We only need the certificate in ConnectionState (#953)

Nate Brown 1 year ago
parent
commit
7edcf620c0
9 changed files with 37 additions and 51 deletions
  1. 3 3
      connection_manager.go
  2. 10 9
      connection_manager_test.go
  3. 11 9
      connection_state.go
  4. 1 1
      control_tester.go
  5. 7 4
      handshake_ix.go
  6. 1 5
      handshake_manager.go
  7. 2 11
      handshake_manager_test.go
  8. 1 8
      inside.go
  9. 1 1
      ssh.go

+ 3 - 3
connection_manager.go

@@ -406,7 +406,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 	}
 	}
 
 
 	certState := n.intf.pki.GetCertState()
 	certState := n.intf.pki.GetCertState()
-	return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
+	return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature)
 }
 }
 
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@@ -465,7 +465,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 	certState := n.intf.pki.GetCertState()
 	certState := n.intf.pki.GetCertState()
-	if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
+	if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) {
 		return
 		return
 	}
 	}
 
 
@@ -474,7 +474,7 @@ func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
 		Info("Re-handshaking with remote")
 		Info("Re-handshaking with remote")
 
 
 	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
 	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out
-	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)
+	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp)
 	if !newHostinfo.HandshakeReady {
 	if !newHostinfo.HandshakeReady {
 		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
 		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)
 	}
 	}

+ 10 - 9
connection_manager_test.go

@@ -79,8 +79,8 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		remoteIndexId: 9901,
 		remoteIndexId: 9901,
 	}
 	}
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		H:         &noise.HandshakeState{},
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 
@@ -159,8 +159,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		remoteIndexId: 9901,
 		remoteIndexId: 9901,
 	}
 	}
 	hostinfo.ConnectionState = &ConnectionState{
 	hostinfo.ConnectionState = &ConnectionState{
-		certState: cs,
-		H:         &noise.HandshakeState{},
+		myCert: &cert.NebulaCertificate{},
+		H:      &noise.HandshakeState{},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 
 
@@ -222,7 +222,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			PublicKey: pubCA,
 			PublicKey: pubCA,
 		},
 		},
 	}
 	}
-	caCert.Sign(cert.Curve_CURVE25519, privCA)
+
+	assert.NoError(t, caCert.Sign(cert.Curve_CURVE25519, privCA))
 	ncp := &cert.NebulaCAPool{
 	ncp := &cert.NebulaCAPool{
 		CAs: cert.NewCAPool().CAs,
 		CAs: cert.NewCAPool().CAs,
 	}
 	}
@@ -241,7 +242,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 			Issuer:    "ca",
 			Issuer:    "ca",
 		},
 		},
 	}
 	}
-	peerCert.Sign(cert.Curve_CURVE25519, privCA)
+	assert.NoError(t, peerCert.Sign(cert.Curve_CURVE25519, privCA))
 
 
 	cs := &CertState{
 	cs := &CertState{
 		RawCertificate:      []byte{},
 		RawCertificate:      []byte{},
@@ -275,9 +276,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	hostinfo := &HostInfo{
 	hostinfo := &HostInfo{
 		vpnIp: vpnIp,
 		vpnIp: vpnIp,
 		ConnectionState: &ConnectionState{
 		ConnectionState: &ConnectionState{
-			certState: cs,
-			peerCert:  &peerCert,
-			H:         &noise.HandshakeState{},
+			myCert:   &cert.NebulaCertificate{},
+			peerCert: &peerCert,
+			H:        &noise.HandshakeState{},
 		},
 		},
 	}
 	}
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)
 	nc.hostMap.unlockedAddHostInfo(hostinfo, ifce)

+ 11 - 9
connection_state.go

@@ -18,7 +18,7 @@ type ConnectionState struct {
 	eKey           *NebulaCipherState
 	eKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	dKey           *NebulaCipherState
 	H              *noise.HandshakeState
 	H              *noise.HandshakeState
-	certState      *CertState
+	myCert         *cert.NebulaCertificate
 	peerCert       *cert.NebulaCertificate
 	peerCert       *cert.NebulaCertificate
 	initiator      bool
 	initiator      bool
 	messageCounter atomic.Uint64
 	messageCounter atomic.Uint64
@@ -28,25 +28,27 @@ type ConnectionState struct {
 	ready          bool
 	ready          bool
 }
 }
 
 
-func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
+func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
 	var dhFunc noise.DHFunc
 	var dhFunc noise.DHFunc
-	curCertState := f.pki.GetCertState()
 
 
-	switch curCertState.Certificate.Details.Curve {
+	switch certState.Certificate.Details.Curve {
 	case cert.Curve_CURVE25519:
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
 	case cert.Curve_P256:
 		dhFunc = noiseutil.DHP256
 		dhFunc = noiseutil.DHP256
 	default:
 	default:
-		l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
+		l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve)
 		return nil
 		return nil
 	}
 	}
-	cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
-	if f.cipher == "chachapoly" {
+
+	var cs noise.CipherSuite
+	if cipher == "chachapoly" {
 		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
+	} else {
+		cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
 	}
 	}
 
 
-	static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
+	static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey}
 
 
 	b := NewBits(ReplayWindow)
 	b := NewBits(ReplayWindow)
 	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
 	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
@@ -72,7 +74,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
 		initiator: initiator,
 		initiator: initiator,
 		window:    b,
 		window:    b,
 		ready:     false,
 		ready:     false,
-		certState: curCertState,
+		myCert:    certState.Certificate,
 	}
 	}
 
 
 	return ci
 	return ci

+ 1 - 1
control_tester.go

@@ -165,7 +165,7 @@ func (c *Control) GetCert() *cert.NebulaCertificate {
 }
 }
 
 
 func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
 func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {
-	hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp, c.f.initHostInfo)
+	hostinfo := c.f.handshakeManager.AddVpnIp(vpnIp)
 	ixHandshakeStage0(c.f, vpnIp, hostinfo)
 	ixHandshakeStage0(c.f, vpnIp, hostinfo)
 
 
 	// If this is a static host, we don't need to wait for the HostQueryReply
 	// If this is a static host, we don't need to wait for the HostQueryReply

+ 7 - 4
handshake_ix.go

@@ -28,12 +28,14 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 		return
 		return
 	}
 	}
 
 
-	ci := hostinfo.ConnectionState
+	certState := f.pki.GetCertState()
+	ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
+	hostinfo.ConnectionState = ci
 
 
 	hsProto := &NebulaHandshakeDetails{
 	hsProto := &NebulaHandshakeDetails{
 		InitiatorIndex: hostinfo.localIndexId,
 		InitiatorIndex: hostinfo.localIndexId,
 		Time:           uint64(time.Now().UnixNano()),
 		Time:           uint64(time.Now().UnixNano()),
-		Cert:           ci.certState.RawCertificateNoKey,
+		Cert:           certState.RawCertificateNoKey,
 	}
 	}
 
 
 	hsBytes := []byte{}
 	hsBytes := []byte{}
@@ -69,7 +71,8 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 }
 }
 
 
 func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
 func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
-	ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
+	certState := f.pki.GetCertState()
+	ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0)
 	// Mark packet 1 as seen so it doesn't show up as missed
 	// Mark packet 1 as seen so it doesn't show up as missed
 	ci.window.Update(f.l, 1)
 	ci.window.Update(f.l, 1)
 
 
@@ -155,7 +158,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		Info("Handshake message received")
 		Info("Handshake message received")
 
 
 	hs.Details.ResponderIndex = myIndex
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = ci.certState.RawCertificateNoKey
+	hs.Details.Cert = certState.RawCertificateNoKey
 	// Update the time in case their clock is way off from ours
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
 

+ 1 - 5
handshake_manager.go

@@ -297,7 +297,7 @@ func (c *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, f EncWriter, light
 }
 }
 
 
 // AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
 // AddVpnIp will try to handshake with the provided vpn ip and return the hostinfo for it.
-func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *HostInfo {
+func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp) *HostInfo {
 	// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
 	// A write lock is used to avoid having to recheck the map and trading a read lock for a write lock
 	c.Lock()
 	c.Lock()
 	defer c.Unlock()
 	defer c.Unlock()
@@ -317,10 +317,6 @@ func (c *HandshakeManager) AddVpnIp(vpnIp iputil.VpnIp, init func(*HostInfo)) *H
 		},
 		},
 	}
 	}
 
 
-	if init != nil {
-		init(hostinfo)
-	}
-
 	c.vpnIps[vpnIp] = hostinfo
 	c.vpnIps[vpnIp] = hostinfo
 	c.metricInitiated.Inc(1)
 	c.metricInitiated.Inc(1)
 	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)
 	c.OutboundHandshakeTimer.Add(vpnIp, c.config.tryInterval)

+ 2 - 11
handshake_manager_test.go

@@ -28,17 +28,8 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
 	now := time.Now()
 	now := time.Now()
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 	blah.NextOutboundHandshakeTimerTick(now, mw)
 
 
-	var initCalled bool
-	initFunc := func(*HostInfo) {
-		initCalled = true
-	}
-
-	i := blah.AddVpnIp(ip, initFunc)
-	assert.True(t, initCalled)
-
-	initCalled = false
-	i2 := blah.AddVpnIp(ip, initFunc)
-	assert.False(t, initCalled)
+	i := blah.AddVpnIp(ip)
+	i2 := blah.AddVpnIp(ip)
 	assert.Same(t, i, i2)
 	assert.Same(t, i, i2)
 
 
 	i.remotes = NewRemoteList(nil)
 	i.remotes = NewRemoteList(nil)

+ 1 - 8
inside.go

@@ -1,7 +1,6 @@
 package nebula
 package nebula
 
 
 import (
 import (
-	"github.com/flynn/noise"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
 	"github.com/slackhq/nebula/header"
@@ -124,7 +123,7 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 
 
 	hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
 	hostinfo := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
 	if hostinfo == nil {
 	if hostinfo == nil {
-		hostinfo = f.handshakeManager.AddVpnIp(vpnIp, f.initHostInfo)
+		hostinfo = f.handshakeManager.AddVpnIp(vpnIp)
 	}
 	}
 	ci := hostinfo.ConnectionState
 	ci := hostinfo.ConnectionState
 
 
@@ -168,12 +167,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
 	return hostinfo
 	return hostinfo
 }
 }
 
 
-// initHostInfo is the init function to pass to (*HandshakeManager).AddVpnIP that
-// will create the initial Noise ConnectionState
-func (f *Interface) initHostInfo(hostinfo *HostInfo) {
-	hostinfo.ConnectionState = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
-}
-
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
 func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
 	fp := &firewall.Packet{}
 	fp := &firewall.Packet{}
 	err := newPacket(p, false, fp)
 	err := newPacket(p, false, fp)

+ 1 - 1
ssh.go

@@ -607,7 +607,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW
 		}
 		}
 	}
 	}
 
 
-	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp, ifce.initHostInfo)
+	hostInfo = ifce.handshakeManager.AddVpnIp(vpnIp)
 	if addr != nil {
 	if addr != nil {
 		hostInfo.SetRemote(addr)
 		hostInfo.SetRemote(addr)
 	}
 	}