浏览代码

Combine ca, cert, and key handling (#952)

Nate Brown 1 年之前
父节点
当前提交
5a131b2975
共有 17 个文件被更改,包括 381 次插入294 次删除
  1. 0 163
      cert.go
  2. 2 7
      cmd/nebula-service/main.go
  3. 2 7
      cmd/nebula/main.go
  4. 5 5
      connection_manager.go
  5. 19 16
      connection_manager_test.go
  6. 4 4
      connection_state.go
  7. 1 1
      control_tester.go
  8. 4 4
      handshake_ix.go
  9. 2 2
      inside.go
  10. 9 48
      interface.go
  11. 1 1
      lighthouse.go
  12. 16 30
      main.go
  13. 1 1
      outside.go
  14. 248 0
      pki.go
  15. 1 1
      ssh.go
  16. 24 4
      util/error.go
  17. 42 0
      util/error_test.go

+ 0 - 163
cert.go

@@ -1,163 +0,0 @@
-package nebula
-
-import (
-	"errors"
-	"fmt"
-	"io/ioutil"
-	"strings"
-	"time"
-
-	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
-	"github.com/slackhq/nebula/config"
-)
-
-type CertState struct {
-	certificate         *cert.NebulaCertificate
-	rawCertificate      []byte
-	rawCertificateNoKey []byte
-	publicKey           []byte
-	privateKey          []byte
-}
-
-func NewCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
-	// Marshal the certificate to ensure it is valid
-	rawCertificate, err := certificate.Marshal()
-	if err != nil {
-		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
-	}
-
-	publicKey := certificate.Details.PublicKey
-	cs := &CertState{
-		rawCertificate: rawCertificate,
-		certificate:    certificate, // PublicKey has been set to nil above
-		privateKey:     privateKey,
-		publicKey:      publicKey,
-	}
-
-	cs.certificate.Details.PublicKey = nil
-	rawCertNoKey, err := cs.certificate.Marshal()
-	if err != nil {
-		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
-	}
-	cs.rawCertificateNoKey = rawCertNoKey
-	// put public key back
-	cs.certificate.Details.PublicKey = cs.publicKey
-	return cs, nil
-}
-
-func NewCertStateFromConfig(c *config.C) (*CertState, error) {
-	var pemPrivateKey []byte
-	var err error
-
-	privPathOrPEM := c.GetString("pki.key", "")
-
-	if privPathOrPEM == "" {
-		return nil, errors.New("no pki.key path or PEM data provided")
-	}
-
-	if strings.Contains(privPathOrPEM, "-----BEGIN") {
-		pemPrivateKey = []byte(privPathOrPEM)
-		privPathOrPEM = "<inline>"
-	} else {
-		pemPrivateKey, err = ioutil.ReadFile(privPathOrPEM)
-		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
-		}
-	}
-
-	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
-	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
-	}
-
-	var rawCert []byte
-
-	pubPathOrPEM := c.GetString("pki.cert", "")
-
-	if pubPathOrPEM == "" {
-		return nil, errors.New("no pki.cert path or PEM data provided")
-	}
-
-	if strings.Contains(pubPathOrPEM, "-----BEGIN") {
-		rawCert = []byte(pubPathOrPEM)
-		pubPathOrPEM = "<inline>"
-	} else {
-		rawCert, err = ioutil.ReadFile(pubPathOrPEM)
-		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
-		}
-	}
-
-	nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
-	if err != nil {
-		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
-	}
-
-	if nebulaCert.Expired(time.Now()) {
-		return nil, fmt.Errorf("nebula certificate for this host is expired")
-	}
-
-	if len(nebulaCert.Details.Ips) == 0 {
-		return nil, fmt.Errorf("no IPs encoded in certificate")
-	}
-
-	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
-		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
-	}
-
-	return NewCertState(nebulaCert, rawKey)
-}
-
-func loadCAFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
-	var rawCA []byte
-	var err error
-
-	caPathOrPEM := c.GetString("pki.ca", "")
-	if caPathOrPEM == "" {
-		return nil, errors.New("no pki.ca path or PEM data provided")
-	}
-
-	if strings.Contains(caPathOrPEM, "-----BEGIN") {
-		rawCA = []byte(caPathOrPEM)
-
-	} else {
-		rawCA, err = ioutil.ReadFile(caPathOrPEM)
-		if err != nil {
-			return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
-		}
-	}
-
-	CAs, err := cert.NewCAPoolFromBytes(rawCA)
-	if errors.Is(err, cert.ErrExpired) {
-		var expired int
-		for _, cert := range CAs.CAs {
-			if cert.Expired(time.Now()) {
-				expired++
-				l.WithField("cert", cert).Warn("expired certificate present in CA pool")
-			}
-		}
-
-		if expired >= len(CAs.CAs) {
-			return nil, errors.New("no valid CA certificates present")
-		}
-
-	} else if err != nil {
-		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
-	}
-
-	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
-		l.WithField("fingerprint", fp).Info("Blocklisting cert")
-		CAs.BlocklistFingerprint(fp)
-	}
-
-	// Support deprecated config for at least one minor release to allow for migrations
-	//TODO: remove in 2022 or later
-	for _, fp := range c.GetStringSlice("pki.blacklist", []string{}) {
-		l.WithField("fingerprint", fp).Info("Blocklisting cert")
-		l.Warn("pki.blacklist is deprecated and will not be supported in a future release. Please migrate your config to use pki.blocklist")
-		CAs.BlocklistFingerprint(fp)
-	}
-
-	return CAs, nil
-}

+ 2 - 7
cmd/nebula-service/main.go

@@ -59,13 +59,8 @@ func main() {
 	}
 
 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
-
-	switch v := err.(type) {
-	case util.ContextualError:
-		v.Log(l)
-		os.Exit(1)
-	case error:
-		l.WithError(err).Error("Failed to start")
+	if err != nil {
+		util.LogWithContextIfNeeded("Failed to start", err, l)
 		os.Exit(1)
 	}
 

+ 2 - 7
cmd/nebula/main.go

@@ -53,13 +53,8 @@ func main() {
 	}
 
 	ctrl, err := nebula.Main(c, *configTest, Build, l, nil)
-
-	switch v := err.(type) {
-	case util.ContextualError:
-		v.Log(l)
-		os.Exit(1)
-	case error:
-		l.WithError(err).Error("Failed to start")
+	if err != nil {
+		util.LogWithContextIfNeeded("Failed to start", err, l)
 		os.Exit(1)
 	}
 

+ 5 - 5
connection_manager.go

@@ -405,8 +405,8 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {
 		return false
 	}
 
-	certState := n.intf.certState.Load()
-	return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)
+	certState := n.intf.pki.GetCertState()
+	return bytes.Equal(current.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature)
 }
 
 func (n *connectionManager) swapPrimary(current, primary *HostInfo) {
@@ -427,7 +427,7 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn
 		return false
 	}
 
-	valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)
+	valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool())
 	if valid {
 		return false
 	}
@@ -464,8 +464,8 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
 }
 
 func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {
-	certState := n.intf.certState.Load()
-	if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {
+	certState := n.intf.pki.GetCertState()
+	if bytes.Equal(hostinfo.ConnectionState.certState.Certificate.Signature, certState.Certificate.Signature) {
 		return
 	}
 

+ 19 - 16
connection_manager_test.go

@@ -44,10 +44,10 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 	// Very incomplete mock objects
 	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 	cs := &CertState{
-		rawCertificate:      []byte{},
-		privateKey:          []byte{},
-		certificate:         &cert.NebulaCertificate{},
-		rawCertificateNoKey: []byte{},
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -57,10 +57,11 @@ func Test_NewConnectionManagerTest(t *testing.T) {
 		outside:          &udp.NoopConn{},
 		firewall:         &Firewall{},
 		lightHouse:       lh,
+		pki:              &PKI{},
 		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                l,
 	}
-	ifce.certState.Store(cs)
+	ifce.pki.cs.Store(cs)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
@@ -123,10 +124,10 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	// Very incomplete mock objects
 	hostMap := NewHostMap(l, vpncidr, preferredRanges)
 	cs := &CertState{
-		rawCertificate:      []byte{},
-		privateKey:          []byte{},
-		certificate:         &cert.NebulaCertificate{},
-		rawCertificateNoKey: []byte{},
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -136,10 +137,11 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 		outside:          &udp.NoopConn{},
 		firewall:         &Firewall{},
 		lightHouse:       lh,
+		pki:              &PKI{},
 		handshakeManager: NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                l,
 	}
-	ifce.certState.Store(cs)
+	ifce.pki.cs.Store(cs)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())
@@ -242,10 +244,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 	peerCert.Sign(cert.Curve_CURVE25519, privCA)
 
 	cs := &CertState{
-		rawCertificate:      []byte{},
-		privateKey:          []byte{},
-		certificate:         &cert.NebulaCertificate{},
-		rawCertificateNoKey: []byte{},
+		RawCertificate:      []byte{},
+		PrivateKey:          []byte{},
+		Certificate:         &cert.NebulaCertificate{},
+		RawCertificateNoKey: []byte{},
 	}
 
 	lh := newTestLighthouse()
@@ -258,9 +260,10 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
 		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udp.NoopConn{}, defaultHandshakeConfig),
 		l:                 l,
 		disconnectInvalid: true,
-		caPool:            ncp,
+		pki:               &PKI{},
 	}
-	ifce.certState.Store(cs)
+	ifce.pki.cs.Store(cs)
+	ifce.pki.caPool.Store(ncp)
 
 	// Create manager
 	ctx, cancel := context.WithCancel(context.Background())

+ 4 - 4
connection_state.go

@@ -30,15 +30,15 @@ type ConnectionState struct {
 
 func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
 	var dhFunc noise.DHFunc
-	curCertState := f.certState.Load()
+	curCertState := f.pki.GetCertState()
 
-	switch curCertState.certificate.Details.Curve {
+	switch curCertState.Certificate.Details.Curve {
 	case cert.Curve_CURVE25519:
 		dhFunc = noise.DH25519
 	case cert.Curve_P256:
 		dhFunc = noiseutil.DHP256
 	default:
-		l.Errorf("invalid curve: %s", curCertState.certificate.Details.Curve)
+		l.Errorf("invalid curve: %s", curCertState.Certificate.Details.Curve)
 		return nil
 	}
 	cs := noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256)
@@ -46,7 +46,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
 		cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256)
 	}
 
-	static := noise.DHKey{Private: curCertState.privateKey, Public: curCertState.publicKey}
+	static := noise.DHKey{Private: curCertState.PrivateKey, Public: curCertState.PublicKey}
 
 	b := NewBits(ReplayWindow)
 	// Clear out bit 0, we never transmit it and we don't want it showing as packet loss

+ 1 - 1
control_tester.go

@@ -161,7 +161,7 @@ func (c *Control) GetHostmap() *HostMap {
 }
 
 func (c *Control) GetCert() *cert.NebulaCertificate {
-	return c.f.certState.Load().certificate
+	return c.f.pki.GetCertState().Certificate
 }
 
 func (c *Control) ReHandshake(vpnIp iputil.VpnIp) {

+ 4 - 4
handshake_ix.go

@@ -33,7 +33,7 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
 	hsProto := &NebulaHandshakeDetails{
 		InitiatorIndex: hostinfo.localIndexId,
 		Time:           uint64(time.Now().UnixNano()),
-		Cert:           ci.certState.rawCertificateNoKey,
+		Cert:           ci.certState.RawCertificateNoKey,
 	}
 
 	hsBytes := []byte{}
@@ -91,7 +91,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		return
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
+	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
@@ -155,7 +155,7 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
 		Info("Handshake message received")
 
 	hs.Details.ResponderIndex = myIndex
-	hs.Details.Cert = ci.certState.rawCertificateNoKey
+	hs.Details.Cert = ci.certState.RawCertificateNoKey
 	// Update the time in case their clock is way off from ours
 	hs.Details.Time = uint64(time.Now().UnixNano())
 
@@ -399,7 +399,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
 		return true
 	}
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
+	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool())
 	if err != nil {
 		f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).

+ 2 - 2
inside.go

@@ -69,7 +69,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
 		ci.queueLock.Unlock()
 	}
 
-	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
+	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason == nil {
 		f.sendNoMetrics(header.Message, 0, ci, hostinfo, nil, packet, nb, out, q)
 
@@ -183,7 +183,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
 	}
 
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.caPool, nil)
+	dropReason := f.firewall.Drop(p, *fp, false, hostinfo, f.pki.GetCAPool(), nil)
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("fwPacket", fp).

+ 9 - 48
interface.go

@@ -13,7 +13,6 @@ import (
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
-	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/config"
 	"github.com/slackhq/nebula/firewall"
 	"github.com/slackhq/nebula/header"
@@ -28,7 +27,7 @@ type InterfaceConfig struct {
 	HostMap                 *HostMap
 	Outside                 udp.Conn
 	Inside                  overlay.Device
-	certState               *CertState
+	pki                     *PKI
 	Cipher                  string
 	Firewall                *Firewall
 	ServeDns                bool
@@ -41,7 +40,6 @@ type InterfaceConfig struct {
 	routines                int
 	MessageMetrics          *MessageMetrics
 	version                 string
-	caPool                  *cert.NebulaCAPool
 	disconnectInvalid       bool
 	relayManager            *relayManager
 	punchy                  *Punchy
@@ -58,7 +56,7 @@ type Interface struct {
 	hostMap            *HostMap
 	outside            udp.Conn
 	inside             overlay.Device
-	certState          atomic.Pointer[CertState]
+	pki                *PKI
 	cipher             string
 	firewall           *Firewall
 	connectionManager  *connectionManager
@@ -71,7 +69,6 @@ type Interface struct {
 	dropLocalBroadcast bool
 	dropMulticast      bool
 	routines           int
-	caPool             *cert.NebulaCAPool
 	disconnectInvalid  bool
 	closed             atomic.Bool
 	relayManager       *relayManager
@@ -152,15 +149,17 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	if c.Inside == nil {
 		return nil, errors.New("no inside interface (tun)")
 	}
-	if c.certState == nil {
+	if c.pki == nil {
 		return nil, errors.New("no certificate state")
 	}
 	if c.Firewall == nil {
 		return nil, errors.New("no firewall rules")
 	}
 
-	myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
+	certificate := c.pki.GetCertState().Certificate
+	myVpnIp := iputil.Ip2VpnIp(certificate.Details.Ips[0].IP)
 	ifce := &Interface{
+		pki:                c.pki,
 		hostMap:            c.HostMap,
 		outside:            c.Outside,
 		inside:             c.Inside,
@@ -170,14 +169,13 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 		handshakeManager:   c.HandshakeManager,
 		createTime:         time.Now(),
 		lightHouse:         c.lightHouse,
-		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].Mask),
+		localBroadcast:     myVpnIp | ^iputil.Ip2VpnIp(certificate.Details.Ips[0].Mask),
 		dropLocalBroadcast: c.DropLocalBroadcast,
 		dropMulticast:      c.DropMulticast,
 		routines:           c.routines,
 		version:            c.version,
 		writers:            make([]udp.Conn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
-		caPool:             c.caPool,
 		disconnectInvalid:  c.disconnectInvalid,
 		myVpnIp:            myVpnIp,
 		relayManager:       c.relayManager,
@@ -198,7 +196,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
 	ifce.reQueryEvery.Store(c.reQueryEvery)
 	ifce.reQueryWait.Store(int64(c.reQueryWait))
 
-	ifce.certState.Store(c.certState)
 	ifce.connectionManager = newConnectionManager(ctx, c.l, ifce, c.checkInterval, c.pendingDeletionInterval, c.punchy)
 
 	return ifce, nil
@@ -295,8 +292,6 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) {
 }
 
 func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
-	c.RegisterReloadCallback(f.reloadCA)
-	c.RegisterReloadCallback(f.reloadCertKey)
 	c.RegisterReloadCallback(f.reloadFirewall)
 	c.RegisterReloadCallback(f.reloadSendRecvError)
 	c.RegisterReloadCallback(f.reloadMisc)
@@ -305,40 +300,6 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
 	}
 }
 
-func (f *Interface) reloadCA(c *config.C) {
-	// reload and check regardless
-	// todo: need mutex?
-	newCAs, err := loadCAFromConfig(f.l, c)
-	if err != nil {
-		f.l.WithError(err).Error("Could not refresh trusted CA certificates")
-		return
-	}
-
-	f.caPool = newCAs
-	f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
-}
-
-func (f *Interface) reloadCertKey(c *config.C) {
-	// reload and check in all cases
-	cs, err := NewCertStateFromConfig(c)
-	if err != nil {
-		f.l.WithError(err).Error("Could not refresh client cert")
-		return
-	}
-
-	// did IP in cert change? if so, don't set
-	currentCert := f.certState.Load().certificate
-	oldIPs := currentCert.Details.Ips
-	newIPs := cs.certificate.Details.Ips
-	if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
-		f.l.WithField("new_ip", newIPs[0]).WithField("old_ip", oldIPs[0]).Error("IP in new cert was different from old")
-		return
-	}
-
-	f.certState.Store(cs)
-	f.l.WithField("cert", cs.certificate).Info("Client cert refreshed from disk")
-}
-
 func (f *Interface) reloadFirewall(c *config.C) {
 	//TODO: need to trigger/detect if the certificate changed too
 	if c.HasChanged("firewall") == false {
@@ -346,7 +307,7 @@ func (f *Interface) reloadFirewall(c *config.C) {
 		return
 	}
 
-	fw, err := NewFirewallFromConfig(f.l, f.certState.Load().certificate, c)
+	fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c)
 	if err != nil {
 		f.l.WithError(err).Error("Error while creating firewall during reload")
 		return
@@ -438,7 +399,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
 			f.firewall.EmitStats()
 			f.handshakeManager.EmitStats()
 			udpStats()
-			certExpirationGauge.Update(int64(f.certState.Load().certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
+			certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second))
 		}
 	}
 }

+ 1 - 1
lighthouse.go

@@ -132,7 +132,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C,
 	c.RegisterReloadCallback(func(c *config.C) {
 		err := h.reload(c, false)
 		switch v := err.(type) {
-		case util.ContextualError:
+		case *util.ContextualError:
 			v.Log(l)
 		case error:
 			l.WithError(err).Error("failed to reload lighthouse")

+ 16 - 30
main.go

@@ -3,7 +3,6 @@ package nebula
 import (
 	"context"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"net"
 	"time"
@@ -46,7 +45,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	err := configLogger(l, c)
 	if err != nil {
-		return nil, util.NewContextualError("Failed to configure the logger", nil, err)
+		return nil, util.ContextualizeIfNeeded("Failed to configure the logger", err)
 	}
 
 	c.RegisterReloadCallback(func(c *config.C) {
@@ -56,28 +55,20 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		}
 	})
 
-	caPool, err := loadCAFromConfig(l, c)
+	pki, err := NewPKIFromConfig(l, c)
 	if err != nil {
-		//The errors coming out of loadCA are already nicely formatted
-		return nil, util.NewContextualError("Failed to load ca from config", nil, err)
+		return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
 	}
-	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 
-	cs, err := NewCertStateFromConfig(c)
+	certificate := pki.GetCertState().Certificate
+	fw, err := NewFirewallFromConfig(l, certificate, c)
 	if err != nil {
-		//The errors coming out of NewCertStateFromConfig are already nicely formatted
-		return nil, util.NewContextualError("Failed to load certificate from config", nil, err)
-	}
-	l.WithField("cert", cs.certificate).Debug("Client nebula certificate")
-
-	fw, err := NewFirewallFromConfig(l, cs.certificate, c)
-	if err != nil {
-		return nil, util.NewContextualError("Error while loading firewall rules", nil, err)
+		return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
 	}
 	l.WithField("firewallHash", fw.GetRuleHash()).Info("Firewall started")
 
 	// TODO: make sure mask is 4 bytes
-	tunCidr := cs.certificate.Details.Ips[0]
+	tunCidr := certificate.Details.Ips[0]
 
 	ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd"))
 	wireSSHReload(l, ssh, c)
@@ -85,7 +76,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if c.GetBool("sshd.enabled", false) {
 		sshStart, err = configSSH(l, ssh, c)
 		if err != nil {
-			return nil, util.NewContextualError("Error while configuring the sshd", nil, err)
+			return nil, util.ContextualizeIfNeeded("Error while configuring the sshd", err)
 		}
 	}
 
@@ -136,7 +127,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 		tun, err = overlay.NewDeviceFromConfig(c, l, tunCidr, tunFd, routines)
 		if err != nil {
-			return nil, util.NewContextualError("Failed to get a tun/tap device", nil, err)
+			return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err)
 		}
 
 		defer func() {
@@ -160,7 +151,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		} else {
 			listenHost, err = net.ResolveIPAddr("ip", rawListenHost)
 			if err != nil {
-				return nil, util.NewContextualError("Failed to resolve listen.host", nil, err)
+				return nil, util.ContextualizeIfNeeded("Failed to resolve listen.host", err)
 			}
 		}
 
@@ -182,7 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		for _, rawPreferredRange := range rawPreferredRanges {
 			_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
 			if err != nil {
-				return nil, util.NewContextualError("Failed to parse preferred ranges", nil, err)
+				return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
 			}
 			preferredRanges = append(preferredRanges, preferredRange)
 		}
@@ -195,7 +186,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	if rawLocalRange != "" {
 		_, localRange, err := net.ParseCIDR(rawLocalRange)
 		if err != nil {
-			return nil, util.NewContextualError("Failed to parse local_range", nil, err)
+			return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
 		}
 
 		// Check if the entry for local_range was already specified in
@@ -222,11 +213,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 
 	punchy := NewPunchyFromConfig(l, c)
 	lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
-	switch {
-	case errors.As(err, &util.ContextualError{}):
-		return nil, err
-	case err != nil:
-		return nil, util.NewContextualError("Failed to initialize lighthouse handler", nil, err)
+	if err != nil {
+		return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err)
 	}
 
 	var messageMetrics *MessageMetrics
@@ -266,7 +254,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		HostMap:                 hostMap,
 		Inside:                  tun,
 		Outside:                 udpConns[0],
-		certState:               cs,
+		pki:                     pki,
 		Cipher:                  c.GetString("cipher", "aes"),
 		Firewall:                fw,
 		ServeDns:                serveDns,
@@ -282,7 +270,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
-		caPool:                  caPool,
 		disconnectInvalid:       c.GetBool("pki.disconnect_invalid", false),
 		relayManager:            NewRelayManager(ctx, l, hostMap, c),
 		punchy:                  punchy,
@@ -321,9 +308,8 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
 	// TODO - stats third-party modules start uncancellable goroutines. Update those libs to accept
 	// a context so that they can exit when the context is Done.
 	statsStart, err := startStats(l, c, buildVersion, configTest)
-
 	if err != nil {
-		return nil, util.NewContextualError("Failed to start stats emitter", nil, err)
+		return nil, util.ContextualizeIfNeeded("Failed to start stats emitter", err)
 	}
 
 	if configTest {

+ 1 - 1
outside.go

@@ -404,7 +404,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return false
 	}
 
-	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
+	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
 	if dropReason != nil {
 		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, out, q)
 		if f.l.Level >= logrus.DebugLevel {

+ 248 - 0
pki.go

@@ -0,0 +1,248 @@
+package nebula
+
+import (
+	"errors"
+	"fmt"
+	"os"
+	"strings"
+	"sync/atomic"
+	"time"
+
+	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
+	"github.com/slackhq/nebula/config"
+	"github.com/slackhq/nebula/util"
+)
+
+type PKI struct {
+	cs     atomic.Pointer[CertState]
+	caPool atomic.Pointer[cert.NebulaCAPool]
+	l      *logrus.Logger
+}
+
+type CertState struct {
+	Certificate         *cert.NebulaCertificate
+	RawCertificate      []byte
+	RawCertificateNoKey []byte
+	PublicKey           []byte
+	PrivateKey          []byte
+}
+
+func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) {
+	pki := &PKI{l: l}
+	err := pki.reload(c, true)
+	if err != nil {
+		return nil, err
+	}
+
+	c.RegisterReloadCallback(func(c *config.C) {
+		rErr := pki.reload(c, false)
+		if rErr != nil {
+			util.LogWithContextIfNeeded("Failed to reload PKI from config", rErr, l)
+		}
+	})
+
+	return pki, nil
+}
+
+func (p *PKI) GetCertState() *CertState {
+	return p.cs.Load()
+}
+
+func (p *PKI) GetCAPool() *cert.NebulaCAPool {
+	return p.caPool.Load()
+}
+
+func (p *PKI) reload(c *config.C, initial bool) error {
+	err := p.reloadCert(c, initial)
+	if err != nil {
+		if initial {
+			return err
+		}
+		err.Log(p.l)
+	}
+
+	err = p.reloadCAPool(c)
+	if err != nil {
+		if initial {
+			return err
+		}
+		err.Log(p.l)
+	}
+
+	return nil
+}
+
+func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError {
+	cs, err := newCertStateFromConfig(c)
+	if err != nil {
+		return util.NewContextualError("Could not load client cert", nil, err)
+	}
+
+	if !initial {
+		// did IP in cert change? if so, don't set
+		currentCert := p.cs.Load().Certificate
+		oldIPs := currentCert.Details.Ips
+		newIPs := cs.Certificate.Details.Ips
+		if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() {
+			return util.NewContextualError(
+				"IP in new cert was different from old",
+				m{"new_ip": newIPs[0], "old_ip": oldIPs[0]},
+				nil,
+			)
+		}
+	}
+
+	p.cs.Store(cs)
+	if initial {
+		p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate")
+	} else {
+		p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk")
+	}
+	return nil
+}
+
+func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError {
+	caPool, err := loadCAPoolFromConfig(p.l, c)
+	if err != nil {
+		return util.NewContextualError("Failed to load ca from config", nil, err)
+	}
+
+	p.caPool.Store(caPool)
+	p.l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
+	return nil
+}
+
+func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) {
+	// Marshal the certificate to ensure it is valid
+	rawCertificate, err := certificate.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err)
+	}
+
+	publicKey := certificate.Details.PublicKey
+	cs := &CertState{
+		RawCertificate: rawCertificate,
+		Certificate:    certificate,
+		PrivateKey:     privateKey,
+		PublicKey:      publicKey,
+	}
+
+	cs.Certificate.Details.PublicKey = nil
+	rawCertNoKey, err := cs.Certificate.Marshal()
+	if err != nil {
+		return nil, fmt.Errorf("error marshalling certificate no key: %s", err)
+	}
+	cs.RawCertificateNoKey = rawCertNoKey
+	// put public key back
+	cs.Certificate.Details.PublicKey = cs.PublicKey
+	return cs, nil
+}
+
+func newCertStateFromConfig(c *config.C) (*CertState, error) {
+	var pemPrivateKey []byte
+	var err error
+
+	privPathOrPEM := c.GetString("pki.key", "")
+	if privPathOrPEM == "" {
+		return nil, errors.New("no pki.key path or PEM data provided")
+	}
+
+	if strings.Contains(privPathOrPEM, "-----BEGIN") {
+		pemPrivateKey = []byte(privPathOrPEM)
+		privPathOrPEM = "<inline>"
+
+	} else {
+		pemPrivateKey, err = os.ReadFile(privPathOrPEM)
+		if err != nil {
+			return nil, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err)
+		}
+	}
+
+	rawKey, _, curve, err := cert.UnmarshalPrivateKey(pemPrivateKey)
+	if err != nil {
+		return nil, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err)
+	}
+
+	var rawCert []byte
+
+	pubPathOrPEM := c.GetString("pki.cert", "")
+	if pubPathOrPEM == "" {
+		return nil, errors.New("no pki.cert path or PEM data provided")
+	}
+
+	if strings.Contains(pubPathOrPEM, "-----BEGIN") {
+		rawCert = []byte(pubPathOrPEM)
+		pubPathOrPEM = "<inline>"
+
+	} else {
+		rawCert, err = os.ReadFile(pubPathOrPEM)
+		if err != nil {
+			return nil, fmt.Errorf("unable to read pki.cert file %s: %s", pubPathOrPEM, err)
+		}
+	}
+
+	nebulaCert, _, err := cert.UnmarshalNebulaCertificateFromPEM(rawCert)
+	if err != nil {
+		return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err)
+	}
+
+	if nebulaCert.Expired(time.Now()) {
+		return nil, fmt.Errorf("nebula certificate for this host is expired")
+	}
+
+	if len(nebulaCert.Details.Ips) == 0 {
+		return nil, fmt.Errorf("no IPs encoded in certificate")
+	}
+
+	if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil {
+		return nil, fmt.Errorf("private key is not a pair with public key in nebula cert")
+	}
+
+	return newCertState(nebulaCert, rawKey)
+}
+
+func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.NebulaCAPool, error) {
+	var rawCA []byte
+	var err error
+
+	caPathOrPEM := c.GetString("pki.ca", "")
+	if caPathOrPEM == "" {
+		return nil, errors.New("no pki.ca path or PEM data provided")
+	}
+
+	if strings.Contains(caPathOrPEM, "-----BEGIN") {
+		rawCA = []byte(caPathOrPEM)
+
+	} else {
+		rawCA, err = os.ReadFile(caPathOrPEM)
+		if err != nil {
+			return nil, fmt.Errorf("unable to read pki.ca file %s: %s", caPathOrPEM, err)
+		}
+	}
+
+	caPool, err := cert.NewCAPoolFromBytes(rawCA)
+	if errors.Is(err, cert.ErrExpired) {
+		var expired int
+		for _, crt := range caPool.CAs {
+			if crt.Expired(time.Now()) {
+				expired++
+				l.WithField("cert", crt).Warn("expired certificate present in CA pool")
+			}
+		}
+
+		if expired >= len(caPool.CAs) {
+			return nil, errors.New("no valid CA certificates present")
+		}
+
+	} else if err != nil {
+		return nil, fmt.Errorf("error while adding CA certificate to CA trust store: %s", err)
+	}
+
+	for _, fp := range c.GetStringSlice("pki.blocklist", []string{}) {
+		l.WithField("fingerprint", fp).Info("Blocklisting cert")
+		caPool.BlocklistFingerprint(fp)
+	}
+
+	return caPool, nil
+}

+ 1 - 1
ssh.go

@@ -754,7 +754,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit
 		return nil
 	}
 
-	cert := ifce.certState.Load().certificate
+	cert := ifce.pki.GetCertState().Certificate
 	if len(a) > 0 {
 		parsedIp := net.ParseIP(a[0])
 		if parsedIp == nil {

+ 24 - 4
util/error.go

@@ -12,18 +12,38 @@ type ContextualError struct {
 	Context   string
 }
 
-func NewContextualError(msg string, fields map[string]interface{}, realError error) ContextualError {
-	return ContextualError{Context: msg, Fields: fields, RealError: realError}
+func NewContextualError(msg string, fields map[string]interface{}, realError error) *ContextualError {
+	return &ContextualError{Context: msg, Fields: fields, RealError: realError}
 }
 
-func (ce ContextualError) Error() string {
+// ContextualizeIfNeeded is a helper function to turn an error into a ContextualError if it is not already one
+func ContextualizeIfNeeded(msg string, err error) error {
+	switch err.(type) {
+	case *ContextualError:
+		return err
+	default:
+		return NewContextualError(msg, nil, err)
+	}
+}
+
+// LogWithContextIfNeeded is a helper function to log an error line for an error or ContextualError
+func LogWithContextIfNeeded(msg string, err error, l *logrus.Logger) {
+	switch v := err.(type) {
+	case *ContextualError:
+		v.Log(l)
+	default:
+		l.WithError(err).Error(msg)
+	}
+}
+
+func (ce *ContextualError) Error() string {
 	if ce.RealError == nil {
 		return ce.Context
 	}
 	return ce.RealError.Error()
 }
 
-func (ce ContextualError) Unwrap() error {
+func (ce *ContextualError) Unwrap() error {
 	if ce.RealError == nil {
 		return errors.New(ce.Context)
 	}

+ 42 - 0
util/error_test.go

@@ -2,6 +2,7 @@ package util
 
 import (
 	"errors"
+	"fmt"
 	"testing"
 
 	"github.com/sirupsen/logrus"
@@ -67,3 +68,44 @@ func TestContextualError_Log(t *testing.T) {
 	e.Log(l)
 	assert.Equal(t, []string{"level=error error=error\n"}, tl.Logs)
 }
+
+func TestLogWithContextIfNeeded(t *testing.T) {
+	l := logrus.New()
+	l.Formatter = &logrus.TextFormatter{
+		DisableTimestamp: true,
+		DisableColors:    true,
+	}
+
+	tl := NewTestLogWriter()
+	l.Out = tl
+
+	// Test ignoring fallback context
+	tl.Reset()
+	e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
+	LogWithContextIfNeeded("This should get thrown away", e, l)
+	assert.Equal(t, []string{"level=error msg=\"test message\" error=error field=1\n"}, tl.Logs)
+
+	// Test using fallback context
+	tl.Reset()
+	err := fmt.Errorf("this is a normal error")
+	LogWithContextIfNeeded("Fallback context woo", err, l)
+	assert.Equal(t, []string{"level=error msg=\"Fallback context woo\" error=\"this is a normal error\"\n"}, tl.Logs)
+}
+
+func TestContextualizeIfNeeded(t *testing.T) {
+	// Test ignoring fallback context
+	e := NewContextualError("test message", m{"field": "1"}, errors.New("error"))
+	assert.Same(t, e, ContextualizeIfNeeded("should be ignored", e))
+
+	// Test using fallback context
+	err := fmt.Errorf("this is a normal error")
+	cErr := ContextualizeIfNeeded("Fallback context woo", err)
+
+	switch v := cErr.(type) {
+	case *ContextualError:
+		assert.Equal(t, err, v.RealError)
+	default:
+		t.Error("Error was not wrapped")
+		t.Fail()
+	}
+}