Browse Source

Merge pull request #1382 from gravitl/feature_v0.14.6_daemon_restart

Feature v0.14.6 daemon restart
dcarns 3 years ago
parent
commit
80d0568f64

+ 25 - 3
netclient/daemon/common.go

@@ -2,8 +2,13 @@ package daemon
 
 
 import (
 import (
 	"errors"
 	"errors"
+	"fmt"
+	"os"
 	"runtime"
 	"runtime"
+	"syscall"
 	"time"
 	"time"
+
+	"github.com/gravitl/netmaker/netclient/ncutils"
 )
 )
 
 
 // InstallDaemon - Calls the correct function to install the netclient as a daemon service on the given operating system.
 // InstallDaemon - Calls the correct function to install the netclient as a daemon service on the given operating system.
@@ -28,11 +33,28 @@ func InstallDaemon() error {
 
 
 // Restart - restarts a system daemon
 // Restart - restarts a system daemon
 func Restart() error {
 func Restart() error {
+	if ncutils.IsWindows() {
+		RestartWindowsDaemon()
+		return nil
+	}
+	pid, err := ncutils.ReadPID()
+	if err != nil {
+		return fmt.Errorf("failed to find pid %w", err)
+	}
+	p, err := os.FindProcess(pid)
+	if err != nil {
+		return fmt.Errorf("failed to find running process for pid %d -- %w", pid, err)
+	}
+	if err := p.Signal(syscall.SIGHUP); err != nil {
+		return fmt.Errorf("SIGHUP failed -- %w", err)
+	}
+	return nil
+}
+
+// Start - starts system daemon
+func Start() error {
 	os := runtime.GOOS
 	os := runtime.GOOS
 	var err error
 	var err error
-
-	time.Sleep(time.Second)
-
 	switch os {
 	switch os {
 	case "windows":
 	case "windows":
 		RestartWindowsDaemon()
 		RestartWindowsDaemon()

+ 1 - 0
netclient/daemon/freebsd.go

@@ -108,6 +108,7 @@ func FreebsdDaemon(command string) {
 
 
 // CleanupFreebsd - removes config files and netclient binary
 // CleanupFreebsd - removes config files and netclient binary
 func CleanupFreebsd() {
 func CleanupFreebsd() {
+	RemoveFreebsdDaemon()
 	if err := os.RemoveAll(ncutils.GetNetclientPath()); err != nil {
 	if err := os.RemoveAll(ncutils.GetNetclientPath()); err != nil {
 		logger.Log(1, "Removing netclient configs: ", err.Error())
 		logger.Log(1, "Removing netclient configs: ", err.Error())
 	}
 	}

+ 1 - 0
netclient/daemon/systemd.go

@@ -83,6 +83,7 @@ func RestartSystemD() {
 
 
 // CleanupLinux - cleans up neclient configs
 // CleanupLinux - cleans up neclient configs
 func CleanupLinux() {
 func CleanupLinux() {
+	RemoveSystemDServices()
 	if err := os.RemoveAll(ncutils.GetNetclientPath()); err != nil {
 	if err := os.RemoveAll(ncutils.GetNetclientPath()); err != nil {
 		logger.Log(1, "Removing netclient configs: ", err.Error())
 		logger.Log(1, "Removing netclient configs: ", err.Error())
 	}
 	}

+ 1 - 22
netclient/functions/common.go

@@ -139,6 +139,7 @@ func Uninstall() error {
 		}
 		}
 	}
 	}
 	err = nil
 	err = nil
+
 	// clean up OS specific stuff
 	// clean up OS specific stuff
 	if ncutils.IsWindows() {
 	if ncutils.IsWindows() {
 		daemon.CleanupWindows()
 		daemon.CleanupWindows()
@@ -213,31 +214,9 @@ func LeaveNetwork(network string) error {
 		logger.Log(1, "removed ", node.Network, " network locally")
 		logger.Log(1, "removed ", node.Network, " network locally")
 	}
 	}
 
 
-	currentNets, err := ncutils.GetSystemNetworks()
-	if err != nil || len(currentNets) <= 1 {
-		daemon.Stop() // stop system daemon if last network
-		return RemoveLocalInstance(cfg, network)
-	}
 	return daemon.Restart()
 	return daemon.Restart()
 }
 }
 
 
-// RemoveLocalInstance - remove all netclient files locally for a network
-func RemoveLocalInstance(cfg *config.ClientConfig, networkName string) error {
-
-	if cfg.Daemon != "off" {
-		if ncutils.IsWindows() {
-			// TODO: Remove job?
-		} else if ncutils.IsMac() {
-			//TODO: Delete mac daemon
-		} else if ncutils.IsFreeBSD() {
-			daemon.RemoveFreebsdDaemon()
-		} else {
-			daemon.RemoveSystemDServices()
-		}
-	}
-	return nil
-}
-
 // DeleteInterface - delete an interface of a network
 // DeleteInterface - delete an interface of a network
 func DeleteInterface(ifacename string, postdown string) error {
 func DeleteInterface(ifacename string, postdown string) error {
 	return wireguard.RemoveConf(ifacename, true)
 	return wireguard.RemoveConf(ifacename, true)

+ 46 - 32
netclient/functions/daemon.go

@@ -30,7 +30,8 @@ import (
 )
 )
 
 
 var messageCache = new(sync.Map)
 var messageCache = new(sync.Map)
-var networkcontext = new(sync.Map)
+
+var serverSet map[string]bool
 
 
 const lastNodeUpdate = "lnu"
 const lastNodeUpdate = "lnu"
 const lastPeerUpdate = "lpu"
 const lastPeerUpdate = "lpu"
@@ -42,20 +43,51 @@ type cachedMessage struct {
 
 
 // Daemon runs netclient daemon from command line
 // Daemon runs netclient daemon from command line
 func Daemon() error {
 func Daemon() error {
+	logger.Log(0, "netclient daemon started -- version:", ncutils.Version)
 	UpdateClientConfig()
 	UpdateClientConfig()
-	serverSet := make(map[string]bool)
-	// == initial pull of all networks ==
-	networks, _ := ncutils.GetSystemNetworks()
-	if len(networks) == 0 {
-		return errors.New("no networks")
+	if err := ncutils.SavePID(); err != nil {
+		return err
 	}
 	}
-	pubNetworks = append(pubNetworks, networks...)
+	// reference required to eliminate unused statticcheck
+	serverSet = make(map[string]bool)
+	serverSet["dummy"] = false
 	// set ipforwarding on startup
 	// set ipforwarding on startup
 	err := local.SetIPForwarding()
 	err := local.SetIPForwarding()
 	if err != nil {
 	if err != nil {
 		logger.Log(0, err.Error())
 		logger.Log(0, err.Error())
 	}
 	}
 
 
+	// == add waitgroup and cancel for checkin routine ==
+	wg := sync.WaitGroup{}
+	quit := make(chan os.Signal, 1)
+	reset := make(chan os.Signal, 1)
+	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+	signal.Notify(reset, syscall.SIGHUP)
+	cancel := startGoRoutines(&wg)
+	for {
+		select {
+		case <-quit:
+			cancel()
+			logger.Log(0, "shutting down netclient daemon")
+			wg.Wait()
+			logger.Log(0, "shutdown complete")
+			return nil
+		case <-reset:
+			logger.Log(0, "received reset")
+			cancel()
+			wg.Wait()
+			logger.Log(0, "restarting daemon")
+			cancel = startGoRoutines(&wg)
+		}
+	}
+}
+
+func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc {
+	ctx, cancel := context.WithCancel(context.Background())
+	wg.Add(1)
+	go Checkin(ctx, wg)
+	serverSet := make(map[string]bool)
+	networks, _ := ncutils.GetSystemNetworks()
 	for _, network := range networks {
 	for _, network := range networks {
 		logger.Log(3, "initializing network", network)
 		logger.Log(3, "initializing network", network)
 		cfg := config.ClientConfig{}
 		cfg := config.ClientConfig{}
@@ -69,30 +101,11 @@ func Daemon() error {
 			// == subscribe to all nodes for each on machine ==
 			// == subscribe to all nodes for each on machine ==
 			serverSet[server] = true
 			serverSet[server] = true
 			logger.Log(1, "started daemon for server ", server)
 			logger.Log(1, "started daemon for server ", server)
-			ctx, cancel := context.WithCancel(context.Background())
-			networkcontext.Store(server, cancel)
-			go messageQueue(ctx, &cfg)
+			wg.Add(1)
+			go messageQueue(ctx, wg, &cfg)
 		}
 		}
 	}
 	}
-
-	// == add waitgroup and cancel for checkin routine ==
-	wg := sync.WaitGroup{}
-	ctx, cancel := context.WithCancel(context.Background())
-	wg.Add(1)
-	go Checkin(ctx, &wg)
-	quit := make(chan os.Signal, 1)
-	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
-	<-quit
-	for server := range serverSet {
-		if cancel, ok := networkcontext.Load(server); ok {
-			cancel.(context.CancelFunc)()
-		}
-	}
-	cancel()
-	logger.Log(0, "shutting down netclient daemon")
-	wg.Wait()
-	logger.Log(0, "shutdown complete")
-	return nil
+	return cancel
 }
 }
 
 
 // UpdateKeys -- updates private key and returns new publickey
 // UpdateKeys -- updates private key and returns new publickey
@@ -167,8 +180,9 @@ func unsubscribeNode(client mqtt.Client, nodeCfg *config.ClientConfig) {
 
 
 // sets up Message Queue and subsribes/publishes updates to/from server
 // sets up Message Queue and subsribes/publishes updates to/from server
 // the client should subscribe to ALL nodes that exist on server locally
 // the client should subscribe to ALL nodes that exist on server locally
-func messageQueue(ctx context.Context, cfg *config.ClientConfig) {
-	logger.Log(0, "netclient daemon started for server: ", cfg.Server.Server)
+func messageQueue(ctx context.Context, wg *sync.WaitGroup, cfg *config.ClientConfig) {
+	defer wg.Done()
+	logger.Log(0, "netclient message queue started for server: ", cfg.Server.Server)
 	client, err := setupMQTT(cfg, false)
 	client, err := setupMQTT(cfg, false)
 	if err != nil {
 	if err != nil {
 		logger.Log(0, "unable to connect to broker", cfg.Server.Server, err.Error())
 		logger.Log(0, "unable to connect to broker", cfg.Server.Server, err.Error())
@@ -176,7 +190,7 @@ func messageQueue(ctx context.Context, cfg *config.ClientConfig) {
 	}
 	}
 	defer client.Disconnect(250)
 	defer client.Disconnect(250)
 	<-ctx.Done()
 	<-ctx.Done()
-	logger.Log(0, "shutting down daemon for server ", cfg.Server.Server)
+	logger.Log(0, "shutting down message queue for server ", cfg.Server.Server)
 }
 }
 
 
 // NewTLSConf sets up tls configuration to connect to broker securely
 // NewTLSConf sets up tls configuration to connect to broker securely

+ 3 - 0
netclient/functions/install.go

@@ -1,6 +1,8 @@
 package functions
 package functions
 
 
 import (
 import (
+	"time"
+
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/netclient/daemon"
 	"github.com/gravitl/netmaker/netclient/daemon"
 )
 )
@@ -12,5 +14,6 @@ func Install() error {
 		logger.Log(0, "error installing daemon", err.Error())
 		logger.Log(0, "error installing daemon", err.Error())
 		return err
 		return err
 	}
 	}
+	time.Sleep(time.Second * 5)
 	return daemon.Restart()
 	return daemon.Restart()
 }
 }

+ 6 - 1
netclient/functions/join.go

@@ -212,7 +212,12 @@ func JoinNetwork(cfg *config.ClientConfig, privateKey string) error {
 		}
 		}
 	}
 	}
 
 
-	daemon.Restart()
+	if err := daemon.Restart(); err != nil {
+		log.Println("daemon restart failed ", err)
+		if err := daemon.Start(); err != nil {
+			return err
+		}
+	}
 	return nil
 	return nil
 }
 }
 
 

+ 49 - 45
netclient/functions/mqpublish.go

@@ -19,12 +19,10 @@ import (
 	"github.com/gravitl/netmaker/tls"
 	"github.com/gravitl/netmaker/tls"
 )
 )
 
 
-// pubNetworks hold the currently publishable networks
-var pubNetworks []string
-
 // Checkin  -- go routine that checks for public or local ip changes, publishes changes
 // Checkin  -- go routine that checks for public or local ip changes, publishes changes
 //   if there are no updates, simply "pings" the server as a checkin
 //   if there are no updates, simply "pings" the server as a checkin
 func Checkin(ctx context.Context, wg *sync.WaitGroup) {
 func Checkin(ctx context.Context, wg *sync.WaitGroup) {
+	logger.Log(2, "starting checkin goroutine")
 	defer wg.Done()
 	defer wg.Done()
 	for {
 	for {
 		select {
 		select {
@@ -33,52 +31,58 @@ func Checkin(ctx context.Context, wg *sync.WaitGroup) {
 			return
 			return
 			//delay should be configuraable -> use cfg.Node.NetworkSettings.DefaultCheckInInterval ??
 			//delay should be configuraable -> use cfg.Node.NetworkSettings.DefaultCheckInInterval ??
 		case <-time.After(time.Second * 60):
 		case <-time.After(time.Second * 60):
-			for _, network := range pubNetworks {
-				var nodeCfg config.ClientConfig
-				nodeCfg.Network = network
-				nodeCfg.ReadConfig()
-				if nodeCfg.Node.IsStatic != "yes" {
-					extIP, err := ncutils.GetPublicIP()
-					if err != nil {
-						logger.Log(1, "error encountered checking public ip addresses: ", err.Error())
-					}
-					if nodeCfg.Node.Endpoint != extIP && extIP != "" {
-						logger.Log(1, "endpoint has changed from ", nodeCfg.Node.Endpoint, " to ", extIP)
-						nodeCfg.Node.Endpoint = extIP
-						if err := PublishNodeUpdate(&nodeCfg); err != nil {
-							logger.Log(0, "could not publish endpoint change")
-						}
-					}
-					intIP, err := getPrivateAddr()
-					if err != nil {
-						logger.Log(1, "error encountered checking private ip addresses: ", err.Error())
-					}
-					if nodeCfg.Node.LocalAddress != intIP && intIP != "" {
-						logger.Log(1, "local Address has changed from ", nodeCfg.Node.LocalAddress, " to ", intIP)
-						nodeCfg.Node.LocalAddress = intIP
-						if err := PublishNodeUpdate(&nodeCfg); err != nil {
-							logger.Log(0, "could not publish local address change")
-						}
-					}
-					_ = UpdateLocalListenPort(&nodeCfg)
+			checkin()
+		}
+	}
+}
+
+func checkin() {
+	networks, _ := ncutils.GetSystemNetworks()
+	logger.Log(3, "checkin with server(s) for all networks")
+	for _, network := range networks {
+		var nodeCfg config.ClientConfig
+		nodeCfg.Network = network
+		nodeCfg.ReadConfig()
+		if nodeCfg.Node.IsStatic != "yes" {
+			extIP, err := ncutils.GetPublicIP()
+			if err != nil {
+				logger.Log(1, "error encountered checking public ip addresses: ", err.Error())
+			}
+			if nodeCfg.Node.Endpoint != extIP && extIP != "" {
+				logger.Log(1, "endpoint has changed from ", nodeCfg.Node.Endpoint, " to ", extIP)
+				nodeCfg.Node.Endpoint = extIP
+				if err := PublishNodeUpdate(&nodeCfg); err != nil {
+					logger.Log(0, "could not publish endpoint change")
+				}
+			}
+			intIP, err := getPrivateAddr()
+			if err != nil {
+				logger.Log(1, "error encountered checking private ip addresses: ", err.Error())
+			}
+			if nodeCfg.Node.LocalAddress != intIP && intIP != "" {
+				logger.Log(1, "local Address has changed from ", nodeCfg.Node.LocalAddress, " to ", intIP)
+				nodeCfg.Node.LocalAddress = intIP
+				if err := PublishNodeUpdate(&nodeCfg); err != nil {
+					logger.Log(0, "could not publish local address change")
+				}
+			}
+			_ = UpdateLocalListenPort(&nodeCfg)
 
 
-				} else if nodeCfg.Node.IsLocal == "yes" && nodeCfg.Node.LocalRange != "" {
-					localIP, err := ncutils.GetLocalIP(nodeCfg.Node.LocalRange)
-					if err != nil {
-						logger.Log(1, "error encountered checking local ip addresses: ", err.Error())
-					}
-					if nodeCfg.Node.Endpoint != localIP && localIP != "" {
-						logger.Log(1, "endpoint has changed from "+nodeCfg.Node.Endpoint+" to ", localIP)
-						nodeCfg.Node.Endpoint = localIP
-						if err := PublishNodeUpdate(&nodeCfg); err != nil {
-							logger.Log(0, "could not publish localip change")
-						}
-					}
+		} else if nodeCfg.Node.IsLocal == "yes" && nodeCfg.Node.LocalRange != "" {
+			localIP, err := ncutils.GetLocalIP(nodeCfg.Node.LocalRange)
+			if err != nil {
+				logger.Log(1, "error encountered checking local ip addresses: ", err.Error())
+			}
+			if nodeCfg.Node.Endpoint != localIP && localIP != "" {
+				logger.Log(1, "endpoint has changed from "+nodeCfg.Node.Endpoint+" to ", localIP)
+				nodeCfg.Node.Endpoint = localIP
+				if err := PublishNodeUpdate(&nodeCfg); err != nil {
+					logger.Log(0, "could not publish localip change")
 				}
 				}
-				Hello(&nodeCfg)
-				checkCertExpiry(&nodeCfg)
 			}
 			}
 		}
 		}
+		Hello(&nodeCfg)
+		checkCertExpiry(&nodeCfg)
 	}
 	}
 }
 }
 
 

+ 46 - 0
netclient/ncutils/pid.go

@@ -0,0 +1,46 @@
+package ncutils
+
+import (
+	"fmt"
+	"os"
+	"strconv"
+)
+
+// PIDFILE - path/name of pid file
+const PIDFILE = "/var/run/netclient.pid"
+
+// WindowsPIDError - error returned from pid function on windows
+type WindowsPIDError struct{}
+
+// Error generates error for windows os
+func (*WindowsPIDError) Error() string {
+	return "pid tracking not supported on windows"
+}
+
+// SavePID - saves the pid of running program to disk
+func SavePID() error {
+	if IsWindows() {
+		return &WindowsPIDError{}
+	}
+	pid := os.Getpid()
+	if err := os.WriteFile(PIDFILE, []byte(fmt.Sprintf("%d", pid)), 0644); err != nil {
+		return fmt.Errorf("could not write to pid file %w", err)
+	}
+	return nil
+}
+
+// ReadPID - reads a previously saved pid from disk
+func ReadPID() (int, error) {
+	if IsWindows() {
+		return 0, &WindowsPIDError{}
+	}
+	bytes, err := os.ReadFile(PIDFILE)
+	if err != nil {
+		return 0, fmt.Errorf("could not read pid file %w", err)
+	}
+	pid, err := strconv.Atoi(string(bytes))
+	if err != nil {
+		return 0, fmt.Errorf("pid file contents invalid %w", err)
+	}
+	return pid, nil
+}