Sfoglia il codice sorgente

filter network based on user access

abhishek9686 1 anno fa
parent
commit
80d9dd6357
6 ha cambiato i file con 108 aggiunte e 7 eliminazioni
  1. 3 0
      controllers/middleware.go
  2. 40 0
      controllers/network.go
  3. 39 2
      controllers/user.go
  4. 20 4
      logic/nodes.go
  5. 1 1
      logic/security.go
  6. 5 0
      models/user_mgmt.go

+ 3 - 0
controllers/middleware.go

@@ -13,6 +13,9 @@ func userMiddleWare(handler http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		var params = mux.Vars(r)
 		r.Header.Set("IS_GLOBAL_ACCESS", "no")
+		r.Header.Set("TARGET_RSRC", "")
+		r.Header.Set("RSRC_TYPE", "")
+		r.Header.Set("TARGET_RSRC_ID", "")
 		r.Header.Set("NET_ID", params["network"])
 		if strings.Contains(r.URL.Path, "hosts") || strings.Contains(r.URL.Path, "nodes") {
 			r.Header.Set("TARGET_RSRC", models.HostRsrc.String())

+ 40 - 0
controllers/network.go

@@ -54,6 +54,46 @@ func getNetworks(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	username := r.Header.Get("user")
+	user, err := logic.GetUser(username)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	platformRole, err := logic.GetRole(user.PlatformRoleID)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	if !platformRole.FullAccess {
+		allNetworkRoles := make(map[models.NetworkID]struct{})
+		if len(user.NetworkRoles) > 0 {
+			for netID := range user.NetworkRoles {
+				allNetworkRoles[netID] = struct{}{}
+
+			}
+		}
+		if len(user.UserGroups) > 0 {
+			for userGID := range user.UserGroups {
+				userG, err := logic.GetUserGroup(userGID)
+				if err == nil {
+					if len(userG.NetworkRoles) > 0 {
+						for netID := range userG.NetworkRoles {
+							allNetworkRoles[netID] = struct{}{}
+
+						}
+					}
+				}
+			}
+		}
+		filteredNetworks := []models.Network{}
+		for _, networkI := range allnetworks {
+			if _, ok := allNetworkRoles[models.NetworkID(networkI.NetID)]; ok {
+				filteredNetworks = append(filteredNetworks, networkI)
+			}
+		}
+		allnetworks = filteredNetworks
+	}
 
 	logger.Log(2, r.Header.Get("user"), "fetched networks.")
 	logic.SortNetworks(allnetworks[:])

+ 39 - 2
controllers/user.go

@@ -34,6 +34,7 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceUsers, http.HandlerFunc(createUser)))).Methods(http.MethodPost)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet)
+	//r.HandleFunc("/api/v1/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers))).Methods(http.MethodDelete)
@@ -42,7 +43,7 @@ func userHandlers(r *mux.Router) {
 
 	// User Role Handlers
 	r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(listRoles))).Methods(http.MethodGet)
-	r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(getRole))).Methods(http.MethodGet)
+	r.HandleFunc("/api/v1/users/role", getRole).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(createRole))).Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(updateRole))).Methods(http.MethodPut)
 	r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(deleteRole))).Methods(http.MethodDelete)
@@ -129,7 +130,7 @@ func getUserGroup(w http.ResponseWriter, r *http.Request) {
 //				200: userBodyResponse
 func createUserGroup(w http.ResponseWriter, r *http.Request) {
 	var userGroupReq models.CreateGroupReq
-	err := json.NewDecoder(r.Body).Decode(&userGroupReq.Group)
+	err := json.NewDecoder(r.Body).Decode(&userGroupReq)
 	if err != nil {
 		slog.Error("error decoding request body", "error",
 			err.Error())
@@ -536,6 +537,42 @@ func getUser(w http.ResponseWriter, r *http.Request) {
 	json.NewEncoder(w).Encode(user)
 }
 
+// swagger:route GET /api/v1/users/{username} user getUser
+//
+// Get an individual user with role info.
+//
+//			Schemes: https
+//
+//			Security:
+//	  		oauth
+//
+//			Responses:
+//				200: userBodyResponse
+func getUserV1(w http.ResponseWriter, r *http.Request) {
+	// set header.
+	w.Header().Set("Content-Type", "application/json")
+
+	var params = mux.Vars(r)
+	usernameFetched := params["username"]
+	user, err := logic.GetReturnUser(usernameFetched)
+	if err != nil {
+		logger.Log(0, usernameFetched, "failed to fetch user: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	userRoleTemplate, err := logic.GetRole(user.PlatformRoleID)
+	if err != nil {
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	resp := models.ReturnUserWithRolesAndGroups{
+		ReturnUser:   user,
+		PlatformRole: userRoleTemplate,
+	}
+	logger.Log(2, r.Header.Get("user"), "fetched user", usernameFetched)
+	logic.ReturnSuccessResponseWithJson(w, r, resp, "fetched user with role info")
+}
+
 // swagger:route GET /api/users user getUsers
 //
 // Get all users.

+ 20 - 4
logic/nodes.go

@@ -679,15 +679,31 @@ func GetFilteredNodesByUserAccess(user models.User, nodes []models.Node) (filter
 
 	nodesMap := make(map[string]struct{})
 	allNetworkRoles := []models.UserRole{}
-	for _, netRoles := range user.NetworkRoles {
-		for netRoleI := range netRoles {
-			allNetworkRoles = append(allNetworkRoles, netRoleI)
+	if len(user.NetworkRoles) > 0 {
+		for _, netRoles := range user.NetworkRoles {
+			for netRoleI := range netRoles {
+				allNetworkRoles = append(allNetworkRoles, netRoleI)
+			}
+		}
+	}
+	if len(user.UserGroups) > 0 {
+		for userGID := range user.UserGroups {
+			userG, err := GetUserGroup(userGID)
+			if err == nil {
+				if len(userG.NetworkRoles) > 0 {
+					for _, netRoles := range userG.NetworkRoles {
+						for netRoleI := range netRoles {
+							allNetworkRoles = append(allNetworkRoles, netRoleI)
+						}
+					}
+				}
+			}
 		}
 	}
 	for _, networkRoleID := range allNetworkRoles {
 		userPermTemplate, err := GetRole(networkRoleID)
 		if err != nil {
-			return
+			continue
 		}
 		networkNodes := GetNetworkNodesMemory(nodes, userPermTemplate.NetworkID)
 		if userPermTemplate.FullAccess {

+ 1 - 1
logic/security.go

@@ -165,7 +165,7 @@ func globalPermissionsCheck(username string, r *http.Request) error {
 	if targetRsrc == models.MetricRsrc.String() {
 		return nil
 	}
-	if targetRsrc == models.HostRsrc.String() && r.Method == http.MethodGet && targetRsrcID == "" {
+	if (targetRsrc == models.HostRsrc.String() || targetRsrc == models.NetworkRsrc.String()) && r.Method == http.MethodGet && targetRsrcID == "" {
 		return nil
 	}
 	if targetRsrc == models.UserRsrc.String() && username == targetRsrcID && (r.Method != http.MethodDelete) {

+ 5 - 0
models/user_mgmt.go

@@ -134,6 +134,11 @@ type User struct {
 	LastLoginTime  time.Time                           `json:"last_login_time"`
 }
 
+type ReturnUserWithRolesAndGroups struct {
+	ReturnUser
+	PlatformRole UserRolePermissionTemplate
+}
+
 // ReturnUser - return user struct
 type ReturnUser struct {
 	UserName       string                              `json:"username"`