Selaa lähdekoodia

feat(NET-1106): support additional RAG endpoint IPs (#2907)

Aceix 1 vuosi sitten
vanhempi
commit
abf3f4f55d
7 muutettua tiedostoa jossa 81 lisäystä ja 4 poistoa
  1. 28 3
      controllers/ext_client.go
  2. 7 1
      controllers/node.go
  3. 6 0
      models/api_host.go
  4. 14 0
      models/api_node.go
  5. 1 0
      models/node.go
  6. 5 0
      models/structs.go
  7. 20 0
      pro/controllers/users.go

+ 28 - 3
controllers/ext_client.go

@@ -21,6 +21,7 @@ import (
 
 	"github.com/gravitl/netmaker/mq"
 	"github.com/skip2/go-qrcode"
+	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slog"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
@@ -199,6 +200,24 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
+	preferredIp := strings.TrimSpace(r.URL.Query().Get("preferredip"))
+	if preferredIp != "" {
+		allowedPreferredIps := []string{}
+		for i := range gwnode.AdditionalRagIps {
+			allowedPreferredIps = append(allowedPreferredIps, gwnode.AdditionalRagIps[i].String())
+		}
+		allowedPreferredIps = append(allowedPreferredIps, host.EndpointIP.String())
+		allowedPreferredIps = append(allowedPreferredIps, host.EndpointIPv6.String())
+		if !slices.Contains(allowedPreferredIps, preferredIp) {
+			slog.Warn("preferred endpoint ip is not associated with the RAG. proceeding with preferred ip", "preferred ip", preferredIp)
+			logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("preferred endpoint ip is not associated with the RAG"), "badrequest"))
+			return
+		}
+		if net.ParseIP(preferredIp).To4() == nil {
+			preferredIp = fmt.Sprintf("[%s]", preferredIp)
+		}
+	}
+
 	addrString := client.Address
 	if addrString != "" {
 		addrString += "/32"
@@ -214,12 +233,18 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 	if network.DefaultKeepalive != 0 {
 		keepalive = "PersistentKeepalive = " + strconv.Itoa(int(network.DefaultKeepalive))
 	}
+
 	gwendpoint := ""
-	if host.EndpointIP.To4() == nil {
-		gwendpoint = fmt.Sprintf("[%s]:%d", host.EndpointIP.String(), host.ListenPort)
+	if preferredIp == "" {
+		if host.EndpointIP.To4() == nil {
+			gwendpoint = fmt.Sprintf("[%s]:%d", host.EndpointIPv6.String(), host.ListenPort)
+		} else {
+			gwendpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), host.ListenPort)
+		}
 	} else {
-		gwendpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), host.ListenPort)
+		gwendpoint = fmt.Sprintf("%s:%d", preferredIp, host.ListenPort)
 	}
+
 	var newAllowedIPs string
 	if logic.IsInternetGw(gwnode) || gwnode.InternetGwID != "" {
 		egressrange := "0.0.0.0/0"

+ 7 - 1
controllers/node.go

@@ -635,14 +635,20 @@ func updateNode(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("metadata cannot be longer than 255 characters"), "badrequest"))
 		return
 	}
+	if !servercfg.IsPro {
+		newData.AdditionalRagIps = []string{}
+	}
 	newNode := newData.ConvertToServerNode(&currentNode)
+	if newNode == nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("error converting node"), "badrequest"))
+		return
+	}
 	if newNode.IsInternetGateway != currentNode.IsInternetGateway {
 		if newNode.IsInternetGateway {
 			logic.SetInternetGw(newNode, models.InetNodeReq{})
 		} else {
 			logic.UnsetInternetGw(newNode)
 		}
-
 	}
 	relayUpdate := logic.RelayUpdates(&currentNode, newNode)
 	if relayUpdate && newNode.IsRelay {

+ 6 - 0
models/api_host.go

@@ -44,7 +44,13 @@ func (h *Host) ConvertNMHostToAPI() *ApiHost {
 	a := ApiHost{}
 	a.Debug = h.Debug
 	a.EndpointIP = h.EndpointIP.String()
+	if a.EndpointIP == "<nil>" {
+		a.EndpointIP = ""
+	}
 	a.EndpointIPv6 = h.EndpointIPv6.String()
+	if a.EndpointIPv6 == "<nil>" {
+		a.EndpointIPv6 = ""
+	}
 	a.FirewallInUse = h.FirewallInUse
 	a.ID = h.ID.String()
 	a.Interfaces = make([]ApiIface, len(h.Interfaces))

+ 14 - 0
models/api_node.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/google/uuid"
+	"golang.org/x/exp/slog"
 )
 
 // ApiNode is a stripped down Node DTO that exposes only required fields to external systems
@@ -44,6 +45,7 @@ type ApiNode struct {
 	IsInternetGateway bool                `json:"isinternetgateway" yaml:"isinternetgateway"`
 	InetNodeReq       InetNodeReq         `json:"inet_node_req" yaml:"inet_node_req"`
 	InternetGwID      string              `json:"internetgw_node_id" yaml:"internetgw_node_id"`
+	AdditionalRagIps  []string            `json:"additional_rag_ips" yaml:"additional_rag_ips"`
 }
 
 // ApiNode.ConvertToServerNode - converts an api node to a server node
@@ -109,6 +111,14 @@ func (a *ApiNode) ConvertToServerNode(currentNode *Node) *Node {
 	convertedNode.LastPeerUpdate = time.Unix(a.LastPeerUpdate, 0)
 	convertedNode.ExpirationDateTime = time.Unix(a.ExpirationDateTime, 0)
 	convertedNode.Metadata = a.Metadata
+	for _, ip := range a.AdditionalRagIps {
+		ragIp := net.ParseIP(ip)
+		if ragIp == nil {
+			slog.Error("error parsing additional rag ip", "error", err, "ip", ip)
+			return nil
+		}
+		convertedNode.AdditionalRagIps = append(convertedNode.AdditionalRagIps, ragIp)
+	}
 	return &convertedNode
 }
 
@@ -163,6 +173,10 @@ func (nm *Node) ConvertToAPINode() *ApiNode {
 	apiNode.FailOverPeers = nm.FailOverPeers
 	apiNode.FailedOverBy = nm.FailedOverBy
 	apiNode.Metadata = nm.Metadata
+	apiNode.AdditionalRagIps = []string{}
+	for _, ip := range nm.AdditionalRagIps {
+		apiNode.AdditionalRagIps = append(apiNode.AdditionalRagIps, ip.String())
+	}
 	return &apiNode
 }
 

+ 1 - 0
models/node.go

@@ -96,6 +96,7 @@ type Node struct {
 	IsInternetGateway bool                `json:"isinternetgateway" yaml:"isinternetgateway"`
 	InetNodeReq       InetNodeReq         `json:"inet_node_req" yaml:"inet_node_req"`
 	InternetGwID      string              `json:"internetgw_node_id" yaml:"internetgw_node_id"`
+	AdditionalRagIps  []net.IP            `json:"additional_rag_ips" yaml:"additional_rag_ips"`
 }
 
 // LegacyNode - legacy struct for node model

+ 5 - 0
models/structs.go

@@ -73,6 +73,7 @@ type UserRemoteGws struct {
 	GwClient          ExtClient `json:"gw_client"`
 	GwPeerPublicKey   string    `json:"gw_peer_public_key"`
 	Metadata          string    `json:"metadata"`
+	AllowedEndpoints  []string  `json:"allowed_endpoints"`
 }
 
 // UserRemoteGwsReq - struct to hold user remote acccess gws req
@@ -372,3 +373,7 @@ type LoginReqDto struct {
 const (
 	ResHeaderKeyStAccessToken = "St-Access-Token"
 )
+
+type GetClientConfReqDto struct {
+	PreferredIp string `json:"preferred_ip"`
+}

+ 20 - 0
pro/controllers/users.go

@@ -227,6 +227,7 @@ func getUserRemoteAccessGws(w http.ResponseWriter, r *http.Request) {
 					IsInternetGateway: node.IsInternetGateway,
 					GwPeerPublicKey:   host.PublicKey.String(),
 					Metadata:          node.Metadata,
+					AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
 				})
 				userGws[node.Network] = gws
 				delete(user.RemoteGwIDs, node.ID.String())
@@ -242,6 +243,7 @@ func getUserRemoteAccessGws(w http.ResponseWriter, r *http.Request) {
 					IsInternetGateway: node.IsInternetGateway,
 					GwPeerPublicKey:   host.PublicKey.String(),
 					Metadata:          node.Metadata,
+					AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
 				})
 				userGws[node.Network] = gws
 				processedAdminNodeIds[node.ID.String()] = struct{}{}
@@ -275,6 +277,7 @@ func getUserRemoteAccessGws(w http.ResponseWriter, r *http.Request) {
 				IsInternetGateway: node.IsInternetGateway,
 				GwPeerPublicKey:   host.PublicKey.String(),
 				Metadata:          node.Metadata,
+				AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
 			})
 			userGws[node.Network] = gws
 		}
@@ -302,6 +305,7 @@ func getUserRemoteAccessGws(w http.ResponseWriter, r *http.Request) {
 					IsInternetGateway: node.IsInternetGateway,
 					GwPeerPublicKey:   host.PublicKey.String(),
 					Metadata:          node.Metadata,
+					AllowedEndpoints:  getAllowedRagEndpoints(&node, host),
 				})
 				userGws[node.Network] = gws
 			}
@@ -352,3 +356,19 @@ func ingressGatewayUsers(w http.ResponseWriter, r *http.Request) {
 	w.WriteHeader(http.StatusOK)
 	json.NewEncoder(w).Encode(gwUsers)
 }
+
+func getAllowedRagEndpoints(ragNode *models.Node, ragHost *models.Host) []string {
+	endpoints := []string{}
+	if len(ragHost.EndpointIP) > 0 {
+		endpoints = append(endpoints, ragHost.EndpointIP.String())
+	}
+	if len(ragHost.EndpointIPv6) > 0 {
+		endpoints = append(endpoints, ragHost.EndpointIPv6.String())
+	}
+	if servercfg.IsPro {
+		for _, ip := range ragNode.AdditionalRagIps {
+			endpoints = append(endpoints, ip.String())
+		}
+	}
+	return endpoints
+}