Browse Source

use SIGHUP to restart daemon

Matthew R. Kasun 3 years ago
parent
commit
d36de447ac
4 changed files with 92 additions and 43 deletions
  1. 16 17
      netclient/daemon/common.go
  2. 41 26
      netclient/functions/daemon.go
  3. 3 0
      netclient/functions/install.go
  4. 32 0
      netclient/ncutils/pid.go

+ 16 - 17
netclient/daemon/common.go

@@ -2,8 +2,13 @@ package daemon
 
 import (
 	"errors"
+	"fmt"
+	"os"
 	"runtime"
+	"syscall"
 	"time"
+
+	"github.com/gravitl/netmaker/netclient/ncutils"
 )
 
 // InstallDaemon - Calls the correct function to install the netclient as a daemon service on the given operating system.
@@ -28,24 +33,18 @@ func InstallDaemon() error {
 
 // Restart - restarts a system daemon
 func Restart() error {
-	os := runtime.GOOS
-	var err error
-
-	time.Sleep(time.Second)
-
-	switch os {
-	case "windows":
-		RestartWindowsDaemon()
-	case "darwin":
-		RestartLaunchD()
-	case "linux":
-		RestartSystemD()
-	case "freebsd":
-		FreebsdDaemon("restart")
-	default:
-		err = errors.New("this os is not yet supported for daemon mode. Run join cmd with flag '--daemon off'")
+	pid, err := ncutils.ReadPID()
+	if err != nil {
+		return fmt.Errorf("failed to find pid %w", err)
 	}
-	return err
+	p, err := os.FindProcess(pid)
+	if err != nil {
+		return fmt.Errorf("failed to find running process for pid %d -- %w", pid, err)
+	}
+	if err := p.Signal(syscall.SIGHUP); err != nil {
+		return fmt.Errorf("SIGHUP failed -- %w", err)
+	}
+	return nil
 }
 
 // Stop - stops a system daemon

+ 41 - 26
netclient/functions/daemon.go

@@ -30,7 +30,7 @@ import (
 )
 
 var messageCache = new(sync.Map)
-var networkcontext = new(sync.Map)
+var serverSet map[string]bool
 
 const lastNodeUpdate = "lnu"
 const lastPeerUpdate = "lpu"
@@ -43,19 +43,53 @@ type cachedMessage struct {
 // Daemon runs netclient daemon from command line
 func Daemon() error {
 	UpdateClientConfig()
-	serverSet := make(map[string]bool)
+	if err := ncutils.SavePID(); err != nil {
+		return err
+	}
+	serverSet = make(map[string]bool)
 	// == initial pull of all networks ==
 	networks, _ := ncutils.GetSystemNetworks()
 	if len(networks) == 0 {
 		return errors.New("no networks")
 	}
-	pubNetworks = append(pubNetworks, networks...)
 	// set ipforwarding on startup
 	err := local.SetIPForwarding()
 	if err != nil {
 		logger.Log(0, err.Error())
 	}
 
+	// == add waitgroup and cancel for checkin routine ==
+	wg := sync.WaitGroup{}
+	quit := make(chan os.Signal, 1)
+	reset := make(chan os.Signal, 1)
+	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
+	signal.Notify(reset, syscall.SIGHUP)
+	cancel := startGoRoutines(&wg)
+	for {
+		select {
+		case <-quit:
+			cancel()
+			logger.Log(0, "shutting down netclient daemon")
+			wg.Wait()
+			logger.Log(0, "shutdown complete")
+			return nil
+		case <-reset:
+			logger.Log(0, "received reset")
+			cancel()
+			wg.Wait()
+			logger.Log(0, "restarting daemon")
+			cancel = startGoRoutines(&wg)
+		}
+	}
+}
+
+func startGoRoutines(wg *sync.WaitGroup) context.CancelFunc {
+	defer wg.Done()
+	ctx, cancel := context.WithCancel(context.Background())
+	wg.Add(1)
+	go Checkin(ctx, wg)
+	serverSet := make(map[string]bool)
+	networks, _ := ncutils.GetSystemNetworks()
 	for _, network := range networks {
 		logger.Log(3, "initializing network", network)
 		cfg := config.ClientConfig{}
@@ -69,30 +103,10 @@ func Daemon() error {
 			// == subscribe to all nodes for each on machine ==
 			serverSet[server] = true
 			logger.Log(1, "started daemon for server ", server)
-			ctx, cancel := context.WithCancel(context.Background())
-			networkcontext.Store(server, cancel)
-			go messageQueue(ctx, &cfg)
-		}
-	}
-
-	// == add waitgroup and cancel for checkin routine ==
-	wg := sync.WaitGroup{}
-	ctx, cancel := context.WithCancel(context.Background())
-	wg.Add(1)
-	go Checkin(ctx, &wg)
-	quit := make(chan os.Signal, 1)
-	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
-	<-quit
-	for server := range serverSet {
-		if cancel, ok := networkcontext.Load(server); ok {
-			cancel.(context.CancelFunc)()
+			go messageQueue(ctx, wg, &cfg)
 		}
 	}
-	cancel()
-	logger.Log(0, "shutting down netclient daemon")
-	wg.Wait()
-	logger.Log(0, "shutdown complete")
-	return nil
+	return cancel
 }
 
 // UpdateKeys -- updates private key and returns new publickey
@@ -167,7 +181,8 @@ func unsubscribeNode(client mqtt.Client, nodeCfg *config.ClientConfig) {
 
 // sets up Message Queue and subsribes/publishes updates to/from server
 // the client should subscribe to ALL nodes that exist on server locally
-func messageQueue(ctx context.Context, cfg *config.ClientConfig) {
+func messageQueue(ctx context.Context, wg *sync.WaitGroup, cfg *config.ClientConfig) {
+	defer wg.Done()
 	logger.Log(0, "netclient daemon started for server: ", cfg.Server.Server)
 	client, err := setupMQTT(cfg, false)
 	if err != nil {

+ 3 - 0
netclient/functions/install.go

@@ -1,6 +1,8 @@
 package functions
 
 import (
+	"time"
+
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/netclient/daemon"
 )
@@ -12,5 +14,6 @@ func Install() error {
 		logger.Log(0, "error installing daemon", err.Error())
 		return err
 	}
+	time.Sleep(time.Second * 5)
 	return daemon.Restart()
 }

+ 32 - 0
netclient/ncutils/pid.go

@@ -0,0 +1,32 @@
+package ncutils
+
+import (
+	"fmt"
+	"os"
+	"strconv"
+)
+
+// PIDFILE - path/name of pid file
+const PIDFILE = "/var/run/netclient.pid"
+
+// SavePID - saves the pid of running program to disk
+func SavePID() error {
+	pid := os.Getpid()
+	if err := os.WriteFile(PIDFILE, []byte(fmt.Sprintf("%d", pid)), 0644); err != nil {
+		return fmt.Errorf("could not write to pid file %w", err)
+	}
+	return nil
+}
+
+// ReadPID - reads a previously saved pid from disk
+func ReadPID() (int, error) {
+	bytes, err := os.ReadFile(PIDFILE)
+	if err != nil {
+		return 0, fmt.Errorf("could not read pid file %w", err)
+	}
+	pid, err := strconv.Atoi(string(bytes))
+	if err != nil {
+		return 0, fmt.Errorf("pid file contents invalid %w", err)
+	}
+	return pid, nil
+}