| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 | package functionsimport (	"context"	"errors"	"fmt"	"os"	"os/signal"	"strings"	"sync"	"syscall"	"time"	mqtt "github.com/eclipse/paho.mqtt.golang"	"github.com/gravitl/netmaker/logger"	"github.com/gravitl/netmaker/mq"	"github.com/gravitl/netmaker/netclient/auth"	"github.com/gravitl/netmaker/netclient/config"	"github.com/gravitl/netmaker/netclient/global_settings"	"github.com/gravitl/netmaker/netclient/local"	"github.com/gravitl/netmaker/netclient/ncutils"	"github.com/gravitl/netmaker/netclient/wireguard"	"golang.zx2c4.com/wireguard/wgctrl/wgtypes")var messageCache = new(sync.Map)var serverSet map[string]boolvar mqclient mqtt.Clientconst lastNodeUpdate = "lnu"const lastPeerUpdate = "lpu"type cachedMessage struct {	Message  string	LastSeen time.Time}// Daemon runs netclient daemon from command linefunc Daemon() error {	logger.Log(0, "netclient daemon started -- version:", ncutils.Version)	UpdateClientConfig()	if err := ncutils.SavePID(); err != nil {		return err	}	// reference required to eliminate unused statticcheck	serverSet = make(map[string]bool)	serverSet["dummy"] = false	// set ipforwarding on startup	err := local.SetIPForwarding()	if err != nil {		logger.Log(0, err.Error())	}	// == add waitgroup and cancel for checkin routine ==	wg := sync.WaitGroup{}	quit := make(chan os.Signal, 1)	reset := make(chan os.Signal, 1)	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)	signal.Notify(reset, syscall.SIGHUP)	cancel := startGoRoutines(&wg)	for {		select {		case <-quit:			cancel()			logger.Log(0, "shutting down netclient daemon")			wg.Wait()			if mqclient != nil {				mqclient.Disconnect(250)			}			logger.Log(0, "shutdown complete")			return nil		case <-reset:			logger.Log(0, "received reset")			cancel()			wg.Wait()			if mqclient != nil {				mqclient.Disconnect(250)			}			logger.Log(0, "restarting daemon")			cancel = startGoRoutines(&wg)		}	}}func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc {	ctx, cancel := context.WithCancel(context.Background())	serverSet := make(map[string]bool)	networks, _ := ncutils.GetSystemNetworks()	for _, network := range networks {		logger.Log(3, "initializing network", network)		cfg := config.ClientConfig{}		cfg.Network = network		cfg.ReadConfig()		if cfg.Node.Connected == "yes" {			if err := wireguard.ApplyConf(&cfg.Node, cfg.Node.Interface, ncutils.GetNetclientPathSpecific()+cfg.Node.Interface+".conf"); err != nil {				logger.Log(0, "failed to start ", cfg.Node.Interface, "wg interface", err.Error())			}			if cfg.PublicIPService != "" {				global_settings.PublicIPServices[network] = cfg.PublicIPService			}		}		server := cfg.Server.Server		if !serverSet[server] {			// == subscribe to all nodes for each on machine ==			serverSet[server] = true			logger.Log(1, "started daemon for server ", server)			local.SetNetmakerDomainRoute(cfg.Server.API)			wg.Add(1)			go messageQueue(ctx, wg, &cfg)		}	}	wg.Add(1)	go Checkin(ctx, wg)	return cancel}// UpdateKeys -- updates private key and returns new publickeyfunc UpdateKeys(nodeCfg *config.ClientConfig, client mqtt.Client) error {	logger.Log(0, "interface:", nodeCfg.Node.Interface, "received message to update wireguard keys for network ", nodeCfg.Network)	key, err := wgtypes.GeneratePrivateKey()	if err != nil {		logger.Log(0, "network:", nodeCfg.Node.Network, "error generating privatekey ", err.Error())		return err	}	file := ncutils.GetNetclientPathSpecific() + nodeCfg.Node.Interface + ".conf"	if err := wireguard.UpdatePrivateKey(file, key.String()); err != nil {		logger.Log(0, "network:", nodeCfg.Node.Network, "error updating wireguard key ", err.Error())		return err	}	if storeErr := wireguard.StorePrivKey(key.String(), nodeCfg.Network); storeErr != nil {		logger.Log(0, "network:", nodeCfg.Network, "failed to save private key", storeErr.Error())		return storeErr	}	nodeCfg.Node.PublicKey = key.PublicKey().String()	PublishNodeUpdate(nodeCfg)	return nil}// == Private ==// sets MQ client subscriptions for a specific node config// should be called for each node belonging to a given serverfunc setSubscriptions(client mqtt.Client, nodeCfg *config.ClientConfig) {	if token := client.Subscribe(fmt.Sprintf("update/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID), 0, mqtt.MessageHandler(NodeUpdate)); token.WaitTimeout(mq.MQ_TIMEOUT*time.Second) && token.Error() != nil {		if token.Error() == nil {			logger.Log(0, "network:", nodeCfg.Node.Network, "connection timeout")		} else {			logger.Log(0, "network:", nodeCfg.Node.Network, token.Error().Error())		}		return	}	logger.Log(3, fmt.Sprintf("subscribed to node updates for node %s update/%s/%s", nodeCfg.Node.Name, nodeCfg.Node.Network, nodeCfg.Node.ID))	if token := client.Subscribe(fmt.Sprintf("peers/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID), 0, mqtt.MessageHandler(UpdatePeers)); token.Wait() && token.Error() != nil {		logger.Log(0, "network", nodeCfg.Node.Network, token.Error().Error())		return	}	logger.Log(3, fmt.Sprintf("subscribed to peer updates for node %s peers/%s/%s", nodeCfg.Node.Name, nodeCfg.Node.Network, nodeCfg.Node.ID))}// on a delete usually, pass in the nodecfg to unsubscribe client broker communications// for the node in nodeCfgfunc unsubscribeNode(client mqtt.Client, nodeCfg *config.ClientConfig) {	client.Unsubscribe(fmt.Sprintf("update/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID))	var ok = true	if token := client.Unsubscribe(fmt.Sprintf("update/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID)); token.WaitTimeout(mq.MQ_TIMEOUT*time.Second) && token.Error() != nil {		if token.Error() == nil {			logger.Log(1, "network:", nodeCfg.Node.Network, "unable to unsubscribe from updates for node ", nodeCfg.Node.Name, "\n", "connection timeout")		} else {			logger.Log(1, "network:", nodeCfg.Node.Network, "unable to unsubscribe from updates for node ", nodeCfg.Node.Name, "\n", token.Error().Error())		}		ok = false	}	if token := client.Unsubscribe(fmt.Sprintf("peers/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID)); token.WaitTimeout(mq.MQ_TIMEOUT*time.Second) && token.Error() != nil {		if token.Error() == nil {			logger.Log(1, "network:", nodeCfg.Node.Network, "unable to unsubscribe from peer updates for node", nodeCfg.Node.Name, "\n", "connection timeout")		} else {			logger.Log(1, "network:", nodeCfg.Node.Network, "unable to unsubscribe from peer updates for node", nodeCfg.Node.Name, "\n", token.Error().Error())		}		ok = false	}	if ok {		logger.Log(1, "network:", nodeCfg.Node.Network, "successfully unsubscribed node ", nodeCfg.Node.ID, " : ", nodeCfg.Node.Name)	}}// sets up Message Queue and subsribes/publishes updates to/from server// the client should subscribe to ALL nodes that exist on server locallyfunc messageQueue(ctx context.Context, wg *sync.WaitGroup, cfg *config.ClientConfig) {	defer wg.Done()	logger.Log(0, "network:", cfg.Node.Network, "netclient message queue started for server:", cfg.Server.Server)	err := setupMQTT(cfg)	if err != nil {		logger.Log(0, "unable to connect to broker", cfg.Server.Server, err.Error())		return	}	//defer mqclient.Disconnect(250)	<-ctx.Done()	logger.Log(0, "shutting down message queue for server", cfg.Server.Server)}// func setMQTTSingenton creates a connection to broker for single use (ie to publish a message)// only to be called from cli (eg. connect/disconnect, join, leave) and not from daemon ---func setupMQTTSingleton(cfg *config.ClientConfig) error {	opts := mqtt.NewClientOptions()	server := cfg.Server.Server	port := cfg.Server.MQPort	pass, err := os.ReadFile(ncutils.GetNetclientPathSpecific() + "secret-" + cfg.Network)	if err != nil {		return fmt.Errorf("could not read secrets file %w", err)	}	opts.AddBroker("wss://" + server + ":" + port)	opts.SetUsername(cfg.Node.ID)	opts.SetPassword(string(pass))	mqclient = mqtt.NewClient(opts)	var connecterr error	opts.SetClientID(ncutils.MakeRandomString(23))	if token := mqclient.Connect(); !token.WaitTimeout(30*time.Second) || token.Error() != nil {		logger.Log(0, "unable to connect to broker, retrying ...")		if token.Error() == nil {			connecterr = errors.New("connect timeout")		} else {			connecterr = token.Error()		}	}	return connecterr}// setupMQTT creates a connection to broker and returns client// this function is primarily used to create a connection to publish to the brokerfunc setupMQTT(cfg *config.ClientConfig) error {	opts := mqtt.NewClientOptions()	server := cfg.Server.Server	port := cfg.Server.MQPort	pass, err := os.ReadFile(ncutils.GetNetclientPathSpecific() + "secret-" + cfg.Network)	if err != nil {		return fmt.Errorf("could not read secrets file %w", err)	}	opts.AddBroker(fmt.Sprintf("wss://%s:%s", server, port))	opts.SetUsername(cfg.Node.ID)	opts.SetPassword(string(pass))	opts.SetClientID(ncutils.MakeRandomString(23))	opts.SetDefaultPublishHandler(All)	opts.SetAutoReconnect(true)	opts.SetConnectRetry(true)	opts.SetConnectRetryInterval(time.Second << 2)	opts.SetKeepAlive(time.Minute >> 1)	opts.SetWriteTimeout(time.Minute)	opts.SetOnConnectHandler(func(client mqtt.Client) {		networks, err := ncutils.GetSystemNetworks()		if err != nil {			logger.Log(0, "error retriving networks", err.Error())		}		for _, network := range networks {			var currNodeCfg config.ClientConfig			currNodeCfg.Network = network			currNodeCfg.ReadConfig()			setSubscriptions(client, &currNodeCfg)		}	})	opts.SetOrderMatters(true)	opts.SetResumeSubs(true)	opts.SetConnectionLostHandler(func(c mqtt.Client, e error) {		logger.Log(0, "network:", cfg.Node.Network, "detected broker connection lost for", cfg.Server.Server)	})	mqclient = mqtt.NewClient(opts)	var connecterr error	for count := 0; count < 3; count++ {		connecterr = nil		if token := mqclient.Connect(); !token.WaitTimeout(30*time.Second) || token.Error() != nil {			logger.Log(0, "unable to connect to broker, retrying ...")			if token.Error() == nil {				connecterr = errors.New("connect timeout")			} else {				connecterr = token.Error()			}			if err := checkBroker(cfg.Server.Server, cfg.Server.MQPort); err != nil {				logger.Log(0, "could not connect to broker", cfg.Server.Server, err.Error())			}		}	}	if connecterr != nil {		logger.Log(0, "failed to establish connection to broker: ", connecterr.Error())		return connecterr	}	return nil}// publishes a message to server to update peers on this peer's behalffunc publishSignal(nodeCfg *config.ClientConfig, signal byte) error {	if err := publish(nodeCfg, fmt.Sprintf("signal/%s", nodeCfg.Node.ID), []byte{signal}, 1); err != nil {		return err	}	return nil}func parseNetworkFromTopic(topic string) string {	return strings.Split(topic, "/")[1]}// should only ever use node client configsfunc decryptMsg(nodeCfg *config.ClientConfig, msg []byte) ([]byte, error) {	if len(msg) <= 24 { // make sure message is of appropriate length		return nil, fmt.Errorf("recieved invalid message from broker %v", msg)	}	// setup the keys	diskKey, keyErr := auth.RetrieveTrafficKey(nodeCfg.Node.Network)	if keyErr != nil {		return nil, keyErr	}	serverPubKey, err := ncutils.ConvertBytesToKey(nodeCfg.Node.TrafficKeys.Server)	if err != nil {		return nil, err	}	return ncutils.DeChunk(msg, serverPubKey, diskKey)}// == Message Caches ==func insert(network, which, cache string) {	var newMessage = cachedMessage{		Message:  cache,		LastSeen: time.Now(),	}	messageCache.Store(fmt.Sprintf("%s%s", network, which), newMessage)}func read(network, which string) string {	val, isok := messageCache.Load(fmt.Sprintf("%s%s", network, which))	if isok {		var readMessage = val.(cachedMessage) // fetch current cached message		if readMessage.LastSeen.IsZero() {			return ""		}		if time.Now().After(readMessage.LastSeen.Add(time.Hour * 24)) { // check if message has been there over a minute			messageCache.Delete(fmt.Sprintf("%s%s", network, which)) // remove old message if expired			return ""		}		return readMessage.Message // return current message if not expired	}	return ""}
 |