Browse Source

stop context cancel on peer updates

Abhishek Kondur 2 years ago
parent
commit
4efbe6256f
6 changed files with 183 additions and 145 deletions
  1. 1 1
      Dockerfile
  2. 7 1
      controllers/hosts.go
  3. 13 1
      controllers/node.go
  4. 130 133
      logic/peers.go
  5. 10 2
      mq/handlers.go
  6. 22 7
      mq/publishers.go

+ 1 - 1
Dockerfile

@@ -4,7 +4,7 @@ ARG tags
 WORKDIR /app
 COPY . .
 
-RUN GOOS=linux CGO_ENABLED=1 go build -ldflags="-s -w " -tags ${tags} .
+RUN GOOS=linux CGO_ENABLED=1 go build -race -ldflags="-s -w " -tags ${tags} .
 # RUN go build -tags=ee . -o netmaker main.go
 FROM alpine:3.18.2
 

+ 7 - 1
controllers/hosts.go

@@ -81,7 +81,13 @@ func pull(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	hPU, err := logic.GetPeerUpdateForHost(context.Background(), "", host, nil, nil)
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		logger.Log(0, "could not pull peers for host", hostID)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	hPU, err := logic.GetPeerUpdateForHost(context.Background(), "", host, allNodes, nil, nil)
 	if err != nil {
 		logger.Log(0, "could not pull peers for host", hostID)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))

+ 13 - 1
controllers/node.go

@@ -388,7 +388,14 @@ func getNode(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	hostPeerUpdate, err := logic.GetPeerUpdateForHost(context.Background(), node.Network, host, nil, nil)
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		logger.Log(0, r.Header.Get("user"),
+			fmt.Sprintf("error fetching wg peers config for host [ %s ]: %v", host.ID.String(), err))
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	hostPeerUpdate, err := logic.GetPeerUpdateForHost(context.Background(), node.Network, host, allNodes, nil, nil)
 	if err != nil && !database.IsEmptyRecord(err) {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("error fetching wg peers config for host [ %s ]: %v", host.ID.String(), err))
@@ -583,9 +590,14 @@ func deleteIngressGateway(w http.ResponseWriter, r *http.Request) {
 	if len(removedClients) > 0 {
 		host, err := logic.GetHost(node.HostID.String())
 		if err == nil {
+			allNodes, err := logic.GetAllNodes()
+			if err != nil {
+				return
+			}
 			go mq.PublishSingleHostPeerUpdate(
 				context.Background(),
 				host,
+				allNodes,
 				nil,
 				removedClients[:],
 			)

+ 130 - 133
logic/peers.go

@@ -3,7 +3,6 @@ package logic
 import (
 	"context"
 	"errors"
-	"fmt"
 	"net"
 	"net/netip"
 
@@ -87,6 +86,7 @@ func GetProxyUpdateForHost(ctx context.Context, host *models.Host) (models.Proxy
 
 // ResetPeerUpdateContext - kills any current peer updates and resets the context
 func ResetPeerUpdateContext() {
+	return
 	if PeerUpdateCtx != nil && PeerUpdateStop != nil {
 		PeerUpdateStop() // tell any current peer updates to stop
 	}
@@ -95,14 +95,11 @@ func ResetPeerUpdateContext() {
 }
 
 // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks
-func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host, deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
+func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient) (models.HostPeerUpdate, error) {
 	if host == nil {
 		return models.HostPeerUpdate{}, errors.New("host is nil")
 	}
-	allNodes, err := GetAllNodes()
-	if err != nil {
-		return models.HostPeerUpdate{}, err
-	}
+
 	// track which nodes are deleted
 	// after peer calculation, if peer not in list, add delete config of peer
 	hostPeerUpdate := models.HostPeerUpdate{
@@ -141,150 +138,150 @@ func GetPeerUpdateForHost(ctx context.Context, network string, host *models.Host
 			nodePeerMap = make(map[string]models.PeerRouteInfo)
 		}
 		for _, peer := range currentPeers {
-			select {
-			case <-ctx.Done():
-				logger.Log(2, "cancelled peer update for host", host.Name, host.ID.String())
-				return models.HostPeerUpdate{}, fmt.Errorf("peer update cancelled")
-			default:
-				peer := peer
-				if peer.ID.String() == node.ID.String() {
-					logger.Log(2, "peer update, skipping self")
-					//skip yourself
-					continue
-				}
+			//select {
+			// case <-ctx.Done():
+			// 	logger.Log(2, "cancelled peer update for host", host.Name, host.ID.String())
+			// 	return models.HostPeerUpdate{}, fmt.Errorf("peer update cancelled")
+			//default:
+			peer := peer
+			if peer.ID.String() == node.ID.String() {
+				logger.Log(2, "peer update, skipping self")
+				//skip yourself
+				continue
+			}
 
-				peerHost, err := GetHost(peer.HostID.String())
-				if err != nil {
-					logger.Log(1, "no peer host", peer.HostID.String(), err.Error())
-					return models.HostPeerUpdate{}, err
-				}
-				peerConfig := wgtypes.PeerConfig{
-					PublicKey:                   peerHost.PublicKey,
-					PersistentKeepaliveInterval: &peer.PersistentKeepalive,
-					ReplaceAllowedIPs:           true,
-				}
-				if node.IsIngressGateway || node.IsEgressGateway {
-					if peer.IsIngressGateway {
-						_, extPeerIDAndAddrs, err := getExtPeers(&peer)
-						if err == nil {
-							for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
-								extPeerIdAndAddr := extPeerIdAndAddr
-								nodePeerMap[extPeerIdAndAddr.ID] = models.PeerRouteInfo{
-									PeerAddr: net.IPNet{
-										IP:   net.ParseIP(extPeerIdAndAddr.Address),
-										Mask: getCIDRMaskFromAddr(extPeerIdAndAddr.Address),
-									},
-									PeerKey: extPeerIdAndAddr.ID,
-									Allow:   true,
-									ID:      extPeerIdAndAddr.ID,
-								}
+			peerHost, err := GetHost(peer.HostID.String())
+			if err != nil {
+				logger.Log(1, "no peer host", peer.HostID.String(), err.Error())
+				return models.HostPeerUpdate{}, err
+			}
+			peerConfig := wgtypes.PeerConfig{
+				PublicKey:                   peerHost.PublicKey,
+				PersistentKeepaliveInterval: &peer.PersistentKeepalive,
+				ReplaceAllowedIPs:           true,
+			}
+			if node.IsIngressGateway || node.IsEgressGateway {
+				if peer.IsIngressGateway {
+					_, extPeerIDAndAddrs, err := getExtPeers(&peer)
+					if err == nil {
+						for _, extPeerIdAndAddr := range extPeerIDAndAddrs {
+							extPeerIdAndAddr := extPeerIdAndAddr
+							nodePeerMap[extPeerIdAndAddr.ID] = models.PeerRouteInfo{
+								PeerAddr: net.IPNet{
+									IP:   net.ParseIP(extPeerIdAndAddr.Address),
+									Mask: getCIDRMaskFromAddr(extPeerIdAndAddr.Address),
+								},
+								PeerKey: extPeerIdAndAddr.ID,
+								Allow:   true,
+								ID:      extPeerIdAndAddr.ID,
 							}
 						}
 					}
-					if node.IsIngressGateway && peer.IsEgressGateway {
-						hostPeerUpdate.IngressInfo.EgressRanges = append(hostPeerUpdate.IngressInfo.EgressRanges,
-							peer.EgressGatewayRanges...)
-					}
-					nodePeerMap[peerHost.PublicKey.String()] = models.PeerRouteInfo{
-						PeerAddr: net.IPNet{
-							IP:   net.ParseIP(peer.PrimaryAddress()),
-							Mask: getCIDRMaskFromAddr(peer.PrimaryAddress()),
-						},
-						PeerKey: peerHost.PublicKey.String(),
-						Allow:   true,
-						ID:      peer.ID.String(),
-					}
 				}
-				if (node.IsRelayed && node.RelayedBy != peer.ID.String()) || (peer.IsRelayed && peer.RelayedBy != node.ID.String()) {
-					// if node is relayed and peer is not the relay, set remove to true
-					if _, ok := hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()]; ok {
-						continue
-					}
-					peerConfig.Remove = true
-					hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
-					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
+				if node.IsIngressGateway && peer.IsEgressGateway {
+					hostPeerUpdate.IngressInfo.EgressRanges = append(hostPeerUpdate.IngressInfo.EgressRanges,
+						peer.EgressGatewayRanges...)
+				}
+				nodePeerMap[peerHost.PublicKey.String()] = models.PeerRouteInfo{
+					PeerAddr: net.IPNet{
+						IP:   net.ParseIP(peer.PrimaryAddress()),
+						Mask: getCIDRMaskFromAddr(peer.PrimaryAddress()),
+					},
+					PeerKey: peerHost.PublicKey.String(),
+					Allow:   true,
+					ID:      peer.ID.String(),
+				}
+			}
+			if (node.IsRelayed && node.RelayedBy != peer.ID.String()) || (peer.IsRelayed && peer.RelayedBy != node.ID.String()) {
+				// if node is relayed and peer is not the relay, set remove to true
+				if _, ok := hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()]; ok {
 					continue
 				}
+				peerConfig.Remove = true
+				hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
+				peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
+				continue
+			}
 
-				uselocal := false
-				if host.EndpointIP.String() == peerHost.EndpointIP.String() {
-					// peer is on same network
-					// set to localaddress
-					uselocal = true
-					if node.LocalAddress.IP == nil {
-						// use public endpint
-						uselocal = false
-					}
-					if node.LocalAddress.String() == peer.LocalAddress.String() {
-						uselocal = false
-					}
+			uselocal := false
+			if host.EndpointIP.String() == peerHost.EndpointIP.String() {
+				// peer is on same network
+				// set to localaddress
+				uselocal = true
+				if node.LocalAddress.IP == nil {
+					// use public endpint
+					uselocal = false
 				}
-				peerConfig.Endpoint = &net.UDPAddr{
-					IP:   peerHost.EndpointIP,
-					Port: getPeerWgListenPort(peerHost),
+				if node.LocalAddress.String() == peer.LocalAddress.String() {
+					uselocal = false
 				}
+			}
+			peerConfig.Endpoint = &net.UDPAddr{
+				IP:   peerHost.EndpointIP,
+				Port: getPeerWgListenPort(peerHost),
+			}
 
-				if uselocal {
-					peerConfig.Endpoint.IP = peer.LocalAddress.IP
-					peerConfig.Endpoint.Port = peerHost.ListenPort
+			if uselocal {
+				peerConfig.Endpoint.IP = peer.LocalAddress.IP
+				peerConfig.Endpoint.Port = peerHost.ListenPort
+			}
+			allowedips := GetAllowedIPs(&node, &peer, nil)
+			if peer.Action != models.NODE_DELETE &&
+				!peer.PendingDelete &&
+				peer.Connected &&
+				nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
+				(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
+				peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
+			}
+
+			peerProxyPort := GetProxyListenPort(peerHost)
+			var nodePeer wgtypes.PeerConfig
+			if _, ok := hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()]; !ok {
+				hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()] = make(map[string]models.IDandAddr)
+				hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
+				peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
+				hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()][peer.ID.String()] = models.IDandAddr{
+					ID:              peer.ID.String(),
+					Address:         peer.PrimaryAddress(),
+					Name:            peerHost.Name,
+					Network:         peer.Network,
+					ProxyListenPort: peerProxyPort,
 				}
-				allowedips := GetAllowedIPs(&node, &peer, nil)
-				if peer.Action != models.NODE_DELETE &&
-					!peer.PendingDelete &&
-					peer.Connected &&
-					nodeacls.AreNodesAllowed(nodeacls.NetworkID(node.Network), nodeacls.NodeID(node.ID.String()), nodeacls.NodeID(peer.ID.String())) &&
-					(deletedNode == nil || (deletedNode != nil && peer.ID.String() != deletedNode.ID.String())) {
-					peerConfig.AllowedIPs = allowedips // only append allowed IPs if valid connection
+				hostPeerUpdate.HostNetworkInfo[peerHost.PublicKey.String()] = models.HostNetworkInfo{
+					Interfaces:      peerHost.Interfaces,
+					ProxyListenPort: peerProxyPort,
 				}
-
-				peerProxyPort := GetProxyListenPort(peerHost)
-				var nodePeer wgtypes.PeerConfig
-				if _, ok := hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()]; !ok {
-					hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()] = make(map[string]models.IDandAddr)
-					hostPeerUpdate.Peers = append(hostPeerUpdate.Peers, peerConfig)
-					peerIndexMap[peerHost.PublicKey.String()] = len(hostPeerUpdate.Peers) - 1
-					hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()][peer.ID.String()] = models.IDandAddr{
-						ID:              peer.ID.String(),
-						Address:         peer.PrimaryAddress(),
-						Name:            peerHost.Name,
-						Network:         peer.Network,
-						ProxyListenPort: peerProxyPort,
-					}
-					hostPeerUpdate.HostNetworkInfo[peerHost.PublicKey.String()] = models.HostNetworkInfo{
-						Interfaces:      peerHost.Interfaces,
-						ProxyListenPort: peerProxyPort,
-					}
-					nodePeer = peerConfig
-				} else {
-					peerAllowedIPs := hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].AllowedIPs
-					peerAllowedIPs = append(peerAllowedIPs, peerConfig.AllowedIPs...)
-					hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].AllowedIPs = peerAllowedIPs
-					hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].Remove = false
-					hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()][peer.ID.String()] = models.IDandAddr{
-						ID:              peer.ID.String(),
-						Address:         peer.PrimaryAddress(),
-						Name:            peerHost.Name,
-						Network:         peer.Network,
-						ProxyListenPort: GetProxyListenPort(peerHost),
-					}
-					hostPeerUpdate.HostNetworkInfo[peerHost.PublicKey.String()] = models.HostNetworkInfo{
-						Interfaces:      peerHost.Interfaces,
-						ProxyListenPort: peerProxyPort,
-					}
-					nodePeer = hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]]
+				nodePeer = peerConfig
+			} else {
+				peerAllowedIPs := hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].AllowedIPs
+				peerAllowedIPs = append(peerAllowedIPs, peerConfig.AllowedIPs...)
+				hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].AllowedIPs = peerAllowedIPs
+				hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]].Remove = false
+				hostPeerUpdate.HostPeerIDs[peerHost.PublicKey.String()][peer.ID.String()] = models.IDandAddr{
+					ID:              peer.ID.String(),
+					Address:         peer.PrimaryAddress(),
+					Name:            peerHost.Name,
+					Network:         peer.Network,
+					ProxyListenPort: GetProxyListenPort(peerHost),
+				}
+				hostPeerUpdate.HostNetworkInfo[peerHost.PublicKey.String()] = models.HostNetworkInfo{
+					Interfaces:      peerHost.Interfaces,
+					ProxyListenPort: peerProxyPort,
 				}
+				nodePeer = hostPeerUpdate.Peers[peerIndexMap[peerHost.PublicKey.String()]]
+			}
 
-				if node.Network == network { // add to peers map for metrics
-					hostPeerUpdate.PeerIDs[peerHost.PublicKey.String()] = models.IDandAddr{
-						ID:              peer.ID.String(),
-						Address:         peer.PrimaryAddress(),
-						Name:            peerHost.Name,
-						Network:         peer.Network,
-						ProxyListenPort: peerHost.ProxyListenPort,
-					}
-					hostPeerUpdate.NodePeers = append(hostPeerUpdate.NodePeers, nodePeer)
+			if node.Network == network { // add to peers map for metrics
+				hostPeerUpdate.PeerIDs[peerHost.PublicKey.String()] = models.IDandAddr{
+					ID:              peer.ID.String(),
+					Address:         peer.PrimaryAddress(),
+					Name:            peerHost.Name,
+					Network:         peer.Network,
+					ProxyListenPort: peerHost.ProxyListenPort,
 				}
+				hostPeerUpdate.NodePeers = append(hostPeerUpdate.NodePeers, nodePeer)
 			}
+			//}
 		}
 		var extPeers []wgtypes.PeerConfig
 		var extPeerIDAndAddrs []models.IDandAddr

+ 10 - 2
mq/handlers.go

@@ -107,7 +107,11 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
 						return
 					}
 				}
-				if err = PublishSingleHostPeerUpdate(context.Background(), currentHost, nil, nil); err != nil {
+				nodes, err := logic.GetAllNodes()
+				if err != nil {
+					return
+				}
+				if err = PublishSingleHostPeerUpdate(context.Background(), currentHost, nodes, nil, nil); err != nil {
 					slog.Error("failed peers publish after join acknowledged", "name", hostUpdate.Host.Name, "id", currentHost.ID, "error", err)
 					return
 				}
@@ -235,7 +239,11 @@ func UpdateMetrics(client mqtt.Client, msg mqtt.Message) {
 			slog.Info("updating peers after node detected connectivity issues", "id", currentNode.ID, "network", currentNode.Network)
 			host, err := logic.GetHost(currentNode.HostID.String())
 			if err == nil {
-				if err = PublishSingleHostPeerUpdate(context.Background(), host, nil, nil); err != nil {
+				nodes, err := logic.GetAllNodes()
+				if err != nil {
+					return
+				}
+				if err = PublishSingleHostPeerUpdate(context.Background(), host, nodes, nil, nil); err != nil {
 					slog.Warn("failed to publish update after failover peer change for node", "id", currentNode.ID, "network", currentNode.Network, "error", err)
 				}
 			}

+ 22 - 7
mq/publishers.go

@@ -24,10 +24,14 @@ func PublishPeerUpdate() error {
 		logger.Log(1, "err getting all hosts", err.Error())
 		return err
 	}
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		return err
+	}
 	logic.ResetPeerUpdateContext()
 	for _, host := range hosts {
 		host := host
-		if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, nil, nil); err != nil {
+		if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, allNodes, nil, nil); err != nil {
 			logger.Log(1, "failed to publish peer update to host", host.ID.String(), ": ", err.Error())
 		}
 	}
@@ -46,10 +50,14 @@ func PublishDeletedNodePeerUpdate(delNode *models.Node) error {
 		logger.Log(1, "err getting all hosts", err.Error())
 		return err
 	}
+	allNodes, err := logic.GetAllNodes()
+	if err != nil {
+		return err
+	}
 	logic.ResetPeerUpdateContext()
 	for _, host := range hosts {
 		host := host
-		if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, delNode, nil); err != nil {
+		if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, allNodes, delNode, nil); err != nil {
 			logger.Log(1, "failed to publish peer update to host", host.ID.String(), ": ", err.Error())
 		}
 	}
@@ -68,10 +76,14 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
 		logger.Log(1, "err getting all hosts", err.Error())
 		return err
 	}
+	nodes, err := logic.GetAllNodes()
+	if err != nil {
+		return err
+	}
 	logic.ResetPeerUpdateContext()
 	for _, host := range hosts {
 		host := host
-		if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, nil, []models.ExtClient{*delClient}); err != nil {
+		if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, nodes, nil, []models.ExtClient{*delClient}); err != nil {
 			logger.Log(1, "failed to publish peer update to host", host.ID.String(), ": ", err.Error())
 		}
 	}
@@ -79,9 +91,9 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error {
 }
 
 // PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host
-func PublishSingleHostPeerUpdate(ctx context.Context, host *models.Host, deletedNode *models.Node, deletedClients []models.ExtClient) error {
+func PublishSingleHostPeerUpdate(ctx context.Context, host *models.Host, allNodes []models.Node, deletedNode *models.Node, deletedClients []models.ExtClient) error {
 
-	peerUpdate, err := logic.GetPeerUpdateForHost(ctx, "", host, deletedNode, deletedClients)
+	peerUpdate, err := logic.GetPeerUpdateForHost(ctx, "", host, allNodes, deletedNode, deletedClients)
 	if err != nil {
 		return err
 	}
@@ -436,7 +448,10 @@ func sendPeers() {
 	if err != nil && len(hosts) > 0 {
 		logger.Log(1, "error retrieving networks for keepalive", err.Error())
 	}
-
+	nodes, err := logic.GetAllNodes()
+	if err != nil {
+		return
+	}
 	var force bool
 	peer_force_send++
 	if peer_force_send == 5 {
@@ -455,7 +470,7 @@ func sendPeers() {
 		for _, host := range hosts {
 			host := host
 			logger.Log(2, "sending scheduled peer update (5 min)")
-			if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, nil, nil); err != nil {
+			if err = PublishSingleHostPeerUpdate(logic.PeerUpdateCtx, &host, nodes, nil, nil); err != nil {
 				logger.Log(1, "error publishing peer updates for host: ", host.ID.String(), " Err: ", err.Error())
 			}
 		}