| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 | package functionsimport (	"context"	"errors"	"fmt"	"os"	"os/signal"	"strings"	"sync"	"syscall"	"time"	mqtt "github.com/eclipse/paho.mqtt.golang"	"github.com/go-ping/ping"	"github.com/gravitl/netmaker/logger"	"github.com/gravitl/netmaker/models"	"github.com/gravitl/netmaker/netclient/auth"	"github.com/gravitl/netmaker/netclient/config"	"github.com/gravitl/netmaker/netclient/daemon"	"github.com/gravitl/netmaker/netclient/ncutils"	"github.com/gravitl/netmaker/netclient/wireguard"	"golang.zx2c4.com/wireguard/wgctrl/wgtypes")var messageCache = new(sync.Map)var networkcontext = new(sync.Map)const lastNodeUpdate = "lnu"const lastPeerUpdate = "lpu"type cachedMessage struct {	Message  string	LastSeen time.Time}// Daemon runs netclient daemon from command linefunc Daemon() error {	// == initial pull of all networks ==	networks, _ := ncutils.GetSystemNetworks()	for _, network := range networks {		//temporary code --- remove in version v0.13.0		removeHostDNS(network, ncutils.IsWindows())		// end of code to be removed in version v0.13.0		var cfg config.ClientConfig		cfg.Network = network		cfg.ReadConfig()		initialPull(cfg.Network)	}	// == get all the comms networks on machine ==	commsNetworks, err := getCommsNetworks(networks[:])	if err != nil {		return errors.New("no comm networks exist")	}	// == subscribe to all nodes on each comms network on machine ==	for currCommsNet := range commsNetworks {		logger.Log(1, "started comms network daemon, ", currCommsNet)		ctx, cancel := context.WithCancel(context.Background())		networkcontext.Store(currCommsNet, cancel)		go messageQueue(ctx, currCommsNet)	}	// == add waitgroup and cancel for checkin routine ==	wg := sync.WaitGroup{}	ctx, cancel := context.WithCancel(context.Background())	wg.Add(1)	go Checkin(ctx, &wg, commsNetworks)	quit := make(chan os.Signal, 1)	signal.Notify(quit, syscall.SIGTERM, os.Interrupt, os.Kill)	<-quit	for currCommsNet := range commsNetworks {		if cancel, ok := networkcontext.Load(currCommsNet); ok {			cancel.(context.CancelFunc)()		}	}	cancel()	logger.Log(0, "shutting down netclient daemon")	wg.Wait()	logger.Log(0, "shutdown complete")	return nil}// UpdateKeys -- updates private key and returns new publickeyfunc UpdateKeys(nodeCfg *config.ClientConfig, client mqtt.Client) error {	logger.Log(0, "received message to update wireguard keys for network ", nodeCfg.Network)	key, err := wgtypes.GeneratePrivateKey()	if err != nil {		logger.Log(0, "error generating privatekey ", err.Error())		return err	}	file := ncutils.GetNetclientPathSpecific() + nodeCfg.Node.Interface + ".conf"	if err := wireguard.UpdatePrivateKey(file, key.String()); err != nil {		logger.Log(0, "error updating wireguard key ", err.Error())		return err	}	if storeErr := wireguard.StorePrivKey(key.String(), nodeCfg.Network); storeErr != nil {		logger.Log(0, "failed to save private key", storeErr.Error())		return storeErr	}	nodeCfg.Node.PublicKey = key.PublicKey().String()	var commsCfg = getCommsCfgByNode(&nodeCfg.Node)	PublishNodeUpdate(&commsCfg, nodeCfg)	return nil}// PingServer -- checks if server is reachable// use commsCfg only*func PingServer(commsCfg *config.ClientConfig) error {	node := getServerAddress(commsCfg)	pinger, err := ping.NewPinger(node)	if err != nil {		return err	}	pinger.Timeout = 2 * time.Second	pinger.Run()	stats := pinger.Statistics()	if stats.PacketLoss == 100 {		return errors.New("ping error")	}	return nil}// == Private ==// sets MQ client subscriptions for a specific node config// should be called for each node belonging to a given comms networkfunc setSubscriptions(client mqtt.Client, nodeCfg *config.ClientConfig) {	if nodeCfg.DebugOn {		if token := client.Subscribe("#", 0, nil); token.Wait() && token.Error() != nil {			logger.Log(0, token.Error().Error())			return		}		logger.Log(0, "subscribed to all topics for debugging purposes")	}	if token := client.Subscribe(fmt.Sprintf("update/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID), 0, mqtt.MessageHandler(NodeUpdate)); token.Wait() && token.Error() != nil {		logger.Log(0, token.Error().Error())		return	}	if nodeCfg.DebugOn {		logger.Log(0, fmt.Sprintf("subscribed to node updates for node %s update/%s/%s", nodeCfg.Node.Name, nodeCfg.Node.Network, nodeCfg.Node.ID))	}	if token := client.Subscribe(fmt.Sprintf("peers/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID), 0, mqtt.MessageHandler(UpdatePeers)); token.Wait() && token.Error() != nil {		logger.Log(0, token.Error().Error())		return	}	if nodeCfg.DebugOn {		logger.Log(0, fmt.Sprintf("subscribed to peer updates for node %s peers/%s/%s", nodeCfg.Node.Name, nodeCfg.Node.Network, nodeCfg.Node.ID))	}}// on a delete usually, pass in the nodecfg to unsubscribe client broker communications// for the node in nodeCfgfunc unsubscribeNode(client mqtt.Client, nodeCfg *config.ClientConfig) {	client.Unsubscribe(fmt.Sprintf("update/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID))	var ok = true	if token := client.Unsubscribe(fmt.Sprintf("update/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID)); token.Wait() && token.Error() != nil {		logger.Log(1, "unable to unsubscribe from updates for node ", nodeCfg.Node.Name, "\n", token.Error().Error())		ok = false	}	if token := client.Unsubscribe(fmt.Sprintf("peers/%s/%s", nodeCfg.Node.Network, nodeCfg.Node.ID)); token.Wait() && token.Error() != nil {		logger.Log(1, "unable to unsubscribe from peer updates for node ", nodeCfg.Node.Name, "\n", token.Error().Error())		ok = false	}	if ok {		logger.Log(1, "successfully unsubscribed node ", nodeCfg.Node.ID, " : ", nodeCfg.Node.Name)	}}// sets up Message Queue and subsribes/publishes updates to/from server// the client should subscribe to ALL nodes that exist on unique comms network locallyfunc messageQueue(ctx context.Context, commsNet string) {	var commsCfg config.ClientConfig	commsCfg.Network = commsNet	commsCfg.ReadConfig()	logger.Log(0, "netclient daemon started for network: ", commsNet)	client := setupMQTT(&commsCfg, false)	defer client.Disconnect(250)	<-ctx.Done()	logger.Log(0, "shutting down daemon for comms network ", commsNet)}// setupMQTT creates a connection to broker and return client// utilizes comms client configs to setup connectionsfunc setupMQTT(commsCfg *config.ClientConfig, publish bool) mqtt.Client {	opts := mqtt.NewClientOptions()	server := getServerAddress(commsCfg)	opts.AddBroker(server + ":1883")             // TODO get the appropriate port of the comms mq server	opts.ClientID = ncutils.MakeRandomString(23) // helps avoid id duplication on broker	opts.SetDefaultPublishHandler(All)	opts.SetAutoReconnect(true)	opts.SetConnectRetry(true)	opts.SetConnectRetryInterval(time.Second << 2)	opts.SetKeepAlive(time.Minute >> 1)	opts.SetWriteTimeout(time.Minute)	opts.SetOnConnectHandler(func(client mqtt.Client) {		if !publish {			networks, err := ncutils.GetSystemNetworks()			if err != nil {				logger.Log(0, "error retriving networks ", err.Error())			}			for _, network := range networks {				var currNodeCfg config.ClientConfig				currNodeCfg.Network = network				currNodeCfg.ReadConfig()				setSubscriptions(client, &currNodeCfg)			}		}	})	opts.SetOrderMatters(true)	opts.SetResumeSubs(true)	opts.SetConnectionLostHandler(func(c mqtt.Client, e error) {		logger.Log(0, "detected broker connection lost, running pull for ", commsCfg.Node.Network)		_, err := Pull(commsCfg.Node.Network, true)		if err != nil {			logger.Log(0, "could not run pull, server unreachable: ", err.Error())			logger.Log(0, "waiting to retry...")		}		logger.Log(0, "connection re-established with mqtt server")	})	client := mqtt.NewClient(opts)	tperiod := time.Now().Add(12 * time.Second)	for {		//if after 12 seconds, try a gRPC pull on the last try		if time.Now().After(tperiod) {			logger.Log(0, "running pull for ", commsCfg.Node.Network)			_, err := Pull(commsCfg.Node.Network, true)			if err != nil {				logger.Log(0, "could not run pull, exiting ", commsCfg.Node.Network, " setup: ", err.Error())				return client			}			time.Sleep(time.Second)		}		if token := client.Connect(); token.Wait() && token.Error() != nil {			logger.Log(0, "unable to connect to broker, retrying ...")			if time.Now().After(tperiod) {				logger.Log(0, "could not connect to broker, exiting ", commsCfg.Node.Network, " setup: ", token.Error().Error())				if strings.Contains(token.Error().Error(), "connectex") || strings.Contains(token.Error().Error(), "i/o timeout") {					logger.Log(0, "connection issue detected.. pulling and restarting daemon")					Pull(commsCfg.Node.Network, true)					daemon.Restart()				}				return client			}		} else {			break		}		time.Sleep(2 * time.Second)	}	return client}// publishes a message to server to update peers on this peer's behalffunc publishSignal(commsCfg, nodeCfg *config.ClientConfig, signal byte) error {	if err := publish(commsCfg, nodeCfg, fmt.Sprintf("signal/%s", nodeCfg.Node.ID), []byte{signal}, 1); err != nil {		return err	}	return nil}func initialPull(network string) {	logger.Log(0, "pulling latest config for ", network)	var configPath = fmt.Sprintf("%snetconfig-%s", ncutils.GetNetclientPathSpecific(), network)	fileInfo, err := os.Stat(configPath)	if err != nil {		logger.Log(0, "could not stat config file: ", configPath)		return	}	// speed up UDP rest	if !fileInfo.ModTime().IsZero() && time.Now().After(fileInfo.ModTime().Add(time.Minute)) {		sleepTime := 2		for {			_, err := Pull(network, true)			if err == nil {				break			}			if sleepTime > 3600 {				sleepTime = 3600			}			logger.Log(0, "failed to pull for network ", network)			logger.Log(0, fmt.Sprintf("waiting %d seconds to retry...", sleepTime))			time.Sleep(time.Second * time.Duration(sleepTime))			sleepTime = sleepTime * 2		}		time.Sleep(time.Second << 1)	}}func parseNetworkFromTopic(topic string) string {	return strings.Split(topic, "/")[1]}// should only ever use node client configsfunc decryptMsg(nodeCfg *config.ClientConfig, msg []byte) ([]byte, error) {	if len(msg) <= 24 { // make sure message is of appropriate length		return nil, fmt.Errorf("recieved invalid message from broker %v", msg)	}	// setup the keys	diskKey, keyErr := auth.RetrieveTrafficKey(nodeCfg.Node.Network)	if keyErr != nil {		return nil, keyErr	}	serverPubKey, err := ncutils.ConvertBytesToKey(nodeCfg.Node.TrafficKeys.Server)	if err != nil {		return nil, err	}	return ncutils.DeChunk(msg, serverPubKey, diskKey)}func getServerAddress(cfg *config.ClientConfig) string {	var server models.ServerAddr	for _, server = range cfg.Node.NetworkSettings.DefaultServerAddrs {		if server.Address != "" && server.IsLeader {			break		}	}	return server.Address}func getCommsNetworks(networks []string) (map[string]bool, error) {	var cfg config.ClientConfig	var response = make(map[string]bool, 1)	for _, network := range networks {		cfg.Network = network		cfg.ReadConfig()		response[cfg.Node.CommID] = true	}	return response, nil}func getCommsCfgByNode(node *models.Node) config.ClientConfig {	var commsCfg config.ClientConfig	commsCfg.Network = node.CommID	commsCfg.ReadConfig()	return commsCfg}// == Message Caches ==func insert(network, which, cache string) {	var newMessage = cachedMessage{		Message:  cache,		LastSeen: time.Now(),	}	messageCache.Store(fmt.Sprintf("%s%s", network, which), newMessage)}func read(network, which string) string {	val, isok := messageCache.Load(fmt.Sprintf("%s%s", network, which))	if isok {		var readMessage = val.(cachedMessage) // fetch current cached message		if readMessage.LastSeen.IsZero() {			return ""		}		if time.Now().After(readMessage.LastSeen.Add(time.Minute * 10)) { // check if message has been there over a minute			messageCache.Delete(fmt.Sprintf("%s%s", network, which)) // remove old message if expired			return ""		}		return readMessage.Message // return current message if not expired	}	return ""}// == End Message Caches ==
 |