2
0
Эх сурвалжийг харах

Net 137p - Pull requests (#2349)

* pull response for refactored relays

* remove debugging logs

* statticcheck

* review comments

* use GetHost vice GetHostByNodeID
Matthew R Kasun 2 жил өмнө
parent
commit
80c2fad9bf
8 өөрчлөгдсөн 401 нэмэгдсэн , 104 устгасан
  1. 9 9
      controllers/hosts.go
  2. 29 20
      controllers/relay.go
  3. 7 4
      logic/nodes.go
  4. 215 20
      logic/peers.go
  5. 121 14
      logic/relay.go
  6. 3 3
      models/node.go
  7. 3 3
      models/structs.go
  8. 14 31
      mq/relay.go

+ 9 - 9
controllers/hosts.go

@@ -1,7 +1,6 @@
 package controller
 
 import (
-	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -82,12 +81,14 @@ func pull(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	hPU, err := logic.GetPeerUpdateForHost(context.Background(), "", host, nil, nil)
-	if err != nil {
-		logger.Log(0, "could not pull peers for host", hostID)
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-		return
-	}
+	peers := logic.GetPeerUpdate(host)
+
+	//hPU, err := logic.GetPeerUpdateForHost(context.Background(), "", host, nil, nil)
+	//if err != nil {
+	//logger.Log(0, "could not pull peers for host", hostID)
+	//logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+	//return
+	//}
 	serverConf := servercfg.GetServerInfo()
 	if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
 		serverConf.MQUserName = hostID
@@ -102,8 +103,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
 	response := models.HostPull{
 		Host:         *host,
 		ServerConfig: serverConf,
-		Peers:        hPU.Peers,
-		PeerIDs:      hPU.PeerIDs,
+		Peers:        peers,
 	}
 
 	logger.Log(1, hostID, "completed a pull")

+ 29 - 20
controllers/relay.go

@@ -43,8 +43,15 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	relayHost, err := logic.GetHost(relayNode.HostID.String())
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"),
+			fmt.Sprintf("failed to retrieve host for node [%s] on network [%s]: %v", relayRequest.NodeID, relayRequest.NetID, err))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
 	relay := models.Client{
-		Host: *logic.GetHostByNodeID(params["nodeid"]),
+		Host: *relayHost,
 		Node: relayNode,
 	}
 	peers, err := logic.GetNetworkClients(relay.Node.Network)
@@ -59,7 +66,7 @@ func createRelay(w http.ResponseWriter, r *http.Request) {
 	//}
 	clients := peers
 	for _, client := range clients {
-		mq.PubPeerUpdate(&client, &relay, &peers)
+		mq.PubPeerUpdate(&client, &relay, peers)
 	}
 
 	logger.Log(1, r.Header.Get("user"), "created relay on node", relayRequest.NodeID, "on network", relayRequest.NetID)
@@ -94,26 +101,28 @@ func deleteRelay(w http.ResponseWriter, r *http.Request) {
 	logger.Log(1, r.Header.Get("user"), "deleted relay server", nodeid, "on network", netid)
 	go func() {
 		//update relayHost node
-		relayHost := logic.GetHostByNodeID(node.ID.String())
-		if err := mq.NodeUpdate(&node); err != nil {
-			logger.Log(1, "relay node update", relayHost.Name, "on network", node.Network, ": ", err.Error())
-		}
-		for _, relayedClient := range updateClients {
-			err = mq.NodeUpdate(&relayedClient.Node)
-			if err != nil {
-				logger.Log(1, "relayed node update ", relayedClient.Node.ID.String(), "on network", relayedClient.Node.Network, ": ", err.Error())
+		relayHost, err := logic.GetHost(node.HostID.String())
+		if err == nil {
+			if err := mq.NodeUpdate(&node); err != nil {
+				logger.Log(1, "relay node update", relayHost.Name, "on network", node.Network, ": ", err.Error())
+			}
+			for _, relayedClient := range updateClients {
+				err = mq.NodeUpdate(&relayedClient.Node)
+				if err != nil {
+					logger.Log(1, "relayed node update ", relayedClient.Node.ID.String(), "on network", relayedClient.Node.Network, ": ", err.Error())
 
+				}
+			}
+			peers, err := logic.GetNetworkClients(node.Network)
+			if err != nil {
+				logger.Log(0, "error getting network nodes: ", err.Error())
+				logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+				return
+			}
+			clients := peers
+			for _, client := range clients {
+				mq.PubPeerUpdate(&client, nil, peers)
 			}
-		}
-		peers, err := logic.GetNetworkClients(node.Network)
-		if err != nil {
-			logger.Log(0, "error getting network nodes: ", err.Error())
-			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-			return
-		}
-		clients := peers
-		for _, client := range clients {
-			mq.PubPeerUpdate(&client, nil, &peers)
 		}
 	}()
 	logger.Log(1, r.Header.Get("user"), "deleted relay on node", node.ID.String(), "on network", node.Network)

+ 7 - 4
logic/nodes.go

@@ -49,11 +49,14 @@ func GetNetworkClients(network string) ([]models.Client, error) {
 		return []models.Client{}, err
 	}
 	for _, node := range nodes {
-		client := models.Client{
-			Node: node,
-			Host: *GetHostByNodeID(node.ID.String()),
+		host, err := GetHost(node.HostID.String())
+		if err == nil {
+			client := models.Client{
+				Node: node,
+				Host: *host,
+			}
+			clients = append(clients, client)
 		}
-		clients = append(clients, client)
 	}
 	return clients, nil
 }

+ 215 - 20
logic/peers.go

@@ -193,7 +193,14 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 					}
 				}
 				if peer.IsEgressGateway {
-					allowedips = append(allowedips, getEgressIPs(&node, &peer)...)
+					host, err := GetHost(peer.HostID.String())
+					if err == nil {
+						allowedips = append(allowedips, getEgressIPs(
+							&models.Client{
+								Host: *host,
+								Node: peer,
+							})...)
+					}
 				}
 				if peer.Action != models.NODE_DELETE &&
 					!peer.PendingDelete &&
@@ -570,42 +577,36 @@ func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet
 	return allowedips
 }
 
-func getEgressIPs(node, peer *models.Node) []net.IPNet {
-	host, err := GetHost(node.HostID.String())
-	if err != nil {
-		logger.Log(0, "error retrieving host for node", node.ID.String(), err.Error())
-	}
-	peerHost, err := GetHost(peer.HostID.String())
-	if err != nil {
-		logger.Log(0, "error retrieving host for peer", peer.ID.String(), err.Error())
-	}
+// getEgressIPs - gets the egress IPs for a client
+func getEgressIPs(client *models.Client) []net.IPNet {
 
 	//check for internet gateway
 	internetGateway := false
-	if slices.Contains(peer.EgressGatewayRanges, "0.0.0.0/0") || slices.Contains(peer.EgressGatewayRanges, "::/0") {
+	if slices.Contains(client.Node.EgressGatewayRanges, "0.0.0.0/0") || slices.Contains(client.Node.EgressGatewayRanges, "::/0") {
 		internetGateway = true
 	}
 	allowedips := []net.IPNet{}
-	for _, iprange := range peer.EgressGatewayRanges { // go through each cidr for egress gateway
-		_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
+	for _, iprange := range client.Node.EgressGatewayRanges { // go through each cidr for egress gateway
+		ip, cidr, err := net.ParseCIDR(iprange) // confirming it's valid cidr
 		if err != nil {
 			logger.Log(1, "could not parse gateway IP range. Not adding ", iprange)
 			continue // if can't parse CIDR
 		}
+		cidr.IP = ip
 		// getting the public ip of node
-		if ipnet.Contains(peerHost.EndpointIP) && !internetGateway { // ensuring egress gateway range does not contain endpoint of node
-			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", host.EndpointIP.String(), ", omitting")
+		if cidr.Contains(client.Host.EndpointIP) && !internetGateway { // ensuring egress gateway range does not contain endpoint of node
+			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", client.Host.EndpointIP.String(), ", omitting")
 			continue // skip adding egress range if overlaps with node's ip
 		}
 		// TODO: Could put in a lot of great logic to avoid conflicts / bad routes
-		if ipnet.Contains(node.LocalAddress.IP) && !internetGateway { // ensuring egress gateway range does not contain public ip of node
-			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", node.LocalAddress.String(), ", omitting")
+		if cidr.Contains(client.Node.LocalAddress.IP) && !internetGateway { // ensuring egress gateway range does not contain public ip of node
+			logger.Log(2, "egress IP range of ", iprange, " overlaps with ", client.Node.LocalAddress.String(), ", omitting")
 			continue // skip adding egress range if overlaps with node's local ip
 		}
 		if err != nil {
 			logger.Log(1, "error encountered when setting egress range", err.Error())
 		} else {
-			allowedips = append(allowedips, *ipnet)
+			allowedips = append(allowedips, *cidr)
 		}
 	}
 	return allowedips
@@ -630,8 +631,15 @@ func getNodeAllowedIPs(peer, node *models.Node) []net.IPNet {
 	// handle egress gateway peers
 	if peer.IsEgressGateway {
 		//hasGateway = true
-		egressIPs := getEgressIPs(node, peer)
-		allowedips = append(allowedips, egressIPs...)
+		host, err := GetHost(peer.HostID.String())
+		if err == nil {
+			egressIPs := getEgressIPs(
+				&models.Client{
+					Host: *host,
+					Node: *peer,
+				})
+			allowedips = append(allowedips, egressIPs...)
+		}
 	}
 	if peer.IsRelay {
 		for _, relayed := range peer.RelayedNodes {
@@ -679,3 +687,190 @@ func filterNodeMapForClientACLs(publicKey, network string, nodePeerMap map[strin
 	}
 	return nodePeerMap
 }
+
+func GetPeerUpdate(host *models.Host) []wgtypes.PeerConfig {
+	peerUpdate := []wgtypes.PeerConfig{}
+	for _, nodeStr := range host.Nodes {
+		node, err := GetNodeByID(nodeStr)
+		if err != nil {
+			continue
+		}
+		client := models.Client{Host: *host, Node: node}
+		peers, err := GetNetworkClients(node.Network)
+		if err != nil {
+			continue
+		}
+		if node.IsRelayed {
+			peerUpdate = append(peerUpdate, peerUpdateForRelayed(&client, peers)...)
+			continue
+		}
+		if node.IsRelay {
+			peerUpdate = append(peerUpdate, peerUpdateForRelay(&client, peers)...)
+			continue
+		}
+		for _, peer := range peers {
+			if peer.Host.ID == client.Host.ID {
+				continue
+			}
+			// if peer is relayed by some other node, remove it from the peer list,  it
+			// will be added to allowedips of relay peer
+			if peer.Node.IsRelayed {
+				update := wgtypes.PeerConfig{
+					PublicKey: peer.Host.PublicKey,
+					Remove:    true,
+				}
+				peerUpdate = append(peerUpdate, update)
+				continue
+			}
+			update := wgtypes.PeerConfig{
+				PublicKey:         peer.Host.PublicKey,
+				ReplaceAllowedIPs: true,
+				Endpoint: &net.UDPAddr{
+					IP:   peer.Host.EndpointIP,
+					Port: peer.Host.ListenPort,
+				},
+				PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
+			}
+			// if peer is a relay that relays us, don't do anything
+			if peer.Node.IsRelay && client.Node.RelayedBy == peer.Node.ID.String() {
+				continue
+			} else {
+				update.AllowedIPs = append(update.AllowedIPs, getRelayAllowedIPs(&peer)...)
+			}
+			//normal peer
+			update.AllowedIPs = append(update.AllowedIPs, AddAllowedIPs(&peer)...)
+			peerUpdate = append(peerUpdate, update)
+		}
+	}
+	return peerUpdate
+}
+
+func AddAllowedIPs(peer *models.Client) []net.IPNet {
+	allowedIPs := []net.IPNet{}
+	if peer.Node.Address.IP != nil {
+		peer.Node.Address.Mask = net.CIDRMask(32, 32)
+		allowedIPs = append(allowedIPs, peer.Node.Address)
+	}
+	if peer.Node.Address6.IP != nil {
+		peer.Node.Address6.Mask = net.CIDRMask(128, 128)
+		allowedIPs = append(allowedIPs, peer.Node.Address6)
+	}
+	if peer.Node.IsEgressGateway {
+		allowedIPs = append(allowedIPs, getEgressIPs(peer)...)
+	}
+	if peer.Node.IsIngressGateway {
+		allowedIPs = append(allowedIPs, getIngressIPs(peer)...)
+	}
+	return allowedIPs
+}
+
+// getRelayAllowedIPs returns the list of allowedips for a peer that is a relay
+func getRelayAllowedIPs(peer *models.Client) []net.IPNet {
+	var relayIPs []net.IPNet
+	if !peer.Node.IsRelay {
+		logger.Log(0, "getRelayAllowedIPs called for a non-relay node", peer.Host.Name)
+		return relayIPs
+	}
+	//if !client.Node.IsRelayed || client.Node.RelayedBy != peer.Node.ID.String() {
+	//logger.Log(0, "getRelayAllowedIPs called for non-relayed node", client.Host.Name, peer.Host.Name)
+	//return relayIPs
+	//}
+	for _, relayed := range peer.Node.RelayedNodes {
+		relayedNode, err := GetNodeByID(relayed)
+		if err != nil {
+			logger.Log(0, "retrieve relayed node", err.Error())
+			continue
+		}
+		if relayedNode.Address.IP != nil {
+			relayedNode.Address.Mask = net.CIDRMask(32, 32)
+			relayIPs = append(relayIPs, relayedNode.Address)
+		}
+		if relayedNode.Address6.IP != nil {
+			relayedNode.Address.Mask = net.CIDRMask(128, 128)
+			relayIPs = append(relayIPs, relayedNode.Address6)
+		}
+		host, err := GetHost(relayedNode.HostID.String())
+		if err == nil {
+			if relayedNode.IsRelay {
+				relayIPs = append(relayIPs, getRelayAllowedIPs(
+					&models.Client{
+						Host: *host,
+						Node: relayedNode,
+					})...)
+			}
+			if relayedNode.IsEgressGateway {
+				relayIPs = append(relayIPs, getEgressIPs(
+					&models.Client{
+						Host: *host,
+						Node: relayedNode,
+					})...)
+			}
+			if relayedNode.IsIngressGateway {
+				relayIPs = append(relayIPs, getIngressIPs(
+					&models.Client{
+						Host: *host,
+						Node: relayedNode,
+					})...)
+			}
+		}
+	}
+	return relayIPs
+}
+
+// getIngressIPs returns the additional allowedips (ext client addresses) that need
+// to be included for an ingress gateway peer
+// TODO:  add ExtraAllowedIPs
+func getIngressIPs(peer *models.Client) []net.IPNet {
+	var ingressIPs []net.IPNet
+	extclients, err := GetNetworkExtClients(peer.Node.Network)
+	if err != nil {
+		return ingressIPs
+	}
+	for _, ec := range extclients {
+		if ec.IngressGatewayID == peer.Node.ID.String() {
+			if ec.Address != "" {
+				ip, cidr, err := net.ParseCIDR(ec.Address)
+				if err != nil {
+					continue
+				}
+				cidr.IP = ip
+				ingressIPs = append(ingressIPs, *cidr)
+			}
+			if ec.Address6 != "" {
+				ip, cidr, err := net.ParseCIDR(ec.Address6)
+				if err != nil {
+					continue
+				}
+				cidr.IP = ip
+				ingressIPs = append(ingressIPs, *cidr)
+			}
+		}
+	}
+	return ingressIPs
+}
+
+// GetPeerUpdateForRelay - returns the peer update for a relay node
+func GetPeerUpdateForRelay(client *models.Client, peers []models.Client) []wgtypes.PeerConfig {
+	peerConfig := []wgtypes.PeerConfig{}
+	if !client.Node.IsRelay {
+		return []wgtypes.PeerConfig{}
+	}
+	for _, peer := range peers {
+		if peer.Host.ID == client.Host.ID {
+			continue
+		}
+		update := wgtypes.PeerConfig{
+			PublicKey:         peer.Host.PublicKey,
+			ReplaceAllowedIPs: true,
+			Remove:            false,
+			Endpoint: &net.UDPAddr{
+				IP:   peer.Host.EndpointIP,
+				Port: peer.Host.ListenPort,
+			},
+			PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
+		}
+		update.AllowedIPs = append(update.AllowedIPs, AddAllowedIPs(&peer)...)
+		peerConfig = append(peerConfig, update)
+	}
+	return peerConfig
+}

+ 121 - 14
logic/relay.go

@@ -9,6 +9,7 @@ import (
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
+	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
 // CreateRelay - creates a relay
@@ -79,11 +80,13 @@ func SetRelayedNodes(setRelayed bool, relay string, relayed []string) []models.C
 			logger.Log(0, "setRelayedNodes.Insert", err.Error())
 			continue
 		}
-		host := GetHostByNodeID(node.ID.String())
-		returnnodes = append(returnnodes, models.Client{
-			Host: *host,
-			Node: node,
-		})
+		host, err := GetHost(node.HostID.String())
+		if err == nil {
+			returnnodes = append(returnnodes, models.Client{
+				Host: *host,
+				Node: node,
+			})
+		}
 	}
 	return returnnodes
 }
@@ -94,7 +97,7 @@ func ValidateRelay(relay models.RelayRequest) error {
 	//isIp := functions.IsIpCIDR(gateway.RangeString)
 	empty := len(relay.RelayedNodes) == 0
 	if empty {
-		err = errors.New("IP Ranges Cannot Be Empty")
+		err = errors.New("relayed nodes cannot be empty")
 	}
 	return err
 }
@@ -135,16 +138,120 @@ func getRelayedAddresses(id string) []net.IPNet {
 		return addrs
 	}
 	if node.Address.IP != nil {
-		addrs = append(addrs, net.IPNet{
-			IP:   node.Address.IP,
-			Mask: net.CIDRMask(32, 32),
-		})
+		node.Address.Mask = net.CIDRMask(32, 32)
+		addrs = append(addrs, node.Address)
 	}
 	if node.Address6.IP != nil {
-		addrs = append(addrs, net.IPNet{
-			IP:   node.Address6.IP,
-			Mask: net.CIDRMask(128, 128),
-		})
+		node.Address6.Mask = net.CIDRMask(128, 128)
+		addrs = append(addrs, node.Address6)
 	}
 	return addrs
 }
+
+// peerUpdateForRelayed - returns the peerConfig for a relayed node
+func peerUpdateForRelayed(client *models.Client, peers []models.Client) []wgtypes.PeerConfig {
+	peerConfig := []wgtypes.PeerConfig{}
+	if !client.Node.IsRelayed {
+		logger.Log(0, "GetPeerUpdateForRelayed called for non-relayed node ", client.Host.Name)
+		return []wgtypes.PeerConfig{}
+	}
+	relayNode, err := GetNodeByID(client.Node.RelayedBy)
+	if err != nil {
+		logger.Log(0, "error retrieving relay node", err.Error())
+		return []wgtypes.PeerConfig{}
+	}
+	host, err := GetHost(relayNode.HostID.String())
+	if err != nil {
+		return []wgtypes.PeerConfig{}
+	}
+	relay := models.Client{
+		Host: *host,
+		Node: relayNode,
+	}
+	for _, peer := range peers {
+		if peer.Host.ID == client.Host.ID {
+			continue
+		}
+		if peer.Host.ID == relay.Host.ID { // add relay as a peer
+			update := peerUpdateForRelayedByRelay(client, &relay)
+			peerConfig = append(peerConfig, update)
+			continue
+		}
+		update := wgtypes.PeerConfig{
+			PublicKey: peer.Host.PublicKey,
+			Remove:    true,
+		}
+		peerConfig = append(peerConfig, update)
+	}
+	return peerConfig
+}
+
+// peerUpdateForRelayedByRelay - returns the peerConfig for a node relayed by relay
+func peerUpdateForRelayedByRelay(relayed, relay *models.Client) wgtypes.PeerConfig {
+	if relayed.Node.RelayedBy != relay.Node.ID.String() {
+		logger.Log(0, "peerUpdateForRelayedByRelay called with invalid parameters")
+		return wgtypes.PeerConfig{}
+	}
+	update := wgtypes.PeerConfig{
+		PublicKey:         relay.Host.PublicKey,
+		ReplaceAllowedIPs: true,
+		Endpoint: &net.UDPAddr{
+			IP:   relay.Host.EndpointIP,
+			Port: relay.Host.ListenPort,
+		},
+		PersistentKeepaliveInterval: &relay.Node.PersistentKeepalive,
+	}
+	if relay.Node.Address.IP != nil {
+		relay.Node.Address.Mask = net.CIDRMask(32, 32)
+		update.AllowedIPs = append(update.AllowedIPs, relay.Node.Address)
+	}
+	if relay.Node.Address6.IP != nil {
+		relay.Node.Address6.Mask = net.CIDRMask(128, 128)
+		update.AllowedIPs = append(update.AllowedIPs, relay.Node.Address6)
+	}
+	if relay.Node.IsEgressGateway {
+		update.AllowedIPs = append(update.AllowedIPs, getEgressIPs(relay)...)
+	}
+	if relay.Node.IsIngressGateway {
+		update.AllowedIPs = append(update.AllowedIPs, getIngressIPs(relay)...)
+	}
+	peers, err := GetNetworkClients(relay.Node.Network)
+	if err != nil {
+		logger.Log(0, "error getting network clients", err.Error())
+		return update
+	}
+	for _, peer := range peers {
+		if peer.Host.ID == relayed.Host.ID || peer.Host.ID == relay.Host.ID {
+			continue
+		}
+		update.AllowedIPs = append(update.AllowedIPs, AddAllowedIPs(&peer)...)
+	}
+	return update
+}
+
+// peerUpdateForRelay - returns the peerConfig for a relay
+func peerUpdateForRelay(relay *models.Client, peers []models.Client) []wgtypes.PeerConfig {
+	peerConfig := []wgtypes.PeerConfig{}
+	if !relay.Node.IsRelay {
+		logger.Log(0, "GetPeerUpdateForRelay called for non-relay node ", relay.Host.Name)
+		return []wgtypes.PeerConfig{}
+	}
+	for _, peer := range peers {
+		if peer.Host.ID == relay.Host.ID {
+			continue
+		}
+		update := wgtypes.PeerConfig{
+			PublicKey:         peer.Host.PublicKey,
+			ReplaceAllowedIPs: true,
+			Remove:            false,
+			Endpoint: &net.UDPAddr{
+				IP:   peer.Host.EndpointIP,
+				Port: peer.Host.ListenPort,
+			},
+			PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
+		}
+		update.AllowedIPs = append(update.AllowedIPs, AddAllowedIPs(&peer)...)
+		peerConfig = append(peerConfig, update)
+	}
+	return peerConfig
+}

+ 3 - 3
models/node.go

@@ -70,9 +70,9 @@ type CommonNode struct {
 	EgressGatewayRanges []string      `json:"egressgatewayranges" bson:"egressgatewayranges" yaml:"egressgatewayranges"`
 	IsIngressGateway    bool          `json:"isingressgateway" yaml:"isingressgateway"`
 	IngressDNS          string        `json:"ingressdns" yaml:"ingressdns"`
-	IsRelayed           bool          `json:"isrelayed" bson:"isrelayed" yaml:"isrelayed"`
-	RelayedBy           string        `json:"relayedby" bson:"relayedby" yaml:"relayedby"`
-	IsRelay             bool          `json:"isrelay" bson:"isrelay" yaml:"isrelay"`
+	IsRelayed           bool          `json:"isrelayed" yaml:"isrelayed"`
+	RelayedBy           string        `json:"relayedby" yaml:"relayedby"`
+	IsRelay             bool          `json:"isrelay" yaml:"isrelay"`
 	RelayedNodes        []string      `json:"relaynodes" yaml:"relayedNodes"`
 	DNSOn               bool          `json:"dnson" yaml:"dnson"`
 	PersistentKeepalive time.Duration `json:"persistentkeepalive" yaml:"persistentkeepalive"`

+ 3 - 3
models/structs.go

@@ -153,9 +153,9 @@ type EgressGatewayRequest struct {
 
 // RelayRequest - relay request struct
 type RelayRequest struct {
-	NodeID       string   `json:"nodeid" bson:"nodeid"`
-	NetID        string   `json:"netid" bson:"netid"`
-	RelayedNodes []string `json:"relayaddrs" bson:"relayaddrs"`
+	NodeID       string   `json:"nodeid"`
+	NetID        string   `json:"netid"`
+	RelayedNodes []string `json:"relayednodes"`
 }
 
 // HostRelayRequest - struct for host relay creation

+ 14 - 31
mq/relay.go

@@ -14,7 +14,7 @@ import (
 
 // PubPeerUpdate publishes a peer update to the client
 // relay is set to a newly created relay node or nil for other peer updates
-func PubPeerUpdate(client, relay *models.Client, peers *[]models.Client) {
+func PubPeerUpdate(client, relay *models.Client, peers []models.Client) {
 	p := models.PeerAction{
 		Action: models.UpdatePeer,
 	}
@@ -23,12 +23,12 @@ func PubPeerUpdate(client, relay *models.Client, peers *[]models.Client) {
 		return
 	}
 	if relay != nil {
-		if logic.StringSliceContains(relay.Node.RelayedNodes, client.Node.ID.String()) {
+		if client.Node.RelayedBy == relay.Node.ID.String() {
 			pubRelayedUpdate(client, relay, peers)
 			return
 		}
 	}
-	for _, peer := range *peers {
+	for _, peer := range peers {
 		if client.Host.ID == peer.Host.ID {
 			continue
 		}
@@ -41,12 +41,15 @@ func PubPeerUpdate(client, relay *models.Client, peers *[]models.Client) {
 			},
 			PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
 		}
+		update.AllowedIPs = append(update.AllowedIPs, logic.AddAllowedIPs(&peer)...)
 		if relay != nil {
 			if peer.Node.IsRelayed && peer.Node.RelayedBy == relay.Node.ID.String() {
 				update.Remove = true
 			}
 		}
-		addAllowedIPs(peer, &update)
+		if peer.Node.IsRelay {
+			update.AllowedIPs = append(update.AllowedIPs, getRelayAllowedIPs(peer)...)
+		}
 		p.Peers = append(p.Peers, update)
 	}
 	data, err := json.Marshal(p)
@@ -136,7 +139,7 @@ func getIngressIPs(peer models.Client) []net.IPNet {
 }
 
 // pubRelayedUpdate - publish peer update to a node (client) that is relayed by the relay
-func pubRelayedUpdate(client, relay *models.Client, peers *[]models.Client) {
+func pubRelayedUpdate(client, relay *models.Client, peers []models.Client) {
 	//verify
 	if !logic.StringSliceContains(relay.Node.RelayedNodes, client.Node.ID.String()) {
 		logger.Log(0, "invalid call to pubRelayed update", client.Host.Name, relay.Host.Name)
@@ -146,7 +149,7 @@ func pubRelayedUpdate(client, relay *models.Client, peers *[]models.Client) {
 	p := models.PeerAction{
 		Action: models.RemovePeer,
 	}
-	for _, peer := range *peers {
+	for _, peer := range peers {
 		if peer.Host.ID == relay.Host.ID || peer.Host.ID == client.Host.ID {
 			continue
 		}
@@ -185,11 +188,11 @@ func pubRelayedUpdate(client, relay *models.Client, peers *[]models.Client) {
 	}
 	p.Peers = append(p.Peers, update)
 	// add all other peers to allowed ips
-	for _, peer := range *peers {
+	for _, peer := range peers {
 		if peer.Host.ID == relay.Host.ID || peer.Host.ID == client.Host.ID {
 			continue
 		}
-		addAllowedIPs(peer, &update)
+		update.AllowedIPs = append(update.AllowedIPs, logic.AddAllowedIPs(&peer)...)
 	}
 	p.Peers = append(p.Peers, update)
 	data, err = json.Marshal(p)
@@ -201,7 +204,7 @@ func pubRelayedUpdate(client, relay *models.Client, peers *[]models.Client) {
 }
 
 // pubRelayUpdate - publish peer update to a relay
-func pubRelayUpdate(client *models.Client, peers *[]models.Client) {
+func pubRelayUpdate(client *models.Client, peers []models.Client) {
 	if !client.Node.IsRelay {
 		return
 	}
@@ -209,7 +212,7 @@ func pubRelayUpdate(client *models.Client, peers *[]models.Client) {
 	p := models.PeerAction{
 		Action: models.UpdatePeer,
 	}
-	for _, peer := range *peers {
+	for _, peer := range peers {
 		if peer.Host.ID == client.Host.ID {
 			continue
 		}
@@ -223,7 +226,7 @@ func pubRelayUpdate(client *models.Client, peers *[]models.Client) {
 			},
 			PersistentKeepaliveInterval: &peer.Node.PersistentKeepalive,
 		}
-		addAllowedIPs(peer, &update)
+		update.AllowedIPs = append(update.AllowedIPs, logic.AddAllowedIPs(&peer)...)
 		p.Peers = append(p.Peers, update)
 	}
 	data, err := json.Marshal(p)
@@ -233,23 +236,3 @@ func pubRelayUpdate(client *models.Client, peers *[]models.Client) {
 	}
 	publish(&client.Host, fmt.Sprintf("peer/host/%s/%s", client.Host.ID.String(), servercfg.GetServer()), data)
 }
-
-func addAllowedIPs(peer models.Client, update *wgtypes.PeerConfig) {
-	if peer.Node.Address.IP != nil {
-		peer.Node.Address.Mask = net.CIDRMask(32, 32)
-		update.AllowedIPs = append(update.AllowedIPs, peer.Node.Address)
-	}
-	if peer.Node.Address6.IP != nil {
-		peer.Node.Address6.Mask = net.CIDRMask(128, 128)
-		update.AllowedIPs = append(update.AllowedIPs, peer.Node.Address6)
-	}
-	if peer.Node.IsRelay {
-		update.AllowedIPs = append(update.AllowedIPs, getRelayAllowedIPs(peer)...)
-	}
-	if peer.Node.IsEgressGateway {
-		update.AllowedIPs = append(update.AllowedIPs, getEgressIPs(peer)...)
-	}
-	if peer.Node.IsIngressGateway {
-		update.AllowedIPs = append(update.AllowedIPs, getIngressIPs(peer)...)
-	}
-}