Browse Source

Merge pull request #1053 from gravitl/refactor_v0.13.1_mq_timeout

Refactor v0.13.1 mq timeout
Alex Feiszli 3 years ago
parent
commit
9cebde77f5
2 changed files with 46 additions and 51 deletions
  1. 33 46
      netclient/functions/daemon.go
  2. 13 5
      netclient/functions/mqpublish.go

+ 33 - 46
netclient/functions/daemon.go

@@ -2,6 +2,8 @@ package functions
 
 import (
 	"context"
+	"crypto/ed25519"
+	"crypto/rand"
 	"crypto/tls"
 	"crypto/x509"
 	"errors"
@@ -22,6 +24,7 @@ import (
 	"github.com/gravitl/netmaker/netclient/daemon"
 	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/netclient/wireguard"
+	ssl "github.com/gravitl/netmaker/tls"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
@@ -38,8 +41,7 @@ type cachedMessage struct {
 
 // Daemon runs netclient daemon from command line
 func Daemon() error {
-	var exists = struct{}{}
-	serverSet := make(map[string]struct{})
+	serverSet := make(map[string]config.ClientConfig)
 	// == initial pull of all networks ==
 	networks, _ := ncutils.GetSystemNetworks()
 	if len(networks) == 0 {
@@ -50,7 +52,7 @@ func Daemon() error {
 		cfg := config.ClientConfig{}
 		cfg.Network = network
 		cfg.ReadConfig()
-		serverSet[cfg.Server.Server] = exists
+		serverSet[cfg.Server.Server] = cfg
 		//temporary code --- remove in version v0.13.0
 		removeHostDNS(network, ncutils.IsWindows())
 		// end of code to be removed in version v0.13.0
@@ -58,11 +60,11 @@ func Daemon() error {
 	}
 
 	// == subscribe to all nodes for each on machine ==
-	for server := range serverSet {
+	for server, config := range serverSet {
 		logger.Log(1, "started daemon for server ", server)
 		ctx, cancel := context.WithCancel(context.Background())
 		networkcontext.Store(server, cancel)
-		go messageQueue(ctx, server)
+		go messageQueue(ctx, &config)
 	}
 
 	// == add waitgroup and cancel for checkin routine ==
@@ -115,10 +117,11 @@ func PingServer(cfg *config.ClientConfig) error {
 		return err
 	}
 	pinger.Timeout = 2 * time.Second
+	pinger.Count = 3
 	pinger.Run()
 	stats := pinger.Statistics()
 	if stats.PacketLoss == 100 {
-		return errors.New("ping error")
+		return errors.New("ping error " + fmt.Sprintf("%f", stats.PacketLoss))
 	}
 	logger.Log(3, "ping of server", cfg.Server.Server, "was successful")
 	return nil
@@ -168,12 +171,12 @@ func unsubscribeNode(client mqtt.Client, nodeCfg *config.ClientConfig) {
 
 // sets up Message Queue and subsribes/publishes updates to/from server
 // the client should subscribe to ALL nodes that exist on server locally
-func messageQueue(ctx context.Context, server string) {
-	logger.Log(0, "netclient daemon started for server: ", server)
-	client := setupMQTT(nil, server, false)
+func messageQueue(ctx context.Context, cfg *config.ClientConfig) {
+	logger.Log(0, "netclient daemon started for server: ", cfg.Server.Server)
+	client := setupMQTT(cfg, false)
 	defer client.Disconnect(250)
 	<-ctx.Done()
-	logger.Log(0, "shutting down daemon for server ", server)
+	logger.Log(0, "shutting down daemon for server ", cfg.Server.Server)
 }
 
 // NewTLSConf sets up tls configuration to connect to broker securely
@@ -204,11 +207,9 @@ func NewTLSConfig(server string) *tls.Config {
 
 // setupMQTT creates a connection to broker and returns client
 // this function is primarily used to create a connection to publish to the broker
-func setupMQTT(cfg *config.ClientConfig, server string, publish bool) mqtt.Client {
+func setupMQTT(cfg *config.ClientConfig, publish bool) mqtt.Client {
 	opts := mqtt.NewClientOptions()
-	if cfg != nil {
-		server = cfg.Server.Server
-	}
+	server := cfg.Server.Server
 	opts.AddBroker("ssl://" + server + ":8883") // TODO get the appropriate port of the comms mq server
 	opts.SetTLSConfig(NewTLSConfig(server))
 	opts.SetClientID(ncutils.MakeRandomString(23))
@@ -236,44 +237,30 @@ func setupMQTT(cfg *config.ClientConfig, server string, publish bool) mqtt.Clien
 	opts.SetOrderMatters(true)
 	opts.SetResumeSubs(true)
 	opts.SetConnectionLostHandler(func(c mqtt.Client, e error) {
-		logger.Log(0, "detected broker connection lost, running pull for ", cfg.Node.Network)
-		_, err := Pull(cfg.Node.Network, true)
-		if err != nil {
-			logger.Log(0, "could not run pull, server unreachable: ", err.Error())
-			logger.Log(0, "waiting to retry...")
-		}
-		logger.Log(0, "connection re-established with mqtt server")
+		logger.Log(0, "detected broker connection lost for", cfg.Server.Server)
 	})
 	client := mqtt.NewClient(opts)
-
-	tperiod := time.Now().Add(12 * time.Second)
-	for {
-		//if after 12 seconds, try a pull on the last try
-		if time.Now().After(tperiod) {
-			logger.Log(0, "running pull for ", cfg.Node.Network)
-			_, err := Pull(cfg.Node.Network, true)
-			if err != nil {
-				logger.Log(0, "could not run pull, exiting ", cfg.Node.Network, " setup: ", err.Error())
-				return client
-			}
-			time.Sleep(time.Second)
+	for token := client.Connect(); !token.WaitTimeout(30*time.Second) || token.Error() != nil; token = client.Connect() {
+		logger.Log(0, "unable to connect to broker, retrying ...")
+		var err error
+		if token.Error() == nil {
+			err = errors.New("connect timeout")
+		} else {
+			err = token.Error()
 		}
-		if token := client.Connect(); token.Wait() && token.Error() != nil {
-
-			logger.Log(0, "unable to connect to broker, retrying ...")
-			if time.Now().After(tperiod) {
-				logger.Log(0, "could not connect to broker, exiting ", cfg.Node.Network, " setup: ", token.Error().Error())
-				if strings.Contains(token.Error().Error(), "connectex") || strings.Contains(token.Error().Error(), "i/o timeout") {
-					logger.Log(0, "connection issue detected.. pulling and restarting daemon")
-					Pull(cfg.Node.Network, true)
-					daemon.Restart()
+		logger.Log(0, "could not connect to broker", cfg.Server.Server, err.Error())
+		if strings.Contains(err.Error(), "connectex") || strings.Contains(err.Error(), "connect timeout") {
+			logger.Log(0, "connection issue detected.. attempt connection with new certs")
+			key, err := ssl.ReadKey(ncutils.GetNetclientPath() + ncutils.GetSeparator() + "client.key")
+			if err != nil {
+				_, *key, err = ed25519.GenerateKey(rand.Reader)
+				if err != nil {
+					log.Fatal("could not generate new key")
 				}
-				return client
 			}
-		} else {
-			break
+			RegisterWithServer(key, cfg)
+			daemon.Restart()
 		}
-		time.Sleep(2 * time.Second)
 	}
 	return client
 }

+ 13 - 5
netclient/functions/mqpublish.go

@@ -76,7 +76,7 @@ func Checkin(ctx context.Context, wg *sync.WaitGroup) {
 					}
 				}
 				if err := PingServer(&nodeCfg); err != nil {
-					logger.Log(0, "could not ping server for , ", nodeCfg.Network, "\n", err.Error())
+					logger.Log(0, "could not ping server for", nodeCfg.Network, nodeCfg.Server.Server+"\n", err.Error())
 				} else {
 					Hello(&nodeCfg)
 				}
@@ -128,17 +128,25 @@ func publish(nodeCfg *config.ClientConfig, dest string, msg []byte, qos byte) er
 		return err
 	}
 
-	client := setupMQTT(nodeCfg, "", true)
+	client := setupMQTT(nodeCfg, true)
 	defer client.Disconnect(250)
 	encrypted, err := ncutils.Chunk(msg, serverPubKey, trafficPrivKey)
 	if err != nil {
 		return err
 	}
 
-	if token := client.Publish(dest, qos, false, encrypted); token.Wait() && token.Error() != nil {
-		return token.Error()
+	if token := client.Publish(dest, qos, false, encrypted); !token.WaitTimeout(30*time.Second) || token.Error() != nil {
+		logger.Log(0, "could not connect to broker at "+nodeCfg.Server.Server+":8883")
+		var err error
+		if token.Error() == nil {
+			err = errors.New("connection timeout")
+		} else {
+			err = token.Error()
+		}
+		if err != nil {
+			return token.Error()
+		}
 	}
-
 	return nil
 }