Forráskód Böngészése

Teardown tunnel automatically if peer's certificate expired (#370)

Donatas Abraitis 3 éve
szülő
commit
32e2619323
5 módosított fájl, 167 hozzáadás és 19 törlés
  1. 65 18
      connection_manager.go
  2. 95 0
      connection_manager_test.go
  3. 3 1
      examples/config.yml
  4. 3 0
      interface.go
  5. 1 0
      main.go

+ 65 - 18
connection_manager.go

@@ -166,7 +166,23 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 		// Check for traffic coming back in from this host.
 		traf := n.CheckIn(vpnIP)
 
-		// If we saw incoming packets from this ip, just return
+		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
+		if err != nil {
+			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+
+			if !n.intf.disconnectInvalid {
+				n.ClearIP(vpnIP)
+				n.ClearPendingDeletion(vpnIP)
+				continue
+			}
+		}
+
+		if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
+			continue
+		}
+
+		// If we saw an incoming packets from this ip and peer's certificate is not
+		// expired, just ignore.
 		if traf {
 			if n.l.Level >= logrus.DebugLevel {
 				n.l.WithField("vpnIp", IntIp(vpnIP)).
@@ -178,15 +194,6 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 			continue
 		}
 
-		// If we didn't we may need to probe or destroy the conn
-		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
-		if err != nil {
-			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
-			continue
-		}
-
 		hostinfo.logger(n.l).
 			WithField("tunnelCheck", m{"state": "testing", "method": "active"}).
 			Debug("Tunnel status")
@@ -213,22 +220,31 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 
 		vpnIP := ep.(uint32)
 
-		// If we saw incoming packets from this ip, just return
+		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
+		if err != nil {
+			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
+
+			if !n.intf.disconnectInvalid {
+				n.ClearIP(vpnIP)
+				n.ClearPendingDeletion(vpnIP)
+				continue
+			}
+		}
+
+		if n.handleInvalidCertificate(now, vpnIP, hostinfo) {
+			continue
+		}
+
+		// If we saw an incoming packets from this ip and peer's certificate is not
+		// expired, just ignore.
 		traf := n.CheckIn(vpnIP)
 		if traf {
 			n.l.WithField("vpnIp", IntIp(vpnIP)).
 				WithField("tunnelCheck", m{"state": "alive", "method": "active"}).
 				Debug("Tunnel status")
-			n.ClearIP(vpnIP)
-			n.ClearPendingDeletion(vpnIP)
-			continue
-		}
 
-		hostinfo, err := n.hostMap.QueryVpnIP(vpnIP)
-		if err != nil {
 			n.ClearIP(vpnIP)
 			n.ClearPendingDeletion(vpnIP)
-			n.l.Debugf("Not found in hostmap: %s", IntIp(vpnIP))
 			continue
 		}
 
@@ -256,3 +272,34 @@ func (n *connectionManager) HandleDeletionTick(now time.Time) {
 		}
 	}
 }
+
+// handleInvalidCertificates will destroy a tunnel if pki.disconnect_invalid is true and the certificate is no longer valid
+func (n *connectionManager) handleInvalidCertificate(now time.Time, vpnIP uint32, hostinfo *HostInfo) bool {
+	if !n.intf.disconnectInvalid {
+		return false
+	}
+
+	remoteCert := hostinfo.GetCert()
+	if remoteCert == nil {
+		return false
+	}
+
+	valid, err := remoteCert.Verify(now, n.intf.caPool)
+	if valid {
+		return false
+	}
+
+	fingerprint, _ := remoteCert.Sha256Sum()
+	n.l.WithField("vpnIp", IntIp(vpnIP)).WithError(err).
+		WithField("certName", remoteCert.Details.Name).
+		WithField("fingerprint", fingerprint).
+		Info("Remote certificate is no longer valid, tearing down the tunnel")
+
+	// Inform the remote and close the tunnel locally
+	n.intf.sendCloseTunnel(hostinfo)
+	n.intf.closeTunnel(hostinfo, false)
+
+	n.ClearIP(vpnIP)
+	n.ClearPendingDeletion(vpnIP)
+	return true
+}

+ 95 - 0
connection_manager_test.go

@@ -1,6 +1,8 @@
 package nebula
 
 import (
+	"crypto/ed25519"
+	"crypto/rand"
 	"net"
 	"testing"
 	"time"
@@ -148,3 +150,96 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
 	assert.Contains(t, nc.hostMap.Hosts, vpnIP)
 
 }
+
+// Check if we can disconnect the peer.
+// Validate if the peer's certificate is invalid (expired, etc.)
+// Disconnect only if disconnectInvalid: true is set.
+func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
+	now := time.Now()
+	l := NewTestLogger()
+	ipNet := net.IPNet{
+		IP:   net.IPv4(172, 1, 1, 2),
+		Mask: net.IPMask{255, 255, 255, 0},
+	}
+	_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
+	_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
+	preferredRanges := []*net.IPNet{localrange}
+	hostMap := NewHostMap(l, "test", vpncidr, preferredRanges)
+
+	// Generate keys for CA and peer's cert.
+	pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
+	caCert := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:      "ca",
+			NotBefore: now,
+			NotAfter:  now.Add(1 * time.Hour),
+			IsCA:      true,
+			PublicKey: pubCA,
+		},
+	}
+	caCert.Sign(privCA)
+	ncp := &cert.NebulaCAPool{
+		CAs: cert.NewCAPool().CAs,
+	}
+	ncp.CAs["ca"] = &caCert
+
+	pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
+	peerCert := cert.NebulaCertificate{
+		Details: cert.NebulaCertificateDetails{
+			Name:      "host",
+			Ips:       []*net.IPNet{&ipNet},
+			Subnets:   []*net.IPNet{},
+			NotBefore: now,
+			NotAfter:  now.Add(60 * time.Second),
+			PublicKey: pubCrt,
+			IsCA:      false,
+			Issuer:    "ca",
+		},
+	}
+	peerCert.Sign(privCA)
+
+	cs := &CertState{
+		rawCertificate:      []byte{},
+		privateKey:          []byte{},
+		certificate:         &cert.NebulaCertificate{},
+		rawCertificateNoKey: []byte{},
+	}
+
+	lh := NewLightHouse(l, false, &net.IPNet{IP: net.IP{0, 0, 0, 0}, Mask: net.IPMask{0, 0, 0, 0}}, []uint32{}, 1000, 0, &udpConn{}, false, 1, false)
+	ifce := &Interface{
+		hostMap:           hostMap,
+		inside:            &Tun{},
+		outside:           &udpConn{},
+		certState:         cs,
+		firewall:          &Firewall{},
+		lightHouse:        lh,
+		handshakeManager:  NewHandshakeManager(l, vpncidr, preferredRanges, hostMap, lh, &udpConn{}, defaultHandshakeConfig),
+		l:                 l,
+		disconnectInvalid: true,
+		caPool:            ncp,
+	}
+
+	// Create manager
+	nc := newConnectionManager(l, ifce, 5, 10)
+	ifce.connectionManager = nc
+	hostinfo := nc.hostMap.AddVpnIP(vpnIP)
+	hostinfo.ConnectionState = &ConnectionState{
+		certState: cs,
+		peerCert:  &peerCert,
+		H:         &noise.HandshakeState{},
+	}
+
+	// Move ahead 45s.
+	// Check if to disconnect with invalid certificate.
+	// Should be alive.
+	nextTick := now.Add(45 * time.Second)
+	destroyed := nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
+	assert.False(t, destroyed)
+
+	// Move ahead 61s.
+	// Check if to disconnect with invalid certificate.
+	// Should be disconnected.
+	nextTick = now.Add(61 * time.Second)
+	destroyed = nc.handleInvalidCertificate(nextTick, vpnIP, hostinfo)
+	assert.True(t, destroyed)
+}

+ 3 - 1
examples/config.yml

@@ -7,9 +7,11 @@ pki:
   ca: /etc/nebula/ca.crt
   cert: /etc/nebula/host.crt
   key: /etc/nebula/host.key
-  #blocklist is a list of certificate fingerprints that we will refuse to talk to
+  # blocklist is a list of certificate fingerprints that we will refuse to talk to
   #blocklist:
   #  - c99d4e650533b92061b09918e838a5a0a6aaee21eed1d12fd937682865936c72
+  # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid.
+  #disconnect_invalid: false
 
 # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network).
 # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel.

+ 3 - 0
interface.go

@@ -43,6 +43,7 @@ type InterfaceConfig struct {
 	MessageMetrics          *MessageMetrics
 	version                 string
 	caPool                  *cert.NebulaCAPool
+	disconnectInvalid       bool
 
 	ConntrackCacheTimeout time.Duration
 	l                     *logrus.Logger
@@ -67,6 +68,7 @@ type Interface struct {
 	udpBatchSize       int
 	routines           int
 	caPool             *cert.NebulaCAPool
+	disconnectInvalid  bool
 
 	// rebindCount is used to decide if an active tunnel should trigger a punch notification through a lighthouse
 	rebindCount int8
@@ -118,6 +120,7 @@ func NewInterface(c *InterfaceConfig) (*Interface, error) {
 		writers:            make([]*udpConn, c.routines),
 		readers:            make([]io.ReadWriteCloser, c.routines),
 		caPool:             c.caPool,
+		disconnectInvalid:  c.disconnectInvalid,
 		myVpnIp:            ip2int(c.certState.certificate.Details.Ips[0].IP),
 
 		conntrackCacheTimeout: c.ConntrackCacheTimeout,

+ 1 - 0
main.go

@@ -371,6 +371,7 @@ func Main(config *Config, configTest bool, buildVersion string, logger *logrus.L
 		MessageMetrics:          messageMetrics,
 		version:                 buildVersion,
 		caPool:                  caPool,
+		disconnectInvalid:       config.GetBool("pki.disconnect_invalid", false),
 
 		ConntrackCacheTimeout: conntrackCacheTimeout,
 		l:                     l,