Browse Source

calculate allowed ips

Matthew R Kasun 3 years ago
parent
commit
fc86015c29
1 changed files with 135 additions and 42 deletions
  1. 135 42
      mq/mq.go

+ 135 - 42
mq/mq.go

@@ -3,6 +3,7 @@ package mq
 import (
 import (
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
+	"log"
 	"net"
 	"net"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
@@ -13,6 +14,7 @@ import (
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/servercfg"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 )
@@ -42,8 +44,8 @@ var Ping mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
 			return
 			return
 		}
 		}
 		node.SetLastCheckIn()
 		node.SetLastCheckIn()
-		if err := logic.UpdateNode(&node, &node) ; err != nil {
-			logger.Log(0, "error updating node "+ err.Error())
+		if err := logic.UpdateNode(&node, &node); err != nil {
+			logger.Log(0, "error updating node "+err.Error())
 		}
 		}
 		logger.Log(0, "ping processed")
 		logger.Log(0, "ping processed")
 		// --TODO --set client version once feature is implemented.
 		// --TODO --set client version once feature is implemented.
@@ -66,8 +68,8 @@ var PublicKeyUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Mess
 		}
 		}
 		node.PublicKey = key
 		node.PublicKey = key
 		node.SetLastCheckIn()
 		node.SetLastCheckIn()
-		if err := logic.UpdateNode(&node, &node) ; err != nil {
-			logger.Log(0, "error updating node "+ err.Error())
+		if err := logic.UpdateNode(&node, &node); err != nil {
+			logger.Log(0, "error updating node "+err.Error())
 		}
 		}
 		if err := UpdatePeers(client, node); err != nil {
 		if err := UpdatePeers(client, node); err != nil {
 			logger.Log(0, "error updating peers "+err.Error())
 			logger.Log(0, "error updating peers "+err.Error())
@@ -92,8 +94,8 @@ var IPUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
 		}
 		}
 		node.Endpoint = ip
 		node.Endpoint = ip
 		node.SetLastCheckIn()
 		node.SetLastCheckIn()
-		if err := logic.UpdateNode(&node, &node) ; err != nil {
-			logger.Log(0, "error updating node "+ err.Error())
+		if err := logic.UpdateNode(&node, &node); err != nil {
+			logger.Log(0, "error updating node "+err.Error())
 		}
 		}
 		if err != UpdatePeers(client, node) {
 		if err != UpdatePeers(client, node) {
 			logger.Log(0, "error updating peers "+err.Error())
 			logger.Log(0, "error updating peers "+err.Error())
@@ -106,50 +108,141 @@ func UpdatePeers(client mqtt.Client, newnode models.Node) error {
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-        keepalive, _ := time.ParseDuration(string(newnode.PersistentKeepalive)+"s")
-        for _, node := range  networkNodes {
-                var peers []wgtypes.PeerConfig
+	dualstack := false
+	keepalive, _ := time.ParseDuration(string(newnode.PersistentKeepalive) + "s")
+	defaultkeepalive, _ := time.ParseDuration("25s")
+	for _, node := range networkNodes {
+		var peers []wgtypes.PeerConfig
 		var peerUpdate models.PeerUpdate
 		var peerUpdate models.PeerUpdate
-                for _, peer := range  networkNodes{
-                        if peer.ID == node.ID {
-                                //skip
-                                continue
-                        }
-                        pubkey, err := wgtypes.ParseKey(peer.PublicKey)
-                        if err != nil {
+		var gateways []string
+
+		for _, peer := range networkNodes {
+			if peer.ID == node.ID {
+				//skip
+				continue
+			}
+			var allowedips []net.IPNet
+			var peeraddr = net.IPNet{
+				IP:   net.ParseIP(peer.Address),
+				Mask: net.CIDRMask(32, 32),
+			}
+			//hasGateway := false
+			pubkey, err := wgtypes.ParseKey(peer.PublicKey)
+			if err != nil {
 				return err
 				return err
-                        }
-                        if node.Endpoint == peer.Endpoint {
-                                if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
-                                        peer.Endpoint = peer.LocalAddress
-                                }else {
-                                        continue
-                                }
-                        }
-                        endpoint := peer.Endpoint + ":" + strconv.Itoa(int(peer.ListenPort))
-                        //fmt.Println("endpoint: ", endpoint, peer.Endpoint, peer.ListenPort)
-                        address, err := net.ResolveUDPAddr("udp", endpoint)
-                        if err != nil {
+			}
+			if node.Endpoint == peer.Endpoint {
+				if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
+					peer.Endpoint = peer.LocalAddress
+				} else {
+					continue
+				}
+			}
+			endpoint := peer.Endpoint + ":" + strconv.Itoa(int(peer.ListenPort))
+			//fmt.Println("endpoint: ", endpoint, peer.Endpoint, peer.ListenPort)
+			address, err := net.ResolveUDPAddr("udp", endpoint)
+			if err != nil {
 				return err
 				return err
-                        }
-                        //calculate Allowed IPs.
-                        var peerData wgtypes.PeerConfig
-                        peerData = wgtypes.PeerConfig{
-                                PublicKey: pubkey,
-                                Endpoint: address,
-                                PersistentKeepaliveInterval: &keepalive,
-                                //AllowedIPs: allowedIPs
-                        }
-                        peers = append (peers, peerData)
-                }
+			}
+			//calculate Allowed IPs.
+			allowedips = append(allowedips, peeraddr)
+			// handle manually set peers
+			for _, allowedIp := range node.AllowedIPs {
+				if _, ipnet, err := net.ParseCIDR(allowedIp); err == nil {
+					nodeEndpointArr := strings.Split(node.Endpoint, ":")
+					if !ipnet.Contains(net.IP(nodeEndpointArr[0])) && ipnet.IP.String() != node.Address { // don't need to add an allowed ip that already exists..
+						allowedips = append(allowedips, *ipnet)
+					}
+				} else if appendip := net.ParseIP(allowedIp); appendip != nil && allowedIp != node.Address {
+					ipnet := net.IPNet{
+						IP:   net.ParseIP(allowedIp),
+						Mask: net.CIDRMask(32, 32),
+					}
+					allowedips = append(allowedips, ipnet)
+				}
+			}
+			// handle egress gateway peers
+			if node.IsEgressGateway == "yes" {
+				//hasGateway = true
+				ranges := node.EgressGatewayRanges
+				for _, iprange := range ranges { // go through each cidr for egress gateway
+					_, ipnet, err := net.ParseCIDR(iprange) // confirming it's valid cidr
+					if err != nil {
+						ncutils.PrintLog("could not parse gateway IP range. Not adding "+iprange, 1)
+						continue // if can't parse CIDR
+					}
+					nodeEndpointArr := strings.Split(node.Endpoint, ":") // getting the public ip of node
+					if ipnet.Contains(net.ParseIP(nodeEndpointArr[0])) { // ensuring egress gateway range does not contain public ip of node
+						ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.Endpoint+", omitting", 2)
+						continue // skip adding egress range if overlaps with node's ip
+					}
+					if ipnet.Contains(net.ParseIP(node.LocalAddress)) { // ensuring egress gateway range does not contain public ip of node
+						ncutils.PrintLog("egress IP range of "+iprange+" overlaps with "+node.LocalAddress+", omitting", 2)
+						continue // skip adding egress range if overlaps with node's local ip
+					}
+					gateways = append(gateways, iprange)
+					if err != nil {
+						log.Println("ERROR ENCOUNTERED SETTING GATEWAY")
+					} else {
+						allowedips = append(allowedips, *ipnet)
+					}
+				}
+			}
+			var peerData wgtypes.PeerConfig
+			if node.Address6 != "" && dualstack {
+				var addr6 = net.IPNet{
+					IP:   net.ParseIP(node.Address6),
+					Mask: net.CIDRMask(128, 128),
+				}
+				allowedips = append(allowedips, addr6)
+			}
+			if node.IsServer == "yes" && !(node.IsServer == "yes") {
+				peerData = wgtypes.PeerConfig{
+					PublicKey:                   pubkey,
+					PersistentKeepaliveInterval: &defaultkeepalive,
+					ReplaceAllowedIPs:           true,
+					AllowedIPs:                  allowedips,
+				}
+			} else if keepalive != 0 {
+				peerData = wgtypes.PeerConfig{
+					PublicKey:                   pubkey,
+					PersistentKeepaliveInterval: &defaultkeepalive,
+					//Endpoint: &net.UDPAddr{
+					//	IP:   net.ParseIP(node.Endpoint),
+					//	Port: int(node.ListenPort),
+					//},
+					Endpoint:          address,
+					ReplaceAllowedIPs: true,
+					AllowedIPs:        allowedips,
+				}
+			} else {
+				peerData = wgtypes.PeerConfig{
+					PublicKey: pubkey,
+					//Endpoint: &net.UDPAddr{
+					//	IP:   net.ParseIP(node.Endpoint),
+					//	Port: int(node.ListenPort),
+					//},
+					Endpoint:          address,
+					ReplaceAllowedIPs: true,
+					AllowedIPs:        allowedips,
+				}
+			}
+			//peerData = wgtypes.PeerConfig{
+			//	PublicKey:                   pubkey,
+			//	Endpoint:                    address,
+			//	PersistentKeepaliveInterval: &keepalive,
+			//AllowedIPs: allowedIPs
+			//}
+			peers = append(peers, peerData)
+		}
 		peerUpdate.Network = node.Network
 		peerUpdate.Network = node.Network
-		peerUpdate.Peers = peers 
+		peerUpdate.Peers = peers
 		data, err := json.Marshal(&peerUpdate)
 		data, err := json.Marshal(&peerUpdate)
 		if err != nil {
 		if err != nil {
 			logger.Log(0, "error marshaling peer update "+err.Error())
 			logger.Log(0, "error marshaling peer update "+err.Error())
 			return err
 			return err
 		}
 		}
-			if token := client.Publish("/update/peers/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil {
+		if token := client.Publish("/update/peers/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil {
 			logger.Log(0, "error sending peer updatte to no")
 			logger.Log(0, "error sending peer updatte to no")
 			return err
 			return err
 		}
 		}
@@ -198,7 +291,7 @@ func NewPeer(node models.Node) error {
 	if token := client.Connect(); token.Wait() && token.Error() != nil {
 	if token := client.Connect(); token.Wait() && token.Error() != nil {
 		return token.Error()
 		return token.Error()
 	}
 	}
-	
+
 	if err := UpdatePeers(client, node); err != nil {
 	if err := UpdatePeers(client, node); err != nil {
 		return err
 		return err
 	}
 	}