Browse Source

moved peer determination to server

Matthew R Kasun 3 years ago
parent
commit
a86b9bd380
5 changed files with 32 additions and 230 deletions
  1. 4 3
      models/mqtt.go
  2. 22 45
      mq/mq.go
  3. 5 14
      netclient/functions/daemon.go
  4. 0 167
      netclient/functions/peers.go
  5. 1 1
      netclient/wireguard/common.go

+ 4 - 3
models/mqtt.go

@@ -1,9 +1,10 @@
 package models
 package models
 
 
+import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
 type PeerUpdate struct {
 type PeerUpdate struct {
-	Network  string
-	Nodes    []Node
-	ExtPeers []ExtPeersResponse
+	Network string
+	Peers   []wgtypes.PeerConfig
 }
 }
 
 
 type KeyUpdate struct {
 type KeyUpdate struct {

+ 22 - 45
mq/mq.go

@@ -1,8 +1,8 @@
 package mq
 package mq
 
 
 import (
 import (
+	"encoding/json"
 	"errors"
 	"errors"
-	"fmt"
 	"strings"
 	"strings"
 
 
 	mqtt "github.com/eclipse/paho.mqtt.golang"
 	mqtt "github.com/eclipse/paho.mqtt.golang"
@@ -86,61 +86,38 @@ var IPUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
 		}
 		}
 		node.Endpoint = ip
 		node.Endpoint = ip
 		node.SetLastCheckIn()
 		node.SetLastCheckIn()
-		if err := UpdatePeers(client, node); err != nil {
+		if err != UpdatePeers(client, node) {
 			logger.Log(0, "error updating peers "+err.Error())
 			logger.Log(0, "error updating peers "+err.Error())
 		}
 		}
 	}()
 	}()
 }
 }
 
 
 func UpdatePeers(client mqtt.Client, node models.Node) error {
 func UpdatePeers(client mqtt.Client, node models.Node) error {
-	var peerUpdate models.PeerUpdate
-	peerUpdate.Network = node.Network
-
-	nodes, err := logic.GetNetworkNodes(node.Network)
+	peersToUpdate, err := logic.GetNetworkNodes(node.Network)
 	if err != nil {
 	if err != nil {
-		return fmt.Errorf("unable to get network nodes %v: ", err)
-	}
-	if token := client.Connect(); token.Wait() && token.Error() != nil {
-		return token.Error()
+		logger.Log(0, "error retrieving peers to be updated "+err.Error())
+		return err
 	}
 	}
-	for _, peer := range nodes {
-		//don't need to update the initiatiing client
-		if peer.ID == node.ID {
-			continue
-		}
-		peerUpdate.Nodes = append(peerUpdate.Nodes, peer)
-		peerUpdate.ExtPeers, err = logic.GetExtPeersList(&node)
-
+	for _, peerToUpdate := range peersToUpdate {
+		peers, _, _, err := logic.GetServerPeers(&peerToUpdate)
 		if err != nil {
 		if err != nil {
-			logger.Log(0)
-		}
-		if token := client.Publish("update/peers/"+peer.ID, 0, false, nodes); token.Wait() && token.Error() != nil {
-			logger.Log(0, "error publishing peer update "+peer.ID+" "+token.Error().Error())
+			logger.Log(0, "error retrieving peers "+err.Error())
+			return err
 		}
 		}
-	}
-
-	return nil
-}
-
-func UpdateLocalPeers(client mqtt.Client, node models.Node) error {
-	nodes, err := logic.GetNetworkNodes(node.Network)
-	if err != nil {
-		return fmt.Errorf("unable to get network nodes %v: ", err)
-	}
-	if token := client.Connect(); token.Wait() && token.Error() != nil {
-		return token.Error()
-	}
-	for _, peer := range nodes {
-		//don't need to update the initiatiing client
-		if peer.ID == node.ID {
+		if peerToUpdate.ID == node.ID {
 			continue
 			continue
 		}
 		}
-		//if peer.Endpoint is on same lan as node.LocalAddress
-		//if TODO{
-		//continue
-		//}
-		if token := client.Publish("update/peers/"+peer.ID, 0, false, nodes); token.Wait() && token.Error() != nil {
-			logger.Log(0, "error publishing peer update "+peer.ID+" "+token.Error().Error())
+		var peerUpdate models.PeerUpdate
+		peerUpdate.Network = node.Network
+		peerUpdate.Peers = peers
+		data, err := json.Marshal(peerUpdate)
+		if err != nil {
+			logger.Log(0, "error marshaling peer update "+err.Error())
+			return err
+		}
+		if token := client.Publish("/update/peers/"+peerToUpdate.ID, 0, false, data); token.Wait() && token.Error() != nil {
+			logger.Log(0, "error sending peer updatte to no")
+			return err
 		}
 		}
 	}
 	}
 	return nil
 	return nil
@@ -162,7 +139,7 @@ var LocalAddressUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.M
 		}
 		}
 		node.LocalAddress = string(msg.Payload())
 		node.LocalAddress = string(msg.Payload())
 		node.SetLastCheckIn()
 		node.SetLastCheckIn()
-		if err := UpdateLocalPeers(client, node); err != nil {
+		if err := UpdatePeers(client, node); err != nil {
 			logger.Log(0, "error updating peers "+err.Error())
 			logger.Log(0, "error updating peers "+err.Error())
 		}
 		}
 	}()
 	}()

+ 5 - 14
netclient/functions/daemon.go

@@ -28,7 +28,7 @@ func Daemon() error {
 		return err
 		return err
 	}
 	}
 	for _, network := range networks {
 	for _, network := range networks {
-		go Netclient(ctx, network)
+		go MessageQueue(ctx, network)
 	}
 	}
 	quit := make(chan os.Signal, 1)
 	quit := make(chan os.Signal, 1)
 	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
 	signal.Notify(quit, syscall.SIGTERM, os.Interrupt)
@@ -51,8 +51,8 @@ func SetupMQTT(cfg config.ClientConfig) mqtt.Client {
 	return client
 	return client
 }
 }
 
 
-// Netclient sets up Message Queue and subsribes/publishes updates to/from server
-func Netclient(ctx context.Context, network string) {
+// MessageQueue sets up Message Queue and subsribes/publishes updates to/from server
+func MessageQueue(ctx context.Context, network string) {
 	ncutils.Log("netclient go routine started for " + network)
 	ncutils.Log("netclient go routine started for " + network)
 	var cfg config.ClientConfig
 	var cfg config.ClientConfig
 	cfg.Network = network
 	cfg.Network = network
@@ -150,17 +150,7 @@ var UpdatePeers mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message)
 		var cfg config.ClientConfig
 		var cfg config.ClientConfig
 		cfg.Network = peerUpdate.Network
 		cfg.Network = peerUpdate.Network
 		cfg.ReadConfig()
 		cfg.ReadConfig()
-		peers, err := CalculatePeers(cfg.Node, peerUpdate.Nodes, cfg.Node.IsDualStack, cfg.Node.IsEgressGateway, cfg.Node.IsServer)
-		if err != nil {
-			ncutils.Log("error calculating Peers " + err.Error())
-			return
-		}
-		extpeers, err := CalculateExtPeers(cfg.Node, peerUpdate.ExtPeers)
-		if err != nil {
-			ncutils.Log("error updated external wireguard peers " + err.Error())
-		}
-		peers = append(peers, extpeers...)
-		err = wireguard.UpdateWgPeers(cfg.Node.Interface, peers)
+		err = wireguard.UpdateWgPeers(cfg.Node.Interface, peerUpdate.Peers)
 		if err != nil {
 		if err != nil {
 			ncutils.Log("error updating wireguard peers" + err.Error())
 			ncutils.Log("error updating wireguard peers" + err.Error())
 			return
 			return
@@ -315,6 +305,7 @@ func Metrics(ctx context.Context, cfg config.ClientConfig, network string) {
 				ncutils.Log("error publishing metrics " + token.Error().Error())
 				ncutils.Log("error publishing metrics " + token.Error().Error())
 			}
 			}
 			ncutils.Log("metrics collection complete")
 			ncutils.Log("metrics collection complete")
+			client.Disconnect(250)
 		}
 		}
 	}
 	}
 }
 }

+ 0 - 167
netclient/functions/peers.go

@@ -1,167 +0,0 @@
-package functions
-
-import (
-	"log"
-	"net"
-	"strconv"
-	"strings"
-	"time"
-
-	"github.com/gravitl/netmaker/models"
-	"github.com/gravitl/netmaker/netclient/ncutils"
-	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
-)
-
-func CalculatePeers(thisNode models.Node, peernodes []models.Node, dualstack, egressgateway, server string) ([]wgtypes.Peer, error) {
-	//hasGateway := false
-	var gateways []string
-	var peers []wgtypes.Peer
-
-	keepalive := thisNode.PersistentKeepalive
-	keepalivedur, _ := time.ParseDuration(strconv.FormatInt(int64(keepalive), 10) + "s")
-	keepaliveserver, err := time.ParseDuration(strconv.FormatInt(int64(5), 10) + "s")
-	if err != nil {
-		log.Fatalf("Issue with format of keepalive value. Please update netconfig: %v", err)
-	}
-	for _, node := range peernodes {
-		pubkey, err := wgtypes.ParseKey(node.PublicKey)
-		if err != nil {
-			log.Println("error parsing key")
-			//return peers, hasGateway, gateways, err
-		}
-
-		if thisNode.PublicKey == node.PublicKey {
-			continue
-		}
-		if thisNode.Endpoint == node.Endpoint {
-			if thisNode.LocalAddress != node.LocalAddress && node.LocalAddress != "" {
-				node.Endpoint = node.LocalAddress
-			} else {
-				continue
-			}
-		}
-
-		var peer wgtypes.Peer
-		var peeraddr = net.IPNet{
-			IP:   net.ParseIP(node.Address),
-			Mask: net.CIDRMask(32, 32),
-		}
-		var allowedips []net.IPNet
-		allowedips = append(allowedips, peeraddr)
-		// handle manually set peers
-		for _, allowedIp := range node.AllowedIPs {
-			if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
-				nodeEndpointArr := strings.Split(node.Endpoint, ":")
-				if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists..
-					allowedips = append(allowedips, *ipnet)
-				}
-			} else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address {
-				ipnet := net.IPNet{
-					IP:   net.ParseIP(allowedIp),
-					Mask: net.CIDRMask(32, 32),
-				}
-				allowedips = append(allowedips, ipnet)
-			}
-		}
-		// handle egress gateway peers
-		if node.IsEgressGateway == "yes" {
-			//hasGateway = true
-			ranges := node.EgressGatewayRanges
-			for _, iprange := range ranges { // go through each cidr for egress gateway
-				_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
-				if err != nil {
-					ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
-					continue // if can't parse CIDR
-				}
-				nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node
-				if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node
-					ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
-					continue // skip adding egress range if overlaps with node's ip
-				}
-				if ipnet.Contains(net.ParseIP(thisNode.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
-					ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+thisNode.LocalAddress+", omitting", 2)
-					continue // skip adding egress range if overlaps with node's local ip
-				}
-				gateways = append(gateways, iprange)
-				if err != nil {
-					log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
-				} else {
-					allowedips = append(allowedips, *ipnet)
-				}
-			}
-		}
-		if node.Address6 != "" && dualstack == "yes" {
-			var addr6 = net.IPNet{
-				IP:   net.ParseIP(node.Address6),
-				Mask: net.CIDRMask(128, 128),
-			}
-			allowedips = append(allowedips, addr6)
-		}
-		if thisNode.IsServer == "yes" && !(node.IsServer == "yes") {
-			peer = wgtypes.Peer{
-				PublicKey:                   pubkey,
-				PersistentKeepaliveInterval: keepaliveserver,
-				AllowedIPs:                  allowedips,
-			}
-		} else if keepalive != 0 {
-			peer = wgtypes.Peer{
-				PublicKey:                   pubkey,
-				PersistentKeepaliveInterval: keepalivedur,
-				Endpoint: &net.UDPAddr{
-					IP:   net.ParseIP(node.Endpoint),
-					Port: int(node.ListenPort),
-				},
-				AllowedIPs: allowedips,
-			}
-		} else {
-			peer = wgtypes.Peer{
-				PublicKey: pubkey,
-				Endpoint: &net.UDPAddr{
-					IP:   net.ParseIP(node.Endpoint),
-					Port: int(node.ListenPort),
-				},
-				AllowedIPs: allowedips,
-			}
-		}
-		peers = append(peers, peer)
-	}
-	return peers, nil
-}
-
-func CalculateExtPeers(thisNode models.Node, extPeers []models.ExtPeersResponse) ([]wgtypes.Peer, error) {
-	var peers []wgtypes.Peer
-	var err error
-	for _, extPeer := range extPeers {
-		pubkey, err := wgtypes.ParseKey(extPeer.PublicKey)
-		if err != nil {
-			log.Println("error parsing key")
-			return peers, err
-		}
-
-		if thisNode.PublicKey == extPeer.PublicKey {
-			continue
-		}
-
-		var peer wgtypes.Peer
-		var peeraddr = net.IPNet{
-			IP:   net.ParseIP(extPeer.Address),
-			Mask: net.CIDRMask(32, 32),
-		}
-		var allowedips []net.IPNet
-		allowedips = append(allowedips, peeraddr)
-
-		if extPeer.Address6 != "" && thisNode.IsDualStack == "yes" {
-			var addr6 = net.IPNet{
-				IP:   net.ParseIP(extPeer.Address6),
-				Mask: net.CIDRMask(128, 128),
-			}
-			allowedips = append(allowedips, addr6)
-		}
-		peer = wgtypes.Peer{
-			PublicKey:  pubkey,
-			AllowedIPs: allowedips,
-		}
-		peers = append(peers, peer)
-	}
-	return peers, err
-}

+ 1 - 1
netclient/wireguard/common.go

@@ -343,7 +343,7 @@ func WriteWgConfig(cfg config.ClientConfig, privateKey string, peers []wgtypes.P
 }
 }
 
 
 // UpdateWgPeers - updates the peers of a network
 // UpdateWgPeers - updates the peers of a network
-func UpdateWgPeers(wgInterface string, peers []wgtypes.Peer) error {
+func UpdateWgPeers(wgInterface string, peers []wgtypes.PeerConfig) error {
 	//update to get path properly
 	//update to get path properly
 	file := ncutils.GetNetclientPathSpecific() + wgInterface + ".conf"
 	file := ncutils.GetNetclientPathSpecific() + wgInterface + ".conf"
 	wireguard, err := ini.ShadowLoad(file)
 	wireguard, err := ini.ShadowLoad(file)