Explorar o código

IsLeader check for sending UDP peers

afeiszli %!s(int64=3) %!d(string=hai) anos
pai
achega
6dac265574

+ 1 - 1
controllers/nodeGrpcController.go

@@ -134,7 +134,7 @@ func (s *NodeServiceServer) GetPeers(ctx context.Context, req *nodepb.Object) (*
 		if err != nil {
 			return nil, err
 		}
-		if node.IsServer == "yes" {
+		if node.IsServer == "yes" && functions.IsLeader(&node) {
 			SetNetworkServerPeers(macAndNetwork[1])
 		}
 		excludeIsRelayed := node.IsRelay != "yes"

+ 2 - 25
controllers/nodeHttpController.go

@@ -267,7 +267,7 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) {
 	var nodes []models.Node
 	var params = mux.Vars(r)
 	networkName := params["network"]
-	nodes, err := GetNetworkNodes(networkName)
+	nodes, err := functions.GetNetworkNodes(networkName)
 	if err != nil {
 		returnErrorResponse(w, r, formatError(err, "internal"))
 		return
@@ -279,29 +279,6 @@ func getNetworkNodes(w http.ResponseWriter, r *http.Request) {
 	json.NewEncoder(w).Encode(nodes)
 }
 
-func GetNetworkNodes(network string) ([]models.Node, error) {
-	var nodes []models.Node
-	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
-	if err != nil {
-		if database.IsEmptyRecord(err) {
-			return []models.Node{}, nil
-		}
-		return nodes, err
-	}
-	for _, value := range collection {
-
-		var node models.Node
-		err := json.Unmarshal([]byte(value), &node)
-		if err != nil {
-			continue
-		}
-		if node.Network == network {
-			nodes = append(nodes, node)
-		}
-	}
-	return nodes, nil
-}
-
 //A separate function to get all nodes, not just nodes for a particular network.
 //Not quite sure if this is necessary. Probably necessary based on front end but may want to review after iteration 1 if it's being used or not
 func getAllNodes(w http.ResponseWriter, r *http.Request) {
@@ -335,7 +312,7 @@ func getUsersNodes(user models.User) ([]models.Node, error) {
 	var nodes []models.Node
 	var err error
 	for _, networkName := range user.Networks {
-		tmpNodes, err := GetNetworkNodes(networkName)
+		tmpNodes, err := functions.GetNetworkNodes(networkName)
 		if err != nil {
 			continue
 		}

+ 63 - 1
functions/helpers.go

@@ -14,7 +14,7 @@ import (
 	"net"
 	"strings"
 	"time"
-
+	"sort"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/servercfg"
@@ -351,6 +351,68 @@ func UpdateNetworkLocalAddresses(networkName string) error {
 	return nil
 }
 
+func GetNetworkNodes(network string) ([]models.Node, error) {
+	var nodes []models.Node
+	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
+	if err != nil {
+		if database.IsEmptyRecord(err) {
+			return []models.Node{}, nil
+		}
+		return nodes, err
+	}
+	for _, value := range collection {
+
+		var node models.Node
+		err := json.Unmarshal([]byte(value), &node)
+		if err != nil {
+			continue
+		}
+		if node.Network == network {
+			nodes = append(nodes, node)
+		}
+	}
+	return nodes, nil
+}
+
+func GetSortedNetworkServerNodes(network string) ([]models.Node, error) {
+	var nodes []models.Node
+	collection, err := database.FetchRecords(database.NODES_TABLE_NAME)
+	if err != nil {
+		if database.IsEmptyRecord(err) {
+			return []models.Node{}, nil
+		}
+		return nodes, err
+	}
+	for _, value := range collection {
+
+		var node models.Node
+		err := json.Unmarshal([]byte(value), &node)
+		if err != nil {
+			continue
+		}
+		if node.Network == network && node.IsServer == "yes" {
+			nodes = append(nodes, node)
+		}
+	}
+	sort.Sort(models.NodesArray(nodes))
+	return nodes, nil
+}
+
+
+func IsLeader(node *models.Node) (bool) {
+	nodes, err := GetSortedNetworkServerNodes(node.Network)
+	if err != nil {
+		PrintUserLog("[netmaker]","ERROR: COULD NOT RETRIEVE SERVER NODES. THIS WILL BREAK HOLE PUNCHING.", 0)
+		return false
+	}
+	for _, n := range nodes {
+		if n.LastModified > time.Now().Add(-1 * time.Minute).Unix() {
+			return n.Address == node.Address
+		}
+	}
+	return nodes[1].Address == node.Address
+}
+
 func IsNetworkDisplayNameUnique(name string) (bool, error) {
 
 	isunique := true

+ 1 - 0
main.go

@@ -39,6 +39,7 @@ func initialize() { // Client Mode Prereq Check
 		log.Fatal(err)
 	}
 	log.Println("database successfully connected.")
+
 	if servercfg.IsClientMode() != "off" {
 		output, err := ncutils.RunCmd("id -u", true)
 		if err != nil {

+ 14 - 1
models/node.go

@@ -7,7 +7,7 @@ import (
 	"net"
 	"strings"
 	"time"
-
+	"bytes"
 	"github.com/go-playground/validator/v10"
 	"github.com/gravitl/netmaker/database"
 	"golang.org/x/crypto/bcrypt"
@@ -26,6 +26,19 @@ const NODE_NOOP = "noop"
 var seededRand *rand.Rand = rand.New(
 	rand.NewSource(time.Now().UnixNano()))
 
+
+type NodesArray []Node
+
+func (a NodesArray) Len() int           { return len(a) }
+func (a NodesArray) Less(i, j int) bool { return isLess(a[i].Address, a[j].Address) }
+func (a NodesArray) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
+
+func isLess(ipA string, ipB string) bool {
+	ipNetA := net.ParseIP(ipA)
+	ipNetB := net.ParseIP(ipB)
+	return bytes.Compare(ipNetA, ipNetB) < 0
+}
+
 // node struct
 type Node struct {
 	ID                  string   `json:"id,omitempty" bson:"id,omitempty"`

+ 29 - 0
servercfg/serverconf.go

@@ -372,3 +372,32 @@ func IsSplitDNS() bool {
 	}
 	return issplit
 }
+/*
+func GetServerNet() string {
+    cidr := "10.250.250.0/24"
+    if os.Getenv("SERVER_NET") != "" {
+        if _, _, err := net.ParseCIDR(os.Getenv("SERVER_NET")); err == nil {
+            return os.Getenv("SERVER_NET")
+        }
+    } else if config.Config.Server.ServerNet != "" {
+        if _, _, err := net.ParseCIDR(config.Config.Server.ServerNet); err == nil {
+            return config.Config.Server.ServerNet
+        }
+    }
+    return cidr
+}
+
+func GetRegistrationNet() string {
+    cidr := "10.250.251.0/24"
+    if os.Getenv("REG_NET") != "" {
+        if _, _, err := net.ParseCIDR(os.Getenv("REG_NET")); err == nil {
+            return os.Getenv("REG_NET")
+        }
+    } else if config.Config.Server.RegNet != "" {
+        if _, _, err := net.ParseCIDR(config.Config.Server.RegNet); err == nil {
+            return config.Config.Server.RegNet
+        }
+    }
+    return cidr
+}
+*/

+ 54 - 2
serverctl/serverctl.go

@@ -7,7 +7,7 @@ import (
 	"log"
 	"os"
 	"os/exec"
-
+	"time"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/functions"
 	"github.com/gravitl/netmaker/models"
@@ -15,6 +15,8 @@ import (
 	"github.com/gravitl/netmaker/servercfg"
 )
 
+const SERVER_NETID="servernet"
+
 func GetServerWGConf() (models.IntClient, error) {
 	var server models.IntClient
 	collection, err := database.FetchRecords(database.INT_CLIENTS_TABLE_NAME)
@@ -104,7 +106,12 @@ func RemoveNetwork(network string) (bool, error) {
 	return true, err
 
 }
-
+/*
+func InitServerNet() error {
+	func
+	return nil
+}
+*/
 func InitServerNetclient() error {
 	netclientDir := ncutils.GetNetclientPath()
 	_, err := os.Stat(netclientDir + "/config")
@@ -193,3 +200,48 @@ func AddNetwork(network string) (bool, error) {
 	log.Println("Server added to network " + network)
 	return true, err
 }
+
+func IsLeader(node *models.Node) (bool) {
+	nodes, err := functions.GetSortedNetworkServerNodes(node.Network)
+	if err != nil {
+		functions.PrintUserLog("[netmaker]","ERROR: COULD NOT RETRIEVE SERVER NODES. THIS WILL BREAK HOLE PUNCHING.", 0)
+		return false
+	}
+	for _, n := range nodes {
+		if n.LastModified > time.Now().Add(-1 * time.Minute).Unix() {
+			return n.Address == node.Address
+		}
+	}
+	return nodes[1].Address == node.Address
+}
+
+// == PRIVATE ==
+/*
+func getUniqueServerIP(currentAddrs []string) (string, error) {
+
+	ip, ipnet, err := net.ParseCIDR(servercfg.GetServerNet())
+	if err != nil {
+		return "", err
+	}
+	offset := true
+	for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); functions.Inc(ip) {
+		if offset {
+			offset = false
+			continue
+		}
+		if isIPunique(ip.String(), currentAddrs) {
+			return ip.String(), nil
+		}
+	}
+	return "", errors.New("failed to get unique server ip")
+}
+
+func isIPunique(addr string, currentAddrs []string) bool {
+	for _, currAddr := range currentAddrs {
+		if addr == currAddr {
+			return false
+		}
+	}
+	return true
+}
+*/