Ver código fonte

range update

Matthew R. Kasun 3 anos atrás
pai
commit
381a3880f2
4 arquivos alterados com 180 adições e 2 exclusões
  1. 88 2
      controllers/network.go
  2. 6 0
      models/mqtt.go
  3. 14 0
      mq/publishers.go
  4. 72 0
      netclient/functions/daemon.go

+ 88 - 2
controllers/network.go

@@ -3,7 +3,9 @@ package controller
 import (
 	"encoding/json"
 	"errors"
+	"net"
 	"net/http"
+	"strconv"
 	"strings"
 
 	"github.com/gorilla/mux"
@@ -11,7 +13,10 @@ import (
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/servercfg"
+	"github.com/gravitl/netmaker/serverctl"
+	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
 // ALL_NETWORK_ACCESS - represents all networks
@@ -173,14 +178,61 @@ func updateNetwork(w http.ResponseWriter, r *http.Request) {
 			returnErrorResponse(w, r, formatError(err, "internal"))
 			return
 		}
+		var serverAddrs = make([]models.ServerAddr, 0)
+		if rangeupdate {
+			serverAddrs = preCalculateServerAddrs(network.NetID)
+		}
+		leaderServerNode, err := logic.GetNetworkServerLeader(netname)
+		if err != nil {
+			logger.Log(1, "failed to update peers for server node address on network", netname)
+		}
+
 		for _, node := range nodes {
-			runUpdates(&node, true)
+			if node.IsServer != "yes" {
+				if rangeupdate {
+					applyServerAddr(&node, serverAddrs, network)
+					var rangeUpdate models.RangeUpdate
+					rangeUpdate.Node = node
+					rangeUpdate.Peers.Network = node.Network
+					rangeUpdate.Peers.ServerAddrs = serverAddrs
+					var peer wgtypes.PeerConfig
+					peer.PublicKey, err = wgtypes.ParseKey(leaderServerNode.PublicKey)
+					if err != nil {
+						returnErrorResponse(w, r, formatError(err, "internal"))
+						return
+					}
+					peer.ReplaceAllowedIPs = true
+					for _, server := range serverAddrs {
+						if server.IsLeader {
+							_, address, err := net.ParseCIDR(server.Address + "/32")
+							if err != nil {
+								returnErrorResponse(w, r, formatError(err, "internal"))
+								return
+							}
+							peer.AllowedIPs = append(peer.AllowedIPs, *address)
+						}
+					}
+					rangeUpdate.Peers.Peers = append(rangeUpdate.Peers.Peers, peer)
+					if err := mq.PublishRangeUpdate(&rangeUpdate); err != nil {
+						returnErrorResponse(w, r, formatError(err, "internal"))
+						return
+					}
+					if err := mq.NodeUpdate(&node); err != nil {
+						logger.Log(1, "could not update range when network", netname, "changed cidr for node", node.Name, node.ID, err.Error())
+					}
+				}
+			}
 		}
 	}
-
 	logger.Log(1, r.Header.Get("user"), "updated network", netname)
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(newNetwork)
+	//currentServerNode, err := logic.GetNetworkServerLocal(netname)
+	//if err != nil {
+	//	logger.Log(1, "failed to update peers for server node address on network", netname)
+	//} else {
+	//	runUpdates(&currentServerNode, true)
+	//}
 }
 
 func updateNetworkNodeLimit(w http.ResponseWriter, r *http.Request) {
@@ -331,3 +383,37 @@ func deleteAccessKey(w http.ResponseWriter, r *http.Request) {
 	logger.Log(1, r.Header.Get("user"), "deleted access key", keyname, "on network,", netname)
 	w.WriteHeader(http.StatusOK)
 }
+
+// used for network address changes
+func applyServerAddr(node *models.Node, serverAddrs []models.ServerAddr, network models.Network) {
+	node.NetworkSettings = network
+	node.NetworkSettings.DefaultServerAddrs = serverAddrs
+}
+
+func preCalculateServerAddrs(netname string) []models.ServerAddr {
+	var serverAddrs = make([]models.ServerAddr, 0)
+	serverNodes := logic.GetServerNodes(netname)
+	if len(serverNodes) == 0 {
+		if err := serverctl.SyncServerNetwork(netname); err != nil {
+			return serverAddrs
+		}
+	}
+
+	address, err := logic.UniqueAddressServer(netname)
+	if err != nil {
+		return serverAddrs
+	}
+	for i := range serverNodes {
+		addrParts := strings.Split(address, ".")                      // get the numbers
+		lastNum, lastErr := strconv.Atoi(addrParts[len(addrParts)-1]) // get the last number as an int
+		if lastErr == nil {
+			lastNum = lastNum - i
+			addrParts[len(addrParts)-1] = strconv.Itoa(lastNum)
+			serverAddrs = append(serverAddrs, models.ServerAddr{
+				IsLeader: logic.IsLeader(&serverNodes[i]),
+				Address:  strings.Join(addrParts, "."),
+			})
+		}
+	}
+	return serverAddrs
+}

+ 6 - 0
models/mqtt.go

@@ -14,3 +14,9 @@ type KeyUpdate struct {
 	Network   string `json:"network" bson:"network"`
 	Interface string `json:"interface" bson:"interface"`
 }
+
+// RangeUpdate  - structure for network range updates
+type RangeUpdate struct {
+	Node  Node
+	Peers PeerUpdate
+}

+ 14 - 0
mq/publishers.go

@@ -87,6 +87,20 @@ func NodeUpdate(node *models.Node) error {
 	return nil
 }
 
+// PublishRangeUpdate - publishes a network range update
+func PublishRangeUpdate(update *models.RangeUpdate) error {
+	data, err := json.Marshal(update)
+	if err != nil {
+		logger.Log(2, "error marshalling range update ", err.Error())
+		return err
+	}
+	if err = publish(&update.Node, fmt.Sprintf("rangeupdate/%s/%s", update.Node.Network, update.Node.ID), data); err != nil {
+		logger.Log(2, "error publishing range update to peer ", update.Node.ID, err.Error())
+		return err
+	}
+	return nil
+}
+
 // sendPeers - retrieve networks, send peer ports to all peers
 func sendPeers() {
 	var force bool

+ 72 - 0
netclient/functions/daemon.go

@@ -118,6 +118,74 @@ var All mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
 	//ncutils.Log("Message: " + string(msg.Payload()))
 }
 
+// RangeUpdate -- mqtt message handler for rangeupdate/<network>/<nodeid>
+func RangeUpdate(client mqtt.Client, msg mqtt.Message) {
+	var rangeUpdate models.RangeUpdate
+	var cfg config.ClientConfig
+	var network = parseNetworkFromTopic(msg.Topic())
+	cfg.Network = network
+	cfg.ReadConfig()
+	data, dataErr := decryptMsg(&cfg, msg.Payload())
+	if dataErr != nil {
+		return
+	}
+	err := json.Unmarshal([]byte(data), &rangeUpdate)
+	if err != nil {
+		ncutils.Log("error unmarshalling node update data" + err.Error())
+		return
+	}
+
+	ncutils.Log("received message to do range update " + network)
+	rangeUpdate.Node.PullChanges = "no"
+	//ensure that OS never changes
+	rangeUpdate.Node.OS = runtime.GOOS
+	// check if interface needs to delta
+	// Save new config
+	cfg.Node.Action = models.NODE_NOOP
+	if err := config.Write(&cfg, cfg.Network); err != nil {
+		ncutils.PrintLog("error updating node configuration: "+err.Error(), 1)
+	}
+	nameserver := cfg.Server.CoreDNSAddr
+	privateKey, err := wireguard.RetrievePrivKey(rangeUpdate.Node.Network)
+	if err != nil {
+		ncutils.Log("error reading PrivateKey " + err.Error())
+		return
+	}
+	file := ncutils.GetNetclientPathSpecific() + cfg.Node.Interface + ".conf"
+	if err := wireguard.UpdateWgInterface(file, privateKey, nameserver, rangeUpdate.Node); err != nil {
+		ncutils.Log("error updating wireguard config " + err.Error())
+		return
+	}
+	ncutils.Log("applying WG conf to " + file)
+	err = wireguard.ApplyConf(&cfg.Node, cfg.Node.Interface, file)
+	if err != nil {
+		ncutils.Log("error restarting wg after node update " + err.Error())
+		return
+	}
+
+	spew.Dump(rangeUpdate.Peers)
+	err = wireguard.UpdateWgPeers(file, rangeUpdate.Peers.Peers)
+	if err != nil {
+		ncutils.Log("error updating wireguard peers" + err.Error())
+		return
+	}
+	//err = wireguard.SyncWGQuickConf(cfg.Node.Interface, file)
+	var iface = cfg.Node.Interface
+	if ncutils.IsMac() {
+		iface, err = local.GetMacIface(cfg.Node.Address)
+		if err != nil {
+			ncutils.Log("error retrieving mac iface: " + err.Error())
+			return
+		}
+	}
+	err = wireguard.SetPeers(iface, cfg.Node.Address, cfg.Node.PersistentKeepalive, rangeUpdate.Peers.Peers)
+	if err != nil {
+		ncutils.Log("error syncing wg after peer update: " + err.Error())
+		return
+	}
+
+}
+
 // NodeUpdate -- mqtt message handler for /update/<NodeID> topic
 func NodeUpdate(client mqtt.Client, msg mqtt.Message) {
 	var newNode models.Node
@@ -435,6 +503,10 @@ func setupMQTT(cfg *config.ClientConfig, publish bool) mqtt.Client {
 			if cfg.DebugOn {
 				ncutils.Log(fmt.Sprintf("subscribed to peer updates for node %s peers/%s/%s", cfg.Node.Name, cfg.Node.Network, cfg.Node.ID))
 			}
+			if token := client.Subscribe(fmt.Sprintf("rangeupdate/%s/%s", cfg.Node.Network, cfg.Node.ID), 0, mqtt.MessageHandler(RangeUpdate)); token.Wait() && token.Error() != nil {
+				ncutils.Log(token.Error().Error())
+				return
+			}
 			opts.SetOrderMatters(true)
 			opts.SetResumeSubs(true)
 		}