Ver código fonte

Don't use a global ca pool (#426)

Nathan Brown 4 anos atrás
pai
commit
883e09a392
6 arquivos alterados com 16 adições e 14 exclusões
  1. 0 2
      cert.go
  2. 2 2
      handshake_ix.go
  3. 2 2
      inside.go
  4. 6 2
      interface.go
  5. 3 3
      main.go
  6. 3 3
      outside.go

+ 0 - 2
cert.go

@@ -11,8 +11,6 @@ import (
 	"github.com/slackhq/nebula/cert"
 	"github.com/slackhq/nebula/cert"
 )
 )
 
 
-var trustedCAs *cert.NebulaCAPool
-
 type CertState struct {
 type CertState struct {
 	certificate         *cert.NebulaCertificate
 	certificate         *cert.NebulaCertificate
 	rawCertificate      []byte
 	rawCertificate      []byte

+ 2 - 2
handshake_ix.go

@@ -96,7 +96,7 @@ func ixHandshakeStage1(f *Interface, addr *udpAddr, packet []byte, h *Header) {
 		return
 		return
 	}
 	}
 
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
+	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("udpAddr", addr).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
 			WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
@@ -318,7 +318,7 @@ func ixHandshakeStage2(f *Interface, addr *udpAddr, hostinfo *HostInfo, packet [
 		return true
 		return true
 	}
 	}
 
 
-	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert)
+	remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
 	if err != nil {
 	if err != nil {
 		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 		f.l.WithError(err).WithField("vpnIp", IntIp(hostinfo.hostId)).WithField("udpAddr", addr).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
 			WithField("cert", remoteCert).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).

+ 2 - 2
inside.go

@@ -52,7 +52,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *FirewallPacket,
 		ci.queueLock.Unlock()
 		ci.queueLock.Unlock()
 	}
 	}
 
 
-	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, trustedCAs, localCache)
+	dropReason := f.firewall.Drop(packet, *fwPacket, false, hostinfo, f.caPool, localCache)
 	if dropReason == nil {
 	if dropReason == nil {
 		mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 		mc := f.sendNoMetrics(message, 0, ci, hostinfo, hostinfo.remote, packet, nb, out, q)
 		if f.lightHouse != nil && mc%5000 == 0 {
 		if f.lightHouse != nil && mc%5000 == 0 {
@@ -140,7 +140,7 @@ func (f *Interface) sendMessageNow(t NebulaMessageType, st NebulaMessageSubType,
 	}
 	}
 
 
 	// check if packet is in outbound fw rules
 	// check if packet is in outbound fw rules
-	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, trustedCAs, nil)
+	dropReason := f.firewall.Drop(p, *fp, false, hostInfo, f.caPool, nil)
 	if dropReason != nil {
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
 			f.l.WithField("fwPacket", fp).
 			f.l.WithField("fwPacket", fp).

+ 6 - 2
interface.go

@@ -10,6 +10,7 @@ import (
 
 
 	"github.com/rcrowley/go-metrics"
 	"github.com/rcrowley/go-metrics"
 	"github.com/sirupsen/logrus"
 	"github.com/sirupsen/logrus"
+	"github.com/slackhq/nebula/cert"
 )
 )
 
 
 const mtu = 9001
 const mtu = 9001
@@ -41,6 +42,7 @@ type InterfaceConfig struct {
 	routines                int
 	routines                int
 	MessageMetrics          *MessageMetrics
 	MessageMetrics          *MessageMetrics
 	version                 string
 	version                 string
+	caPool                  *cert.NebulaCAPool
 
 
 	ConntrackCacheTimeout time.Duration
 	ConntrackCacheTimeout time.Duration
 	l                     *logrus.Logger
 	l                     *logrus.Logger
@@ -63,6 +65,7 @@ type Interface struct {
 	dropMulticast      bool
 	dropMulticast      bool
 	udpBatchSize       int
 	udpBatchSize       int
 	routines           int
 	routines           int
+	caPool             *cert.NebulaCAPool
 
 
 	// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
 	// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
 	rebindCount int8
 	rebindCount int8
@@ -111,6 +114,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 		version:            c.version,
 		version:            c.version,
 		writers:            make([]*udpConn, c.routines),
 		writers:            make([]*udpConn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
+		caPool:             c.caPool,
 
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,
 
 
@@ -218,8 +222,8 @@ func (f *Interface) reloadCA(c *Config) {
 		return
 		return
 	}
 	}
 
 
-	trustedCAs = newCAs
-	f.l.WithField("fingerprints", trustedCAs.GetFingerprints()).Info("Trusted CA certificates refreshed")
+	f.caPool = newCAs
+	f.l.WithField("fingerprints", f.caPool.GetFingerprints()).Info("Trusted CA certificates refreshed")
 }
 }
 
 
 func (f *Interface) reloadCertKey(c *Config) {
 func (f *Interface) reloadCertKey(c *Config) {

+ 3 - 3
main.go

@@ -42,13 +42,12 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		}
 		}
 	})
 	})
 
 
-	// trustedCAs is currently a global, so loadCA operates on that global directly
-	trustedCAs, err = loadCAFromConfig(l, config)
+	caPool, err := loadCAFromConfig(l, config)
 	if err != nil {
 	if err != nil {
 		//The errors coming out of loadCA are already nicely formatted
 		//The errors coming out of loadCA are already nicely formatted
 		return nil, NewContextualError("Failed to load ca from config", nil, err)
 		return nil, NewContextualError("Failed to load ca from config", nil, err)
 	}
 	}
-	l.WithField("fingerprints", trustedCAs.GetFingerprints()).Debug("Trusted CA fingerprints")
+	l.WithField("fingerprints", caPool.GetFingerprints()).Debug("Trusted CA fingerprints")
 
 
 	cs, err := NewCertStateFromConfig(config)
 	cs, err := NewCertStateFromConfig(config)
 	if err != nil {
 	if err != nil {
@@ -365,6 +364,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		routines:                routines,
 		routines:                routines,
 		MessageMetrics:          messageMetrics,
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
 		version:                 buildVersion,
+		caPool:                  caPool,
 
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,
 		l:                     l,

+ 3 - 3
outside.go

@@ -280,7 +280,7 @@ func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out
 		return
 		return
 	}
 	}
 
 
-	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, trustedCAs, localCache)
+	dropReason := f.firewall.Drop(out, *fwPacket, true, hostinfo, f.caPool, localCache)
 	if dropReason != nil {
 	if dropReason != nil {
 		if f.l.Level >= logrus.DebugLevel {
 		if f.l.Level >= logrus.DebugLevel {
 			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
 			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
@@ -368,7 +368,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N
 }
 }
 */
 */
 
 
-func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*cert.NebulaCertificate, error) {
+func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
 	pk := h.PeerStatic()
 	pk := h.PeerStatic()
 
 
 	if pk == nil {
 	if pk == nil {
@@ -397,7 +397,7 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte) (*ce
 	}
 	}
 
 
 	c, _ := cert.UnmarshalNebulaCertificate(recombined)
 	c, _ := cert.UnmarshalNebulaCertificate(recombined)
-	isValid, err := c.Verify(time.Now(), trustedCAs)
+	isValid, err := c.Verify(time.Now(), caPool)
 	if err != nil {
 	if err != nil {
 		return c, fmt.Errorf("certificate validation failed: %s", err)
 		return c, fmt.Errorf("certificate validation failed: %s", err)
 	} else if !isValid {
 	} else if !isValid {