浏览代码

feat(go): support disabling user accounts;

1. Add api endpoints to enable and disable user accounts.
2. Add checks in authenticators to prevent disabled users from logging in.
3. Add checks in middleware to prevent api usage by disabled users.
Vishal Dalwadi 5 月之前
父节点
当前提交
29841ffa26
共有 8 个文件被更改,包括 116 次插入0 次删除
  1. 70 0
      controllers/user.go
  2. 14 0
      logic/security.go
  3. 1 0
      models/user_mgmt.go
  4. 6 0
      pro/auth/azure-ad.go
  5. 8 0
      pro/auth/error.go
  6. 6 0
      pro/auth/github.go
  7. 5 0
      pro/auth/google.go
  8. 6 0
      pro/auth/oidc.go

+ 70 - 0
controllers/user.go

@@ -34,6 +34,8 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceUsers, http.HandlerFunc(createUser)))).Methods(http.MethodPost)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceUsers, http.HandlerFunc(createUser)))).Methods(http.MethodPost)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet)
+	r.HandleFunc("/api/users/{username}/enable", logic.SecurityCheck(true, http.HandlerFunc(enableUserAccount))).Methods(http.MethodPost)
+	r.HandleFunc("/api/users/{username}/disable", logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount))).Methods(http.MethodPost)
 	r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
 	r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet)
 	r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet)
@@ -95,15 +97,24 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 			return
 			return
 		}
 		}
 	}
 	}
+
 	user, err := logic.GetUser(authRequest.UserName)
 	user, err := logic.GetUser(authRequest.UserName)
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized"))
 		logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized"))
 		return
 		return
 	}
 	}
+
+	if user.AccountDisabled {
+		err = errors.New("user account disabled")
+		logic.ReturnErrorResponse(response, request, logic.FormatError(err, "unauthorized"))
+		return
+	}
+
 	if logic.IsOauthUser(user) == nil {
 	if logic.IsOauthUser(user) == nil {
 		logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("user is registered via SSO"), "badrequest"))
 		logic.ReturnErrorResponse(response, request, logic.FormatError(errors.New("user is registered via SSO"), "badrequest"))
 		return
 		return
 	}
 	}
+
 	username := authRequest.UserName
 	username := authRequest.UserName
 	jwt, err := logic.VerifyAuthRequest(authRequest)
 	jwt, err := logic.VerifyAuthRequest(authRequest)
 	if err != nil {
 	if err != nil {
@@ -225,6 +236,65 @@ func getUser(w http.ResponseWriter, r *http.Request) {
 	json.NewEncoder(w).Encode(user)
 	json.NewEncoder(w).Encode(user)
 }
 }
 
 
+// @Summary     Enable a user's account
+// @Router      /api/users/{username}/enable [post]
+// @Tags        Users
+// @Param       username path string true "Username of the user to enable"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func enableUserAccount(w http.ResponseWriter, r *http.Request) {
+	username := mux.Vars(r)["username"]
+	user, err := logic.GetUser(username)
+	if err != nil {
+		logger.Log(0, "failed to fetch user: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	user.AccountDisabled = false
+	err = logic.UpsertUser(*user)
+	if err != nil {
+		logger.Log(0, "failed to enable user account: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+	}
+
+	logic.ReturnSuccessResponse(w, r, "user account enabled")
+}
+
+// @Summary     Disable a user's account
+// @Router      /api/users/{username}/disable [post]
+// @Tags        Users
+// @Param       username path string true "Username of the user to disable"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func disableUserAccount(w http.ResponseWriter, r *http.Request) {
+	username := mux.Vars(r)["username"]
+	user, err := logic.GetUser(username)
+	if err != nil {
+		logger.Log(0, "failed to fetch user: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	if user.PlatformRoleID == models.SuperAdminRole {
+		err = errors.New("cannot disable super-admin user account")
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	user.AccountDisabled = true
+	err = logic.UpsertUser(*user)
+	if err != nil {
+		logger.Log(0, "failed to disable user account: ", err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+	}
+
+	logic.ReturnSuccessResponse(w, r, "user account disabled")
+}
+
 // swagger:route GET /api/v1/users user getUserV1
 // swagger:route GET /api/v1/users user getUserV1
 //
 //
 // Get an individual user with role info.
 // Get an individual user with role info.

+ 14 - 0
logic/security.go

@@ -1,6 +1,7 @@
 package logic
 package logic
 
 
 import (
 import (
+	"errors"
 	"net/http"
 	"net/http"
 	"strings"
 	"strings"
 
 
@@ -32,6 +33,19 @@ func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
 			ReturnErrorResponse(w, r, FormatError(err, "unauthorized"))
 			ReturnErrorResponse(w, r, FormatError(err, "unauthorized"))
 			return
 			return
 		}
 		}
+
+		user, err := GetUser(username)
+		if err != nil {
+			ReturnErrorResponse(w, r, FormatError(err, "unauthorized"))
+			return
+		}
+
+		if user.AccountDisabled {
+			err = errors.New("user account disabled")
+			ReturnErrorResponse(w, r, FormatError(err, "unauthorized"))
+			return
+		}
+
 		// detect masteradmin
 		// detect masteradmin
 		if username == MasterUser {
 		if username == MasterUser {
 			r.Header.Set("ismaster", "yes")
 			r.Header.Set("ismaster", "yes")

+ 1 - 0
models/user_mgmt.go

@@ -146,6 +146,7 @@ type UserGroup struct {
 type User struct {
 type User struct {
 	UserName                   string                                `json:"username" bson:"username" validate:"min=3,in_charset|email"`
 	UserName                   string                                `json:"username" bson:"username" validate:"min=3,in_charset|email"`
 	ExternalIdentityProviderID string                                `json:"external_identity_provider_id"`
 	ExternalIdentityProviderID string                                `json:"external_identity_provider_id"`
+	AccountDisabled            bool                                  `json:"account_disabled"`
 	Password                   string                                `json:"password" bson:"password" validate:"required,min=5"`
 	Password                   string                                `json:"password" bson:"password" validate:"required,min=5"`
 	IsAdmin                    bool                                  `json:"isadmin" bson:"isadmin"` // deprecated
 	IsAdmin                    bool                                  `json:"isadmin" bson:"isadmin"` // deprecated
 	IsSuperAdmin               bool                                  `json:"issuperadmin"`           // deprecated
 	IsSuperAdmin               bool                                  `json:"issuperadmin"`           // deprecated

+ 6 - 0
pro/auth/azure-ad.go

@@ -152,6 +152,12 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 		handleOauthUserNotFound(w)
 		handleOauthUserNotFound(w)
 		return
 		return
 	}
 	}
+
+	if user.AccountDisabled {
+		handleUserAccountDisabled(w)
+		return
+	}
+
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	if err != nil {
 	if err != nil {
 		handleSomethingWentWrong(w)
 		handleSomethingWentWrong(w)

+ 8 - 0
pro/auth/error.go

@@ -113,6 +113,8 @@ var notallowedtosignup = fmt.Sprintf(htmlBaseTemplate, `<h2>Your email is not al
 var authTypeMismatch = fmt.Sprintf(htmlBaseTemplate, `<h2>It looks like you already have an account with us using Basic Authentication.</h2>
 var authTypeMismatch = fmt.Sprintf(htmlBaseTemplate, `<h2>It looks like you already have an account with us using Basic Authentication.</h2>
 <p>To continue, please log in with your existing credentials or reset your password if needed.</p>`)
 <p>To continue, please log in with your existing credentials or reset your password if needed.</p>`)
 
 
+var userAccountDisabled = fmt.Sprintf(htmlBaseTemplate, `<h2>Your account has been disabled. Please contact your administrator for more information about your account.</h2>`)
+
 func handleOauthUserNotFound(response http.ResponseWriter) {
 func handleOauthUserNotFound(response http.ResponseWriter) {
 	response.Header().Set("Content-Type", "text/html; charset=utf-8")
 	response.Header().Set("Content-Type", "text/html; charset=utf-8")
 	response.WriteHeader(http.StatusNotFound)
 	response.WriteHeader(http.StatusNotFound)
@@ -166,3 +168,9 @@ func handleAuthTypeMismatch(response http.ResponseWriter) {
 	response.WriteHeader(http.StatusBadRequest)
 	response.WriteHeader(http.StatusBadRequest)
 	response.Write([]byte(authTypeMismatch))
 	response.Write([]byte(authTypeMismatch))
 }
 }
+
+func handleUserAccountDisabled(response http.ResponseWriter) {
+	response.Header().Set("Content-Type", "text/html; charset=utf-8")
+	response.WriteHeader(http.StatusUnauthorized)
+	response.Write([]byte(userAccountDisabled))
+}

+ 6 - 0
pro/auth/github.go

@@ -143,6 +143,12 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
 		handleOauthUserNotFound(w)
 		handleOauthUserNotFound(w)
 		return
 		return
 	}
 	}
+
+	if user.AccountDisabled {
+		handleUserAccountDisabled(w)
+		return
+	}
+
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	if err != nil {
 	if err != nil {
 		handleSomethingWentWrong(w)
 		handleSomethingWentWrong(w)

+ 5 - 0
pro/auth/google.go

@@ -135,6 +135,11 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
+	if user.AccountDisabled {
+		handleUserAccountDisabled(w)
+		return
+	}
+
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	if err != nil {
 	if err != nil {
 		handleSomethingWentWrong(w)
 		handleSomethingWentWrong(w)

+ 6 - 0
pro/auth/oidc.go

@@ -143,6 +143,12 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
 		handleOauthUserNotFound(w)
 		handleOauthUserNotFound(w)
 		return
 		return
 	}
 	}
+
+	if user.AccountDisabled {
+		handleUserAccountDisabled(w)
+		return
+	}
+
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	userRole, err := logic.GetRole(user.PlatformRoleID)
 	if err != nil {
 	if err != nil {
 		handleSomethingWentWrong(w)
 		handleSomethingWentWrong(w)