浏览代码

fix middleware for auth

abhishek9686 1 年之前
父节点
当前提交
cd56333f04
共有 3 个文件被更改,包括 18 次插入2 次删除
  1. 4 0
      controllers/middleware.go
  2. 11 2
      pro/controllers/users.go
  3. 3 0
      pro/logic/security.go

+ 4 - 0
controllers/middleware.go

@@ -28,6 +28,7 @@ func userMiddleWare(handler http.Handler) http.Handler {
 		r.Header.Set("TARGET_RSRC", "")
 		r.Header.Set("RSRC_TYPE", "")
 		r.Header.Set("TARGET_RSRC_ID", "")
+		r.Header.Set("RAC", "")
 		r.Header.Set("NET_ID", params["network"])
 		if strings.Contains(route, "hosts") || strings.Contains(route, "nodes") {
 			r.Header.Set("TARGET_RSRC", models.HostRsrc.String())
@@ -36,6 +37,9 @@ func userMiddleWare(handler http.Handler) http.Handler {
 			r.Header.Set("TARGET_RSRC", models.DnsRsrc.String())
 		}
 		if strings.Contains(route, "users") {
+			if strings.Contains(route, "remote_access_gw") {
+				r.Header.Set("RAC", "true")
+			}
 			r.Header.Set("TARGET_RSRC", models.UserRsrc.String())
 		}
 		if strings.Contains(route, "ingress") {

+ 11 - 2
pro/controllers/users.go

@@ -59,7 +59,7 @@ func UserHandlers(r *mux.Router) {
 	r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(true, http.HandlerFunc(attachUserToRemoteAccessGw))).Methods(http.MethodPost)
 	r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/users/{username}/remote_access_gw", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGwsV1)))).Methods(http.MethodGet)
-	r.HandleFunc("/api/v1/users/{username}/remote_access_gw/networks", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessNetworks)))).Methods(http.MethodGet)
+	r.HandleFunc("/api/v1/users/{username}/remote_access_gw_network", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessNetworks)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users/{username}/remote_access_gw/network/{network}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessNetworkGateways)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getRemoteAccessGatewayConf)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users/ingress/{ingress_id}", logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers))).Methods(http.MethodGet)
@@ -949,6 +949,11 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to fetch user %s, error: %v", username, err), "badrequest"))
 		return
 	}
+	userGwNodes := proLogic.GetUserRAGNodes(*user)
+	if _, ok := userGwNodes[remoteGwID]; !ok {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("access denied"), "forbidden"))
+		return
+	}
 	node, err := logic.GetNodeByID(remoteGwID)
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to fetch gw node %s, error: %v", remoteGwID, err), "badrequest"))
@@ -998,7 +1003,11 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 		listenPort := logic.GetPeerListenPort(host)
-		userConf.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort)
+		if host.EndpointIP.To4() == nil {
+			userConf.IngressGatewayEndpoint = fmt.Sprintf("[%s]:%d", host.EndpointIPv6.String(), listenPort)
+		} else {
+			userConf.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort)
+		}
 		userConf.Enabled = true
 		parentNetwork, err := logic.GetNetwork(node.Network)
 		if err == nil { // check if parent network default ACL is enabled (yes) or not (no)

+ 3 - 0
pro/logic/security.go

@@ -50,6 +50,9 @@ func NetworkPermissionsCheck(username string, r *http.Request) error {
 	if targetRsrc == "" {
 		return errors.New("target rsrc is missing")
 	}
+	if r.Header.Get("RAC") == "true" && r.Method == http.MethodGet {
+		return nil
+	}
 	if netID == "" {
 		return errors.New("network id is missing")
 	}