Переглянути джерело

logic for removing internet gateway

Matthew R. Kasun 3 роки тому
батько
коміт
b4f8827304

+ 8 - 6
netclient/local/routes.go

@@ -63,10 +63,12 @@ func RemoveCIDRRoute(iface, currentAddr string, cidr *net.IPNet) {
 	removeCidr(iface, cidr, currentAddr)
 }
 
-// SetDefaultRoute - sets the default route when peer is internet gateway
-func SetDefaultRoute(iface string, peer wgtypes.PeerConfig) error {
-	if err := setDefaultRoute(iface, peer); err != nil {
-		return err
-	}
-	return nil
+// SetInternetGatewayRoute - sets the default route when peer is internet gateway
+func SetInternetGatewayRoute(iface, port string, peer wgtypes.PeerConfig) error {
+	return setInternetGatewayRoute(iface, port, peer)
+}
+
+// RemoveInternetGatewayRoute -- deletes routes when internet gateway is removed from peer
+func RemoveInternetGatewayRoute(iface, port string, peer wgtypes.PeerConfig) error {
+	return removeInternetGatewayRoute(iface, port, peer)
 }

+ 15 - 3
netclient/local/routes_linux.go

@@ -44,11 +44,23 @@ func removeCidr(iface string, addr *net.IPNet, address string) {
 	ncutils.RunCmd("ip route delete "+addr.String()+" dev "+iface, false)
 }
 
-func setDefaultRoute(iface string, peer wgtypes.PeerConfig) error {
-	cmd := "wg set " + iface + " fwmark 1234"
-	cmd += ";ip route add default dev " + iface + " table 2468"
+func setInternetGatewayRoute(iface, port string, peer wgtypes.PeerConfig) error {
+	cmd := "wg set " + iface + " fwmark " + port
+	cmd += ";ip route add default dev " + iface + " table " + port
 	cmd += ";ip rule add not fwmark 1234 table 2468"
 	cmd += ";ip rule add table main suppress_prefixlength 0"
+	cmd += ";iptables-restore -n"
+	if _, err := ncutils.RunCmd(cmd, true); err != nil {
+		return err
+	}
+	return nil
+}
+
+func removeInternetGatewayRoute(iface, port string, peer wgtypes.PeerConfig) error {
+	cmd := "ip -4 rule delete table " + port
+	cmd += ";ip -4 rule delete table main suppress_prefixlength 0"
+	cmd += ":ip link del dev " + iface
+	cmd += ";iptables-restore -n"
 	if _, err := ncutils.RunCmd(cmd, true); err != nil {
 		return err
 	}

+ 37 - 2
netclient/wireguard/common.go

@@ -111,6 +111,11 @@ func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error
 					if peer.PublicKey.String() == currentPeer.PublicKey.String() {
 						shouldDelete = false
 					}
+					if shouldDeleteInternetGateway(peer.AllowedIPs, currentPeer.AllowedIPs) {
+						if local.RemoveInternetGatewayRoute(node.Interface, strconv.Itoa(int(node.ListenPort)), peer); err != nil {
+							logger.Log(0, "failed to remove internet gateways routes", err.Error())
+						}
+					}
 				}
 				if shouldDelete {
 					output, err := ncutils.RunCmd("wg set "+iface+" peer "+currentPeer.PublicKey.String()+" remove", true)
@@ -124,6 +129,7 @@ func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error
 			}
 		}
 	}
+	//TODO === why only Mac/Linux????
 	if ncutils.IsMac() {
 		err = SetMacPeerRoutes(iface)
 		return err
@@ -134,11 +140,10 @@ func SetPeers(iface string, node *models.Node, peers []wgtypes.PeerConfig) error
 	}
 	//check if internet gateway
 	if internetGateway {
-		if err := local.SetDefaultRoute(iface, gateway); err != nil {
+		if err := local.SetInternetGatewayRoute(node.Interface, strconv.Itoa(int(node.ListenPort)), gateway); err != nil {
 			return err
 		}
 	}
-
 	return nil
 }
 
@@ -571,3 +576,33 @@ func GetDevicePeers(iface string) ([]wgtypes.Peer, error) {
 		return device.Peers, nil
 	}
 }
+
+func shouldDeleteInternetGateway(new, current []net.IPNet) bool {
+	oldv4gatewayExists := false
+	newv4gatewayExists := false
+	oldv6gatewayExists := false
+	newv6gatewayExists := false
+	for _, ip := range current {
+		if ip.String() == "0.0.0.0/0" {
+			oldv4gatewayExists = true
+		}
+		if ip.String() == "::/0" {
+			oldv6gatewayExists = true
+		}
+	}
+	for _, ip := range new {
+		if ip.String() == "0.0.0.0/0" {
+			newv4gatewayExists = true
+		}
+		if ip.String() == "::/0" {
+			newv6gatewayExists = true
+		}
+	}
+	if oldv4gatewayExists && !newv4gatewayExists {
+		return false
+	}
+	if oldv6gatewayExists && !newv6gatewayExists {
+		return false
+	}
+	return true
+}