Browse Source

refactored some client leave & cache and server join logic

0xdcarns 3 years ago
parent
commit
f7258bf98f

+ 26 - 0
controllers/node_grpc.go

@@ -11,6 +11,7 @@ import (
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/serverctl"
 	"github.com/gravitl/netmaker/serverctl"
 )
 )
@@ -104,6 +105,31 @@ func (s *NodeServiceServer) CreateNode(ctx context.Context, req *nodepb.Object)
 
 
 	runUpdates(&node, false)
 	runUpdates(&node, false)
 
 
+	go func(node *models.Node) {
+		if node.UDPHolePunch == "yes" {
+			var currentServerNodeID, getErr = logic.GetNetworkServerNodeID(node.Network)
+			if getErr != nil {
+				return
+			}
+			var currentServerNode, currErr = logic.GetNodeByID(currentServerNodeID)
+			if currErr != nil {
+				return
+			}
+			for i := 0; i < 5; i++ {
+				if logic.HasPeerConnected(node) {
+					if logic.ShouldPublishPeerPorts(&currentServerNode) {
+						err = mq.PublishPeerUpdate(&currentServerNode)
+						if err != nil {
+							logger.Log(1, "error publishing port updates when node", node.Name, "joined")
+						}
+						break
+					}
+				}
+				time.Sleep(time.Second << 1) // allow time for client to startup
+			}
+		}
+	}(&node)
+
 	return response, nil
 	return response, nil
 }
 }
 
 

+ 21 - 0
logic/wireguard.go

@@ -25,6 +25,27 @@ func RemoveConf(iface string, printlog bool) error {
 	return err
 	return err
 }
 }
 
 
+// HasPeerConnected - checks if a client node has connected over WG
+func HasPeerConnected(node *models.Node) bool {
+	client, err := wgctrl.New()
+	if err != nil {
+		return false
+	}
+	defer client.Close()
+	device, err := client.Device(node.Interface)
+	if err != nil {
+		return false
+	}
+	for _, peer := range device.Peers {
+		if peer.PublicKey.String() == node.PublicKey {
+			if peer.Endpoint != nil {
+				return true
+			}
+		}
+	}
+	return false
+}
+
 // == Private Functions ==
 // == Private Functions ==
 
 
 // gets the server peers locally
 // gets the server peers locally

+ 15 - 15
netclient/cli_options/cmds.go

@@ -62,21 +62,21 @@ func GetCommands(cliFlags []cli.Flag) []*cli.Command {
 				return err
 				return err
 			},
 			},
 		},
 		},
-		{
-			Name:  "push",
-			Usage: "Push configuration changes to server.",
-			Flags: cliFlags,
-			// the action, or code that will be executed when
-			// we execute our `ns` command
-			Action: func(c *cli.Context) error {
-				cfg, _, err := config.GetCLIConfig(c)
-				if err != nil {
-					return err
-				}
-				err = command.Push(cfg)
-				return err
-			},
-		},
+		// {
+		// 	Name:  "push",
+		// 	Usage: "Push configuration changes to server.",
+		// 	Flags: cliFlags,
+		// 	// the action, or code that will be executed when
+		// 	// we execute our `ns` command
+		// 	Action: func(c *cli.Context) error {
+		// 		cfg, _, err := config.GetCLIConfig(c)
+		// 		if err != nil {
+		// 			return err
+		// 		}
+		// 		err = command.Push(cfg)
+		// 		return err
+		// 	},
+		// },
 		{
 		{
 			Name:  "pull",
 			Name:  "pull",
 			Usage: "Pull latest configuration and peers from server.",
 			Usage: "Pull latest configuration and peers from server.",

+ 7 - 2
netclient/functions/common.go

@@ -185,7 +185,7 @@ func LeaveNetwork(network string) error {
 			}
 			}
 		}
 		}
 	}
 	}
-	//extra network route setting required for freebsd and windows
+	// extra network route setting required for freebsd and windows, TODO mac??
 	if ncutils.IsWindows() {
 	if ncutils.IsWindows() {
 		ip, mask, err := ncutils.GetNetworkIPMask(node.NetworkSettings.AddressRange)
 		ip, mask, err := ncutils.GetNetworkIPMask(node.NetworkSettings.AddressRange)
 		if err != nil {
 		if err != nil {
@@ -197,7 +197,12 @@ func LeaveNetwork(network string) error {
 	} else if ncutils.IsLinux() {
 	} else if ncutils.IsLinux() {
 		_, _ = ncutils.RunCmd("ip -4 route del "+node.NetworkSettings.AddressRange+" dev "+node.Interface, false)
 		_, _ = ncutils.RunCmd("ip -4 route del "+node.NetworkSettings.AddressRange+" dev "+node.Interface, false)
 	}
 	}
-	return RemoveLocalInstance(cfg, network)
+
+	currentNets, err := ncutils.GetSystemNetworks()
+	if err != nil || len(currentNets) <= 1 {
+		return RemoveLocalInstance(cfg, network)
+	}
+	return daemon.Restart()
 }
 }
 
 
 // RemoveLocalInstance - remove all netclient files locally for a network
 // RemoveLocalInstance - remove all netclient files locally for a network

+ 32 - 28
netclient/functions/daemon.go

@@ -30,17 +30,31 @@ var messageCache = new(sync.Map)
 const lastNodeUpdate = "lnu"
 const lastNodeUpdate = "lnu"
 const lastPeerUpdate = "lpu"
 const lastPeerUpdate = "lpu"
 
 
+type cachedMessage struct {
+	Message  string
+	LastSeen time.Time
+}
+
 func insert(network, which, cache string) {
 func insert(network, which, cache string) {
-	// var mu sync.Mutex
-	// mu.Lock()
-	// defer mu.Unlock()
-	messageCache.Store(fmt.Sprintf("%s%s", network, which), cache)
+	var newMessage = cachedMessage{
+		Message:  cache,
+		LastSeen: time.Now(),
+	}
+	ncutils.Log("storing new message: " + cache)
+	messageCache.Store(fmt.Sprintf("%s%s", network, which), newMessage)
 }
 }
 
 
 func read(network, which string) string {
 func read(network, which string) string {
 	val, isok := messageCache.Load(fmt.Sprintf("%s%s", network, which))
 	val, isok := messageCache.Load(fmt.Sprintf("%s%s", network, which))
 	if isok {
 	if isok {
-		return fmt.Sprintf("%v", val)
+		var readMessage = val.(cachedMessage)                        // fetch current cached message
+		if time.Now().After(readMessage.LastSeen.Add(time.Minute)) { // check if message has been there over a minute
+			messageCache.Delete(fmt.Sprintf("%s%s", network, which)) // remove old message if expired
+			ncutils.Log("cached message expired")
+			return ""
+		}
+		ncutils.Log("cache hit, skipping probably " + readMessage.Message)
+		return readMessage.Message // return current message if not expired
 	}
 	}
 	return ""
 	return ""
 }
 }
@@ -219,6 +233,7 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) {
 		newNode.OS = runtime.GOOS
 		newNode.OS = runtime.GOOS
 		// check if interface needs to delta
 		// check if interface needs to delta
 		ifaceDelta := ncutils.IfaceDelta(&cfg.Node, &newNode)
 		ifaceDelta := ncutils.IfaceDelta(&cfg.Node, &newNode)
+		shouldDNSChange := cfg.Node.DNSOn != newNode.DNSOn
 
 
 		cfg.Node = newNode
 		cfg.Node = newNode
 		switch newNode.Action {
 		switch newNode.Action {
@@ -265,24 +280,15 @@ func NodeUpdate(client mqtt.Client, msg mqtt.Message) {
 				ncutils.Log("error resubscribing after interface change " + err.Error())
 				ncutils.Log("error resubscribing after interface change " + err.Error())
 				return
 				return
 			}
 			}
-		}
-		/*
-			else {
-				ncutils.Log("syncing conf to " + file)
-				err = wireguard.SyncWGQuickConf(cfg.Node.Interface, file)
-				if err != nil {
-					ncutils.Log("error syncing wg after peer update " + err.Error())
-					return
+			if newNode.DNSOn == "yes" {
+				ncutils.Log("setting up DNS")
+				if err = local.UpdateDNS(cfg.Node.Interface, cfg.Network, cfg.Server.CoreDNSAddr); err != nil {
+					ncutils.Log("error applying dns" + err.Error())
 				}
 				}
 			}
 			}
-		*/
+		}
 		//deal with DNS
 		//deal with DNS
-		if newNode.DNSOn == "yes" {
-			ncutils.Log("setting up DNS")
-			if err = local.UpdateDNS(cfg.Node.Interface, cfg.Network, cfg.Server.CoreDNSAddr); err != nil {
-				ncutils.Log("error applying dns" + err.Error())
-			}
-		} else {
+		if newNode.DNSOn != "yes" && shouldDNSChange {
 			ncutils.Log("settng DNS off")
 			ncutils.Log("settng DNS off")
 			_, err := ncutils.RunCmd("/usr/bin/resolvectl revert "+cfg.Node.Interface, true)
 			_, err := ncutils.RunCmd("/usr/bin/resolvectl revert "+cfg.Node.Interface, true)
 			if err != nil {
 			if err != nil {
@@ -311,14 +317,12 @@ func UpdatePeers(client mqtt.Client, msg mqtt.Message) {
 			return
 			return
 		}
 		}
 		// see if cache hit, if so skip
 		// see if cache hit, if so skip
-		/*
-			var currentMessage = read(peerUpdate.Network, lastPeerUpdate)
-			if currentMessage == string(data) {
-				return
-			}
-		*/
+		var currentMessage = read(peerUpdate.Network, lastPeerUpdate)
+		if currentMessage == string(data) {
+			ncutils.Log("cache hit")
+			return
+		}
 		insert(peerUpdate.Network, lastPeerUpdate, string(data))
 		insert(peerUpdate.Network, lastPeerUpdate, string(data))
-		ncutils.Log("update peer handler")
 
 
 		file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf"
 		file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf"
 		err = wireguard.UpdateWgPeers(file, peerUpdate.Peers)
 		err = wireguard.UpdateWgPeers(file, peerUpdate.Peers)
@@ -326,13 +330,13 @@ func UpdatePeers(client mqtt.Client, msg mqtt.Message) {
 			ncutils.Log("error updating wireguard peers" + err.Error())
 			ncutils.Log("error updating wireguard peers" + err.Error())
 			return
 			return
 		}
 		}
-		ncutils.Log("syncing conf to " + file)
 		//err = wireguard.SyncWGQuickConf(cfg.Node.Interface, file)
 		//err = wireguard.SyncWGQuickConf(cfg.Node.Interface, file)
 		err = wireguard.SetPeers(cfg.Node.Interface, cfg.Node.PersistentKeepalive, peerUpdate.Peers)
 		err = wireguard.SetPeers(cfg.Node.Interface, cfg.Node.PersistentKeepalive, peerUpdate.Peers)
 		if err != nil {
 		if err != nil {
 			ncutils.Log("error syncing wg after peer update " + err.Error())
 			ncutils.Log("error syncing wg after peer update " + err.Error())
 			return
 			return
 		}
 		}
+		ncutils.Log(fmt.Sprintf("received peer update on network, %s", cfg.Network))
 	}()
 	}()
 }
 }