| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490 | package nebulaimport (	"bytes"	"context"	"sync"	"time"	"github.com/rcrowley/go-metrics"	"github.com/sirupsen/logrus"	"github.com/slackhq/nebula/cert"	"github.com/slackhq/nebula/header"	"github.com/slackhq/nebula/iputil"	"github.com/slackhq/nebula/udp")type trafficDecision intconst (	doNothing      trafficDecision = 0	deleteTunnel   trafficDecision = 1 // delete the hostinfo on our side, do not notify the remote	closeTunnel    trafficDecision = 2 // delete the hostinfo and notify the remote	swapPrimary    trafficDecision = 3	migrateRelays  trafficDecision = 4	tryRehandshake trafficDecision = 5)type connectionManager struct {	in     map[uint32]struct{}	inLock *sync.RWMutex	out     map[uint32]struct{}	outLock *sync.RWMutex	// relayUsed holds which relay localIndexs are in use	relayUsed     map[uint32]struct{}	relayUsedLock *sync.RWMutex	hostMap                 *HostMap	trafficTimer            *LockingTimerWheel[uint32]	intf                    *Interface	pendingDeletion         map[uint32]struct{}	punchy                  *Punchy	checkInterval           time.Duration	pendingDeletionInterval time.Duration	metricsTxPunchy         metrics.Counter	l *logrus.Logger}func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface, checkInterval, pendingDeletionInterval time.Duration, punchy *Punchy) *connectionManager {	var max time.Duration	if checkInterval < pendingDeletionInterval {		max = pendingDeletionInterval	} else {		max = checkInterval	}	nc := &connectionManager{		hostMap:                 intf.hostMap,		in:                      make(map[uint32]struct{}),		inLock:                  &sync.RWMutex{},		out:                     make(map[uint32]struct{}),		outLock:                 &sync.RWMutex{},		relayUsed:               make(map[uint32]struct{}),		relayUsedLock:           &sync.RWMutex{},		trafficTimer:            NewLockingTimerWheel[uint32](time.Millisecond*500, max),		intf:                    intf,		pendingDeletion:         make(map[uint32]struct{}),		checkInterval:           checkInterval,		pendingDeletionInterval: pendingDeletionInterval,		punchy:                  punchy,		metricsTxPunchy:         metrics.GetOrRegisterCounter("messages.tx.punchy", nil),		l:                       l,	}	nc.Start(ctx)	return nc}func (n *connectionManager) In(localIndex uint32) {	n.inLock.RLock()	// If this already exists, return	if _, ok := n.in[localIndex]; ok {		n.inLock.RUnlock()		return	}	n.inLock.RUnlock()	n.inLock.Lock()	n.in[localIndex] = struct{}{}	n.inLock.Unlock()}func (n *connectionManager) Out(localIndex uint32) {	n.outLock.RLock()	// If this already exists, return	if _, ok := n.out[localIndex]; ok {		n.outLock.RUnlock()		return	}	n.outLock.RUnlock()	n.outLock.Lock()	n.out[localIndex] = struct{}{}	n.outLock.Unlock()}func (n *connectionManager) RelayUsed(localIndex uint32) {	n.relayUsedLock.RLock()	// If this already exists, return	if _, ok := n.relayUsed[localIndex]; ok {		n.relayUsedLock.RUnlock()		return	}	n.relayUsedLock.RUnlock()	n.relayUsedLock.Lock()	n.relayUsed[localIndex] = struct{}{}	n.relayUsedLock.Unlock()}// getAndResetTrafficCheck returns if there was any inbound or outbound traffic within the last tick and// resets the state for this local indexfunc (n *connectionManager) getAndResetTrafficCheck(localIndex uint32) (bool, bool) {	n.inLock.Lock()	n.outLock.Lock()	_, in := n.in[localIndex]	_, out := n.out[localIndex]	delete(n.in, localIndex)	delete(n.out, localIndex)	n.inLock.Unlock()	n.outLock.Unlock()	return in, out}func (n *connectionManager) AddTrafficWatch(localIndex uint32) {	// Use a write lock directly because it should be incredibly rare that we are ever already tracking this index	n.outLock.Lock()	if _, ok := n.out[localIndex]; ok {		n.outLock.Unlock()		return	}	n.out[localIndex] = struct{}{}	n.trafficTimer.Add(localIndex, n.checkInterval)	n.outLock.Unlock()}func (n *connectionManager) Start(ctx context.Context) {	go n.Run(ctx)}func (n *connectionManager) Run(ctx context.Context) {	//TODO: this tick should be based on the min wheel tick? Check firewall	clockSource := time.NewTicker(500 * time.Millisecond)	defer clockSource.Stop()	p := []byte("")	nb := make([]byte, 12, 12)	out := make([]byte, mtu)	for {		select {		case <-ctx.Done():			return		case now := <-clockSource.C:			n.trafficTimer.Advance(now)			for {				localIndex, has := n.trafficTimer.Purge()				if !has {					break				}				n.doTrafficCheck(localIndex, p, nb, out, now)			}		}	}}func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) {	decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now)	switch decision {	case deleteTunnel:		if n.hostMap.DeleteHostInfo(hostinfo) {			// Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap			n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp)		}	case closeTunnel:		n.intf.sendCloseTunnel(hostinfo)		n.intf.closeTunnel(hostinfo)	case swapPrimary:		n.swapPrimary(hostinfo, primary)	case migrateRelays:		n.migrateRelayUsed(hostinfo, primary)	case tryRehandshake:		n.tryRehandshake(hostinfo)	}	n.resetRelayTrafficCheck(hostinfo)}func (n *connectionManager) resetRelayTrafficCheck(hostinfo *HostInfo) {	if hostinfo != nil {		n.relayUsedLock.Lock()		defer n.relayUsedLock.Unlock()		// No need to migrate any relays, delete usage info now.		for _, idx := range hostinfo.relayState.CopyRelayForIdxs() {			delete(n.relayUsed, idx)		}	}}func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) {	relayFor := oldhostinfo.relayState.CopyAllRelayFor()	for _, r := range relayFor {		existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp)		var index uint32		var relayFrom iputil.VpnIp		var relayTo iputil.VpnIp		switch {		case ok && existing.State == Established:			// This relay already exists in newhostinfo, then do nothing.			continue		case ok && existing.State == Requested:			// The relay exists in a Requested state; re-send the request			index = existing.LocalIndex			switch r.Type {			case TerminalType:				relayFrom = newhostinfo.vpnIp				relayTo = existing.PeerIp			case ForwardingType:				relayFrom = existing.PeerIp				relayTo = newhostinfo.vpnIp			default:				// should never happen			}		case !ok:			n.relayUsedLock.RLock()			if _, relayUsed := n.relayUsed[r.LocalIndex]; !relayUsed {				// The relay hasn't been used; don't migrate it.				n.relayUsedLock.RUnlock()				continue			}			n.relayUsedLock.RUnlock()			// The relay doesn't exist at all; create some relay state and send the request.			var err error			index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested)			if err != nil {				n.l.WithError(err).Error("failed to migrate relay to new hostinfo")				continue			}			switch r.Type {			case TerminalType:				relayFrom = newhostinfo.vpnIp				relayTo = r.PeerIp			case ForwardingType:				relayFrom = r.PeerIp				relayTo = newhostinfo.vpnIp			default:				// should never happen			}		}		// Send a CreateRelayRequest to the peer.		req := NebulaControl{			Type:                NebulaControl_CreateRelayRequest,			InitiatorRelayIndex: index,			RelayFromIp:         uint32(relayFrom),			RelayToIp:           uint32(relayTo),		}		msg, err := req.Marshal()		if err != nil {			n.l.WithError(err).Error("failed to marshal Control message to migrate relay")		} else {			n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu))			n.l.WithFields(logrus.Fields{				"relayFrom":           iputil.VpnIp(req.RelayFromIp),				"relayTo":             iputil.VpnIp(req.RelayToIp),				"initiatorRelayIndex": req.InitiatorRelayIndex,				"responderRelayIndex": req.ResponderRelayIndex,				"vpnIp":               newhostinfo.vpnIp}).				Info("send CreateRelayRequest")		}	}}func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) {	n.hostMap.RLock()	defer n.hostMap.RUnlock()	hostinfo := n.hostMap.Indexes[localIndex]	if hostinfo == nil {		n.l.WithField("localIndex", localIndex).Debugf("Not found in hostmap")		delete(n.pendingDeletion, localIndex)		return doNothing, nil, nil	}	if n.isInvalidCertificate(now, hostinfo) {		delete(n.pendingDeletion, hostinfo.localIndexId)		return closeTunnel, hostinfo, nil	}	primary := n.hostMap.Hosts[hostinfo.vpnIp]	mainHostInfo := true	if primary != nil && primary != hostinfo {		mainHostInfo = false	}	// Check for traffic on this hostinfo	inTraffic, outTraffic := n.getAndResetTrafficCheck(localIndex)	// A hostinfo is determined alive if there is incoming traffic	if inTraffic {		decision := doNothing		if n.l.Level >= logrus.DebugLevel {			hostinfo.logger(n.l).				WithField("tunnelCheck", m{"state": "alive", "method": "passive"}).				Debug("Tunnel status")		}		delete(n.pendingDeletion, hostinfo.localIndexId)		if mainHostInfo {			decision = tryRehandshake		} else {			if n.shouldSwapPrimary(hostinfo, primary) {				decision = swapPrimary			} else {				// migrate the relays to the primary, if in use.				decision = migrateRelays			}		}		n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)		if !outTraffic {			// Send a punch packet to keep the NAT state alive			n.sendPunch(hostinfo)		}		return decision, hostinfo, primary	}	if _, ok := n.pendingDeletion[hostinfo.localIndexId]; ok {		// We have already sent a test packet and nothing was returned, this hostinfo is dead		hostinfo.logger(n.l).			WithField("tunnelCheck", m{"state": "dead", "method": "active"}).			Info("Tunnel status")		delete(n.pendingDeletion, hostinfo.localIndexId)		return deleteTunnel, hostinfo, nil	}	if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo {		if !outTraffic {			// If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel.			// Just maintain NAT state if configured to do so.			n.sendPunch(hostinfo)			n.trafficTimer.Add(hostinfo.localIndexId, n.checkInterval)			return doNothing, nil, nil		}		if n.punchy.GetTargetEverything() {			// This is similar to the old punchy behavior with a slight optimization.			// We aren't receiving traffic but we are sending it, punch on all known			// ips in case we need to re-prime NAT state			n.sendPunch(hostinfo)		}		if n.l.Level >= logrus.DebugLevel {			hostinfo.logger(n.l).				WithField("tunnelCheck", m{"state": "testing", "method": "active"}).				Debug("Tunnel status")		}		// Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues		n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out)	} else {		if n.l.Level >= logrus.DebugLevel {			hostinfo.logger(n.l).Debugf("Hostinfo sadness")		}	}	n.pendingDeletion[hostinfo.localIndexId] = struct{}{}	n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval)	return doNothing, nil, nil}func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool {	// The primary tunnel is the most recent handshake to complete locally and should work entirely fine.	// If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary.	// Let's sort this out.	if current.vpnIp < n.intf.myVpnIp {		// Only one side should flip primary because if both flip then we may never resolve to a single tunnel.		// vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping.		// The remotes vpn ip is lower than mine. I will not flip.		return false	}	certState := n.intf.certState.Load()	return bytes.Equal(current.ConnectionState.certState.certificate.Signature, certState.certificate.Signature)}func (n *connectionManager) swapPrimary(current, primary *HostInfo) {	n.hostMap.Lock()	// Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake.	if n.hostMap.Hosts[current.vpnIp] == primary {		n.hostMap.unlockedMakePrimary(current)	}	n.hostMap.Unlock()}// isInvalidCertificate will check if we should destroy a tunnel if pki.disconnect_invalid is true and// the certificate is no longer valid. Block listed certificates will skip the pki.disconnect_invalid// check and return true.func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostInfo) bool {	remoteCert := hostinfo.GetCert()	if remoteCert == nil {		return false	}	valid, err := remoteCert.VerifyWithCache(now, n.intf.caPool)	if valid {		return false	}	if !n.intf.disconnectInvalid && err != cert.ErrBlockListed {		// Block listed certificates should always be disconnected		return false	}	fingerprint, _ := remoteCert.Sha256Sum()	hostinfo.logger(n.l).WithError(err).		WithField("fingerprint", fingerprint).		Info("Remote certificate is no longer valid, tearing down the tunnel")	return true}func (n *connectionManager) sendPunch(hostinfo *HostInfo) {	if !n.punchy.GetPunch() {		// Punching is disabled		return	}	if n.punchy.GetTargetEverything() {		hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {			n.metricsTxPunchy.Inc(1)			n.intf.outside.WriteTo([]byte{1}, addr)		})	} else if hostinfo.remote != nil {		n.metricsTxPunchy.Inc(1)		n.intf.outside.WriteTo([]byte{1}, hostinfo.remote)	}}func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) {	certState := n.intf.certState.Load()	if bytes.Equal(hostinfo.ConnectionState.certState.certificate.Signature, certState.certificate.Signature) {		return	}	n.l.WithField("vpnIp", hostinfo.vpnIp).		WithField("reason", "local certificate is not current").		Info("Re-handshaking with remote")	//TODO: this is copied from getOrHandshake to keep the extra checks out of the hot path, figure it out	newHostinfo := n.intf.handshakeManager.AddVpnIp(hostinfo.vpnIp, n.intf.initHostInfo)	if !newHostinfo.HandshakeReady {		ixHandshakeStage0(n.intf, newHostinfo.vpnIp, newHostinfo)	}	//If this is a static host, we don't need to wait for the HostQueryReply	//We can trigger the handshake right now	if _, ok := n.intf.lightHouse.GetStaticHostList()[hostinfo.vpnIp]; ok {		select {		case n.intf.handshakeManager.trigger <- hostinfo.vpnIp:		default:		}	}}
 |