Преглед изворни кода

PeerUpdates for refactored relays (#2333)

* revert relays

* remove debugging logs

* review comments

* revert relays (#2289)

* revert relays

* remove debugging logs

* review comments

* publish peer update for new relay

* publish peer update on relay delete

* single peer broadcast actions

* preliminary working peer updates for relays

* re-eable encryption of mq payloads

* remove unused files

* create separate file for relay related peer updates

* remove debugging logs

* remove unused func

* tidy go.mod

* remove debugging logs

* move repeated code to function

* remove erroneous comments

---------

Co-authored-by: Abhishek Kondur <[email protected]>
Matthew R Kasun пре 2 година
родитељ
комит
b19e204afb
7 измењених фајлова са 369 додато и 50 уклоњено
  1. 3 0
      auth/host_session.go
  2. 4 4
      controllers/node.go
  3. 51 21
      controllers/relay.go
  4. 17 0
      logic/nodes.go
  5. 33 25
      logic/relay.go
  6. 6 0
      models/host.go
  7. 255 0
      mq/relay.go

+ 3 - 0
auth/host_session.go

@@ -247,6 +247,9 @@ func CheckNetRegAndHostUpdate(networks []string, h *models.Host) {
 			Action: models.RequestAck,
 			Host:   *h,
 		})
+		if err := mq.PublishPeerUpdate(); err != nil {
+			logger.Log(0, "failed to publish peer update during registration -", err.Error())
+		}
 
 	}
 }

+ 4 - 4
controllers/node.go

@@ -662,10 +662,10 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	if relayupdate {
-		updatenodes := logic.UpdateRelayed(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
-		if len(updatenodes) > 0 {
-			for _, relayedNode := range updatenodes {
-				runUpdates(&relayedNode, false)
+		updatedClients := logic.UpdateRelayed(currentNode.ID.String(), currentNode.RelayedNodes, newNode.RelayedNodes)
+		if len(updatedClients) > 0 {
+			for _, relayedClient := range updatedClients {
+				runUpdates(&relayedClient.Node, false)
 			}
 		}
 	}

+ 51 - 21
controllers/relay.go

@@ -24,38 +24,49 @@ import (
 //			Responses:
 //				200: nodeResponse
 func createRelay(w http.ResponseWriter, r *http.Request) {
-	var relay models.RelayRequest
+	var relayRequest models.RelayRequest
 	var params = mux.Vars(r)
 	w.Header().Set("Content-Type", "application/json")
-	err := json.NewDecoder(r.Body).Decode(&relay)
+	err := json.NewDecoder(r.Body).Decode(&relayRequest)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	relay.NetID = params["network"]
-	relay.NodeID = params["nodeid"]
-	updatenodes, node, err := logic.CreateRelay(relay)
+
+	relayRequest.NetID = params["network"]
+	relayRequest.NodeID = params["nodeid"]
+	_, relayNode, err := logic.CreateRelay(relayRequest)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relay.NodeID, relay.NetID, err))
+			fmt.Sprintf("failed to create relay on node [%s] on network [%s]: %v", relayRequest.NodeID, relayRequest.NetID, err))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-
-	logger.Log(1, r.Header.Get("user"), "created relay on node", relay.NodeID, "on network", relay.NetID)
-	for _, relayedNode := range updatenodes {
-
-		err = mq.NodeUpdate(&relayedNode)
-		if err != nil {
-			logger.Log(1, "error sending update to relayed node ", relayedNode.ID.String(), "on network", relay.NetID, ": ", err.Error())
-		}
+	relay := models.Client{
+		Host: *logic.GetHostByNodeID(params["nodeid"]),
+		Node: relayNode,
+	}
+	peers, err := logic.GetNetworkClients(relay.Node.Network)
+	if err != nil {
+		logger.Log(0, "error getting network nodes: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	//mq.PubPeersforRelay(relay, peers)
+	//for _, relayed := range relayedClients {
+	//mq.PubPeersForRelayedNode(relayed, relay, peers)
+	//}
+	clients := peers
+	for _, client := range clients {
+		mq.PubPeerUpdate(&client, &relay, &peers)
 	}
 
-	apiNode := node.ConvertToAPINode()
+	logger.Log(1, r.Header.Get("user"), "created relay on node", relayRequest.NodeID, "on network", relayRequest.NetID)
+	apiNode := relayNode.ConvertToAPINode()
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(apiNode)
-	runUpdates(&node, true)
+	//runUpdates(&node, true)
 }
 
 // swagger:route DELETE /api/nodes/{network}/{nodeid}/deleterelay nodes deleteRelay
@@ -74,19 +85,38 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 	nodeid := params["nodeid"]
 	netid := params["network"]
-	updatenodes, node, err := logic.DeleteRelay(netid, nodeid)
+	updateClients, node, err := logic.DeleteRelay(netid, nodeid)
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 	logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid)
-	for _, relayedNode := range updatenodes {
-		err = mq.NodeUpdate(&relayedNode)
+	go func() {
+		//update relayHost node
+		relayHost := logic.GetHostByNodeID(node.ID.String())
+		if err := mq.NodeUpdate(&node); err != nil {
+			logger.Log(1, "relay node update", relayHost.Name, "on network", node.Network, ": ", err.Error())
+		}
+		for _, relayedClient := range updateClients {
+			err = mq.NodeUpdate(&relayedClient.Node)
+			if err != nil {
+				logger.Log(1, "relayed node update ", relayedClient.Node.ID.String(), "on network", relayedClient.Node.Network, ": ", err.Error())
+
+			}
+		}
+		peers, err := logic.GetNetworkClients(node.Network)
 		if err != nil {
-			logger.Log(1, "error sending update to relayed node ", relayedNode.ID.String(), "on network", netid, ": ", err.Error())
+			logger.Log(0, "error getting network nodes: ", err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
 		}
-	}
+		clients := peers
+		for _, client := range clients {
+			mq.PubPeerUpdate(&client, nil, &peers)
+		}
+	}()
+	logger.Log(1, r.Header.Get("user"), "deleted relay on node", node.ID.String(), "on network", node.Network)
 	apiNode := node.ConvertToAPINode()
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(apiNode)

+ 17 - 0
logic/nodes.go

@@ -41,6 +41,23 @@ func GetNetworkNodes(network string) ([]models.Node, error) {
 	return GetNetworkNodesMemory(allnodes, network), nil
 }
 
+// GetNetworkClients - gets the clients of a network
+func GetNetworkClients(network string) ([]models.Client, error) {
+	clients := []models.Client{}
+	nodes, err := GetNetworkNodes(network)
+	if err != nil {
+		return []models.Client{}, err
+	}
+	for _, node := range nodes {
+		client := models.Client{
+			Node: node,
+			Host: *GetHostByNodeID(node.ID.String()),
+		}
+		clients = append(clients, client)
+	}
+	return clients, nil
+}
+
 // GetNetworkNodesMemory - gets all nodes belonging to a network from list in memory
 func GetNetworkNodesMemory(allNodes []models.Node, network string) []models.Node {
 	var nodes = []models.Node{}

+ 33 - 25
logic/relay.go

@@ -12,51 +12,51 @@ import (
 )
 
 // CreateRelay - creates a relay
-func CreateRelay(relay models.RelayRequest) ([]models.Node, models.Node, error) {
-	var returnnodes []models.Node
-
+func CreateRelay(relay models.RelayRequest) ([]models.Client, models.Node, error) {
+	var relayedClients []models.Client
 	node, err := GetNodeByID(relay.NodeID)
 	if err != nil {
-		return returnnodes, models.Node{}, err
+		return relayedClients, models.Node{}, err
 	}
 	host, err := GetHost(node.HostID.String())
 	if err != nil {
-		return returnnodes, models.Node{}, err
+		return relayedClients, models.Node{}, err
 	}
 	if host.OS != "linux" {
-		return returnnodes, models.Node{}, fmt.Errorf("only linux machines can be relay nodes")
+		return relayedClients, models.Node{}, fmt.Errorf("only linux machines can be relay nodes")
 	}
 	err = ValidateRelay(relay)
 	if err != nil {
-		return returnnodes, models.Node{}, err
+		return relayedClients, models.Node{}, err
 	}
 	node.IsRelay = true
+	node.RelayedNodes = relay.RelayedNodes
 	node.SetLastModified()
 	nodeData, err := json.Marshal(&node)
 	if err != nil {
-		return returnnodes, node, err
+		return relayedClients, node, err
 	}
 	if err = database.Insert(node.ID.String(), string(nodeData), database.NODES_TABLE_NAME); err != nil {
-		return returnnodes, models.Node{}, err
+		return relayedClients, models.Node{}, err
 	}
-	returnnodes = SetRelayedNodes(true, relay.NodeID, relay.RelayedNodes)
-	for _, relayedNode := range returnnodes {
-		data, err := json.Marshal(&relayedNode)
+	relayedClients = SetRelayedNodes(true, relay.NodeID, relay.RelayedNodes)
+	for _, relayed := range relayedClients {
+		data, err := json.Marshal(&relayed.Node)
 		if err != nil {
 			logger.Log(0, "marshalling relayed node", err.Error())
 			continue
 		}
-		if err := database.Insert(relayedNode.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil {
+		if err := database.Insert(relayed.Node.ID.String(), string(data), database.NODES_TABLE_NAME); err != nil {
 			logger.Log(0, "inserting relayed node", err.Error())
 			continue
 		}
 	}
-	return returnnodes, node, nil
+	return relayedClients, node, nil
 }
 
 // SetRelayedNodes- sets and saves node as relayed
-func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.Node {
-	var returnnodes []models.Node
+func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.Client {
+	var returnnodes []models.Client
 	for _, id := range relayed {
 		node, err := GetNodeByID(id)
 		if err != nil {
@@ -66,6 +66,8 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N
 		node.IsRelayed = setRelayed
 		if node.IsRelayed {
 			node.RelayedBy = relay
+		} else {
+			node.RelayedBy = ""
 		}
 		node.SetLastModified()
 		data, err := json.Marshal(&node)
@@ -77,7 +79,11 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.N
 			logger.Log(0, "setRelayedNodes.Insert", err.Error())
 			continue
 		}
-		returnnodes = append(returnnodes, node)
+		host := GetHostByNodeID(node.ID.String())
+		returnnodes = append(returnnodes, models.Client{
+			Host: *host,
+			Node: node,
+		})
 	}
 	return returnnodes
 }
@@ -94,29 +100,31 @@ func ValidateRelay(relay models.RelayRequest) error {
 }
 
 // UpdateRelayed - updates relay nodes
-func UpdateRelayed(relay string, oldNodes []string, newNodes []string) []models.Node {
+func UpdateRelayed(relay string, oldNodes []string, newNodes []string) []models.Client {
 	_ = SetRelayedNodes(false, relay, oldNodes)
 	return SetRelayedNodes(true, relay, newNodes)
 }
 
 // DeleteRelay - deletes a relay
-func DeleteRelay(network, nodeid string) ([]models.Node, models.Node, error) {
-	var returnnodes []models.Node
+func DeleteRelay(network, nodeid string) ([]models.Client, models.Node, error) {
+	var returnClients []models.Client
 	node, err := GetNodeByID(nodeid)
 	if err != nil {
-		return returnnodes, models.Node{}, err
+		return returnClients, models.Node{}, err
 	}
-	returnnodes = SetRelayedNodes(false, nodeid, node.RelayedNodes)
+
+	returnClients = SetRelayedNodes(false, nodeid, node.RelayedNodes)
 	node.IsRelay = false
+	node.RelayedNodes = []string{}
 	node.SetLastModified()
 	data, err := json.Marshal(&node)
 	if err != nil {
-		return returnnodes, models.Node{}, err
+		return returnClients, models.Node{}, err
 	}
 	if err = database.Insert(nodeid, string(data), database.NODES_TABLE_NAME); err != nil {
-		return returnnodes, models.Node{}, err
+		return returnClients, models.Node{}, err
 	}
-	return returnnodes, node, nil
+	return returnClients, node, nil
 }
 
 func getRelayedAddresses(id string) []net.IPNet {

+ 6 - 0
models/host.go

@@ -75,6 +75,12 @@ type Host struct {
 	TurnEndpoint     *netip.AddrPort  `json:"turn_endpoint,omitempty" yaml:"turn_endpoint,omitempty"`
 }
 
+// Client - represents a client on the network
+type Client struct {
+	Host Host `json:"host" yaml:"host"`
+	Node Node `json:"node" yaml:"node"`
+}
+
 // FormatBool converts a boolean to a [yes|no] string
 func FormatBool(b bool) string {
 	s := "no"

+ 255 - 0
mq/relay.go

@@ -0,0 +1,255 @@
+package mq
+
+import (
+	"encoding/json"
+	"fmt"
+	"net"
+
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
+	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+)
+
+// PubPeerUpdate publishes a peer update to the client
+// relay is set to a newly created relay node or nil for other peer updates
+func PubPeerUpdate(client, relay *models.Client, peers *[]models.Client) {
+	p := models.PeerAction{
+		Action: models.UpdatePeer,
+	}
+	if client.Node.IsRelay {
+		pubRelayUpdate(client, peers)
+		return
+	}
+	if relay != nil {
+		if logic.StringSliceContains(relay.Node.RelayedNodes, client.Node.ID.String()) {
+			pubRelayedUpdate(client, relay, peers)
+			return
+		}
+	}
+	for _, peer := range *peers {
+		if client.Host.ID == peer.Host.ID {
+			continue
+		}
+		update := wgtypes.PeerConfig{
+			PublicKey:         peer.Host.PublicKey,
+			ReplaceAllowedIPs: true,
+			Endpoint: &net.UDPAddr{
+				IP:   peer.Host.EndpointIP,
+				Port: peer.Host.ListenPort,
+			},
+			PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
+		}
+		if relay != nil {
+			if peer.Node.IsRelayed && peer.Node.RelayedBy == relay.Node.ID.String() {
+				update.Remove = true
+			}
+		}
+		addAllowedIPs(peer, &update)
+		p.Peers = append(p.Peers, update)
+	}
+	data, err := json.Marshal(p)
+	if err != nil {
+		logger.Log(0, "marshal peer update", err.Error())
+		return
+	}
+	publish(&client.Host, fmt.Sprintf("peer/host/%s/%s", client.Host.ID.String(), servercfg.GetServer()), data)
+}
+
+// getRelayAllowedIPs returns the list of allowedips for a given peer that is a relay
+func getRelayAllowedIPs(peer models.Client) []net.IPNet {
+	var relayIPs []net.IPNet
+	for _, relayed := range peer.Node.RelayedNodes {
+		node, err := logic.GetNodeByID(relayed)
+		if err != nil {
+			logger.Log(0, "retrieve relayed node", err.Error())
+			continue
+		}
+		if node.Address.IP != nil {
+			node.Address.Mask = net.CIDRMask(32, 32)
+			relayIPs = append(relayIPs, node.Address)
+		}
+		if node.Address6.IP != nil {
+			node.Address.Mask = net.CIDRMask(128, 128)
+			relayIPs = append(relayIPs, node.Address6)
+		}
+		if node.IsRelay {
+			relayIPs = append(relayIPs, getRelayAllowedIPs(peer)...)
+		}
+		if node.IsEgressGateway {
+			relayIPs = append(relayIPs, getEgressIPs(peer)...)
+		}
+		if node.IsIngressGateway {
+			relayIPs = append(relayIPs, getIngressIPs(peer)...)
+		}
+	}
+	return relayIPs
+}
+
+// getEgressIPs returns the additional allowedips (egress ranges) that need
+// to be included for an egress gateway peer
+func getEgressIPs(peer models.Client) []net.IPNet {
+	var egressIPs []net.IPNet
+	for _, egressRange := range peer.Node.EgressGatewayRanges {
+		ip, cidr, err := net.ParseCIDR(egressRange)
+		if err != nil {
+			logger.Log(0, "parse egress range", err.Error())
+			continue
+		}
+		cidr.IP = ip
+		egressIPs = append(egressIPs, *cidr)
+	}
+	return egressIPs
+}
+
+// getIngressIPs returns the additional allowedips (ext client addresses) that need
+// to be included for an ingress gateway peer
+// TODO:  add ExtraAllowedIPs
+func getIngressIPs(peer models.Client) []net.IPNet {
+	var ingressIPs []net.IPNet
+	extclients, err := logic.GetNetworkExtClients(peer.Node.Network)
+	if err != nil {
+		return ingressIPs
+	}
+	for _, ec := range extclients {
+		if ec.IngressGatewayID == peer.Node.ID.String() {
+			if ec.Address != "" {
+				ip, cidr, err := net.ParseCIDR(ec.Address)
+				if err != nil {
+					continue
+				}
+				cidr.IP = ip
+				ingressIPs = append(ingressIPs, *cidr)
+			}
+			if ec.Address6 != "" {
+				ip, cidr, err := net.ParseCIDR(ec.Address6)
+				if err != nil {
+					continue
+				}
+				cidr.IP = ip
+				ingressIPs = append(ingressIPs, *cidr)
+			}
+		}
+	}
+	return ingressIPs
+}
+
+// pubRelayedUpdate - publish peer update to a node (client) that is relayed by the relay
+func pubRelayedUpdate(client, relay *models.Client, peers *[]models.Client) {
+	//verify
+	if !logic.StringSliceContains(relay.Node.RelayedNodes, client.Node.ID.String()) {
+		logger.Log(0, "invalid call to pubRelayed update", client.Host.Name, relay.Host.Name)
+		return
+	}
+	//remove all nodes except relay
+	p := models.PeerAction{
+		Action: models.RemovePeer,
+	}
+	for _, peer := range *peers {
+		if peer.Host.ID == relay.Host.ID || peer.Host.ID == client.Host.ID {
+			continue
+		}
+		update := wgtypes.PeerConfig{
+			PublicKey: peer.Host.PublicKey,
+			Remove:    true,
+		}
+		p.Peers = append(p.Peers, update)
+	}
+	data, err := json.Marshal(p)
+	if err != nil {
+		logger.Log(0, "marshal peer update", err.Error())
+		return
+	}
+	publish(&client.Host, fmt.Sprintf("peer/host/%s/%s", client.Host.ID.String(), servercfg.GetServer()), data)
+	//update the relay peer
+	p = models.PeerAction{
+		Action: models.UpdatePeer,
+	}
+	update := wgtypes.PeerConfig{
+		PublicKey:         relay.Host.PublicKey,
+		ReplaceAllowedIPs: true,
+		Endpoint: &net.UDPAddr{
+			IP:   relay.Host.EndpointIP,
+			Port: relay.Host.ListenPort,
+		},
+		PersistentKeepaliveInterval: &relay.Node.PersistentKeepalive,
+	}
+	if relay.Node.Address.IP != nil {
+		relay.Node.Address.Mask = net.CIDRMask(32, 32)
+		update.AllowedIPs = append(update.AllowedIPs, relay.Node.Address)
+	}
+	if relay.Node.Address6.IP != nil {
+		relay.Node.Address6.Mask = net.CIDRMask(128, 128)
+		update.AllowedIPs = append(update.AllowedIPs, relay.Node.Address6)
+	}
+	p.Peers = append(p.Peers, update)
+	// add all other peers to allowed ips
+	for _, peer := range *peers {
+		if peer.Host.ID == relay.Host.ID || peer.Host.ID == client.Host.ID {
+			continue
+		}
+		addAllowedIPs(peer, &update)
+	}
+	p.Peers = append(p.Peers, update)
+	data, err = json.Marshal(p)
+	if err != nil {
+		logger.Log(0, "marshal peer update", err.Error())
+		return
+	}
+	publish(&client.Host, fmt.Sprintf("peer/host/%s/%s", client.Host.ID.String(), servercfg.GetServer()), data)
+}
+
+// pubRelayUpdate - publish peer update to a relay
+func pubRelayUpdate(client *models.Client, peers *[]models.Client) {
+	if !client.Node.IsRelay {
+		return
+	}
+	// add all peers to allowedips
+	p := models.PeerAction{
+		Action: models.UpdatePeer,
+	}
+	for _, peer := range *peers {
+		if peer.Host.ID == client.Host.ID {
+			continue
+		}
+		update := wgtypes.PeerConfig{
+			PublicKey:         peer.Host.PublicKey,
+			ReplaceAllowedIPs: true,
+			Remove:            false,
+			Endpoint: &net.UDPAddr{
+				IP:   peer.Host.EndpointIP,
+				Port: peer.Host.ListenPort,
+			},
+			PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
+		}
+		addAllowedIPs(peer, &update)
+		p.Peers = append(p.Peers, update)
+	}
+	data, err := json.Marshal(p)
+	if err != nil {
+		logger.Log(0, "marshal peer update", err.Error())
+		return
+	}
+	publish(&client.Host, fmt.Sprintf("peer/host/%s/%s", client.Host.ID.String(), servercfg.GetServer()), data)
+}
+
+func addAllowedIPs(peer models.Client, update *wgtypes.PeerConfig) {
+	if peer.Node.Address.IP != nil {
+		peer.Node.Address.Mask = net.CIDRMask(32, 32)
+		update.AllowedIPs = append(update.AllowedIPs, peer.Node.Address)
+	}
+	if peer.Node.Address6.IP != nil {
+		peer.Node.Address6.Mask = net.CIDRMask(128, 128)
+		update.AllowedIPs = append(update.AllowedIPs, peer.Node.Address6)
+	}
+	if peer.Node.IsRelay {
+		update.AllowedIPs = append(update.AllowedIPs, getRelayAllowedIPs(peer)...)
+	}
+	if peer.Node.IsEgressGateway {
+		update.AllowedIPs = append(update.AllowedIPs, getEgressIPs(peer)...)
+	}
+	if peer.Node.IsIngressGateway {
+		update.AllowedIPs = append(update.AllowedIPs, getIngressIPs(peer)...)
+	}
+}