Pārlūkot izejas kodu

NET-1996: Add Support for TOTP Authentication. (#3517)

* feat(git): ignore run configurations;

* feat(go): add support for TOTP authentication;

* fix(go): api docs;

* fix(go): static checks failing;

* fix(go): ignore mfa enforcement for user auth;

* feat(go): allow resetting mfa;

* feat(go): allow resetting mfa;

* feat(go): use library function;

* fix(go): signature;

* feat(go): allow only master user to unset user's mfa;

* feat(go): set caller when master to prevent panic;

* feat(go): make messages more user friendly;

* fix(go): run go mod tidy;

* fix(go): optimize imports;

* fix(go): return unauthorized on token expiry;

* fix(go): move mfa endpoints under username;

* fix(go): set is mfa enabled when converting;

* feat(go): allow authenticated users to use preauth apis;

* feat(go): set correct header value;

* feat(go): allow super-admins and admins to unset mfa;

* feat(go): allow user to unset mfa if not enforced;
Vishal Dalwadi 2 mēneši atpakaļ
vecāks
revīzija
3551e8e24e
12 mainītis faili ar 419 papildinājumiem un 45 dzēšanām
  1. 1 0
      .gitignore
  2. 237 22
      controllers/user.go
  3. 2 0
      go.mod
  4. 4 0
      go.sum
  5. 29 14
      logic/auth.go
  6. 44 0
      logic/jwts.go
  7. 59 0
      logic/security.go
  8. 5 0
      logic/settings.go
  9. 11 9
      logic/users.go
  10. 1 0
      models/settings.go
  11. 17 0
      models/structs.go
  12. 9 0
      models/user_mgmt.go

+ 1 - 0
.gitignore

@@ -22,6 +22,7 @@ controllers/data/
 data/
 .vscode/
 .idea/
+.run/
 netmaker.exe
 netmaker.code-workspace
 dist/

+ 237 - 22
controllers/user.go

@@ -1,11 +1,13 @@
 package controller
 
 import (
-	"context"
+	"bytes"
+	"encoding/base64"
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gravitl/netmaker/db"
+	"github.com/pquerna/otp"
+	"image/png"
 	"net/http"
 	"reflect"
 	"time"
@@ -20,6 +22,7 @@ import (
 	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/schema"
 	"github.com/gravitl/netmaker/servercfg"
+	"github.com/pquerna/otp/totp"
 	"golang.org/x/exp/slog"
 )
 
@@ -35,6 +38,9 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/users/adm/transfersuperadmin/{username}", logic.SecurityCheck(true, http.HandlerFunc(transferSuperAdmin))).
 		Methods(http.MethodPost)
 	r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods(http.MethodPost)
+	r.HandleFunc("/api/users/{username}/auth/init-totp", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(initiateTOTPSetup)))).Methods(http.MethodPost)
+	r.HandleFunc("/api/users/{username}/auth/complete-totp", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(completeTOTPSetup)))).Methods(http.MethodPost)
+	r.HandleFunc("/api/users/{username}/auth/verify-totp", logic.PreAuthCheck(logic.ContinueIfUserMatch(http.HandlerFunc(verifyTOTP)))).Methods(http.MethodPost)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(updateUser))).Methods(http.MethodPut)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, checkFreeTierLimits(limitChoiceUsers, http.HandlerFunc(createUser)))).Methods(http.MethodPost)
 	r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete)
@@ -356,14 +362,28 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 		return
 	}
 
-	var successResponse = models.SuccessResponse{
-		Code:    http.StatusOK,
-		Message: "W1R3: Device " + username + " Authorized",
-		Response: models.SuccessfulUserLoginResponse{
-			AuthToken: jwt,
-			UserName:  username,
-		},
+	var successResponse models.SuccessResponse
+
+	if user.IsMFAEnabled {
+		successResponse = models.SuccessResponse{
+			Code:    http.StatusOK,
+			Message: "W1R3: TOTP required",
+			Response: models.PartialUserLoginResponse{
+				UserName:     username,
+				PreAuthToken: jwt,
+			},
+		}
+	} else {
+		successResponse = models.SuccessResponse{
+			Code:    http.StatusOK,
+			Message: "W1R3: Device " + username + " Authorized",
+			Response: models.SuccessfulUserLoginResponse{
+				UserName:  username,
+				AuthToken: jwt,
+			},
+		}
 	}
+
 	// Send back the JWT
 	successJSONResponse, jsonError := json.Marshal(successResponse)
 	if jsonError != nil {
@@ -414,6 +434,201 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	}()
 }
 
+// @Summary     Initiate setting up TOTP 2FA for a user.
+// @Router      /api/users/auth/init-totp [post]
+// @Tags        Auth
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func initiateTOTPSetup(w http.ResponseWriter, r *http.Request) {
+	username := r.Header.Get("user")
+
+	user, err := logic.GetUser(username)
+	if err != nil {
+		logger.Log(0, "failed to get user: ", err.Error())
+		err = fmt.Errorf("user not found: %v", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	if user.AuthType == models.OAuth {
+		err = fmt.Errorf("auth type is %s, cannot process totp setup", user.AuthType)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	key, err := totp.Generate(totp.GenerateOpts{
+		Issuer:      "Netmaker",
+		AccountName: username,
+	})
+	if err != nil {
+		err = fmt.Errorf("failed to generate totp key: %v", err)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	qrCodeImg, err := key.Image(200, 200)
+	if err != nil {
+		err = fmt.Errorf("failed to generate totp key: %v", err)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	var qrCodePng bytes.Buffer
+	err = png.Encode(&qrCodePng, qrCodeImg)
+	if err != nil {
+		err = fmt.Errorf("failed to generate totp key: %v", err)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+
+	qrCode := "data:image/png;base64," + base64.StdEncoding.EncodeToString(qrCodePng.Bytes())
+
+	logic.ReturnSuccessResponseWithJson(w, r, models.TOTPInitiateResponse{
+		OTPAuthURL:          key.URL(),
+		OTPAuthURLSignature: logic.GenerateOTPAuthURLSignature(key.URL()),
+		QRCode:              qrCode,
+	}, "totp setup initiated")
+}
+
+// @Summary     Verify and complete setting up TOTP 2FA for a user.
+// @Router      /api/users/auth/complete-totp [post]
+// @Tags        Auth
+// @Param       body body models.UserTOTPVerificationParams true "TOTP verification parameters"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func completeTOTPSetup(w http.ResponseWriter, r *http.Request) {
+	username := r.Header.Get("user")
+
+	var req models.UserTOTPVerificationParams
+	err := json.NewDecoder(r.Body).Decode(&req)
+	if err != nil {
+		logger.Log(0, "failed to decode request body: ", err.Error())
+		err = fmt.Errorf("invalid request body: %v", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	if !logic.VerifyOTPAuthURL(req.OTPAuthURL, req.OTPAuthURLSignature) {
+		err = fmt.Errorf("otp auth url signature mismatch")
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	user, err := logic.GetUser(username)
+	if err != nil {
+		logger.Log(0, "failed to get user: ", err.Error())
+		err = fmt.Errorf("user not found: %v", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	if user.AuthType == models.OAuth {
+		err = fmt.Errorf("auth type is %s, cannot process totp setup", user.AuthType)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	otpAuthURL, err := otp.NewKeyFromURL(req.OTPAuthURL)
+	if err != nil {
+		err = fmt.Errorf("error parsing otp auth url: %v", err)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	totpSecret := otpAuthURL.Secret()
+
+	if totp.Validate(req.TOTP, totpSecret) {
+		user.IsMFAEnabled = true
+		user.TOTPSecret = totpSecret
+		err = logic.UpsertUser(*user)
+		if err != nil {
+			err = fmt.Errorf("error upserting user: %v", err)
+			logger.Log(0, err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
+		}
+
+		logic.ReturnSuccessResponse(w, r, fmt.Sprintf("totp setup complete for user %s", username))
+	} else {
+		err = fmt.Errorf("cannot setup totp for user %s: invalid otp", username)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+	}
+}
+
+// @Summary     Verify a user's TOTP token.
+// @Router      /api/users/auth/verify-totp [post]
+// @Tags        Auth
+// @Accept      json
+// @Param       body body models.UserTOTPVerificationParams true "TOTP verification parameters"
+// @Success     200 {object} models.SuccessResponse
+// @Failure     400 {object} models.ErrorResponse
+// @Failure     401 {object} models.ErrorResponse
+// @Failure     500 {object} models.ErrorResponse
+func verifyTOTP(w http.ResponseWriter, r *http.Request) {
+	username := r.Header.Get("user")
+
+	var req models.UserTOTPVerificationParams
+	err := json.NewDecoder(r.Body).Decode(&req)
+	if err != nil {
+		logger.Log(0, "failed to decode request body: ", err.Error())
+		err = fmt.Errorf("invalid request body: %v", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	user, err := logic.GetUser(username)
+	if err != nil {
+		logger.Log(0, "failed to get user: ", err.Error())
+		err = fmt.Errorf("user not found: %v", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	if !user.IsMFAEnabled {
+		err = fmt.Errorf("mfa is disabled for user(%s), cannot process totp verification", username)
+		logger.Log(0, err.Error())
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+		return
+	}
+
+	if totp.Validate(req.TOTP, user.TOTPSecret) {
+		jwt, err := logic.CreateUserJWT(user.UserName, user.PlatformRoleID)
+		if err != nil {
+			err = fmt.Errorf("error creating token: %v", err)
+			logger.Log(0, err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
+		}
+
+		// update last login time
+		user.LastLoginTime = time.Now().UTC()
+		err = logic.UpsertUser(*user)
+		if err != nil {
+			err = fmt.Errorf("error upserting user: %v", err)
+			logger.Log(0, err.Error())
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+			return
+		}
+
+		logic.ReturnSuccessResponseWithJson(w, r, models.SuccessfulUserLoginResponse{
+			UserName:  username,
+			AuthToken: jwt,
+		}, "W1R3: User "+username+" Authorized")
+	} else {
+		err = fmt.Errorf("invalid otp")
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
+	}
+}
+
 // @Summary     Check if the server has a super admin
 // @Router      /api/users/adm/hassuperadmin [get]
 // @Tags        Users
@@ -586,18 +801,6 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	for i, user := range users {
-		// only setting num_access_tokens here, because only UI needs it.
-		user.NumAccessTokens, err = (&schema.UserAccessToken{
-			UserName: user.UserName,
-		}).CountByUser(db.WithContext(context.TODO()))
-		if err != nil {
-			continue
-		}
-
-		users[i] = user
-	}
-
 	logic.SortUsers(users[:])
 	logger.Log(2, r.Header.Get("user"), "fetched users")
 	json.NewEncoder(w).Encode(users)
@@ -884,6 +1087,14 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 			return
 
 		}
+
+		if logic.IsMFAEnforced() && user.IsMFAEnabled && !userchange.IsMFAEnabled {
+			err = errors.New("mfa is enforced, user cannot unset their own mfa")
+			slog.Error("failed to update user", "caller", caller.UserName, "attempted to update user", username, "error", err)
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
+			return
+		}
+
 		if servercfg.IsPro {
 			// user cannot update his own roles and groups
 			if len(user.NetworkRoles) != len(userchange.NetworkRoles) || !reflect.DeepEqual(user.NetworkRoles, userchange.NetworkRoles) {
@@ -900,7 +1111,6 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 				return
 			}
 		}
-
 	}
 	if ismaster {
 		if user.PlatformRoleID != models.SuperAdminRole && userchange.PlatformRoleID == models.SuperAdminRole {
@@ -920,6 +1130,11 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 		(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens(r.Context())
 	}
 	oldUser := *user
+	if ismaster {
+		caller = &models.User{
+			UserName: logic.MasterUser,
+		}
+	}
 	e := models.Event{
 		Action: models.Update,
 		Source: models.Subject{

+ 2 - 0
go.mod

@@ -46,6 +46,7 @@ require (
 	github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e
 	github.com/guumaster/tablewriter v0.0.10
 	github.com/matryer/is v1.4.1
+	github.com/pquerna/otp v1.5.0
 	github.com/spf13/cobra v1.9.1
 	google.golang.org/api v0.238.0
 	gopkg.in/mail.v2 v2.3.1
@@ -59,6 +60,7 @@ require (
 	cloud.google.com/go/auth v0.16.2 // indirect
 	cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
 	cloud.google.com/go/compute/metadata v0.7.0 // indirect
+	github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
 	github.com/gabriel-vasile/mimetype v1.4.8 // indirect
 	github.com/go-jose/go-jose/v4 v4.0.5 // indirect
 	github.com/go-logr/logr v1.4.2 // indirect

+ 4 - 0
go.sum

@@ -8,6 +8,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
 filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
 github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
 github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI=
+github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
 github.com/c-robinson/iplib v1.0.8 h1:exDRViDyL9UBLcfmlxxkY5odWX5092nPsQIykHXhIn4=
 github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo=
 github.com/coreos/go-oidc/v3 v3.14.1 h1:9ePWwfdwC4QKRlCXsJGou56adA/owXczOzwKdOumLqk=
@@ -107,6 +109,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/posthog/posthog-go v1.5.12 h1:nxK/z5QLCFxwzxV8GNvVd4Y1wJ++zJSWMGEtzU+/HLM=
 github.com/posthog/posthog-go v1.5.12/go.mod h1:ZPCind3bz8xDLK0Zhvpv1fQav6WfRcQDqTMfMXmna98=
+github.com/pquerna/otp v1.5.0 h1:NMMR+WrmaqXU4EzdGJEE1aUUI0AMRzsp96fFFWNPwxs=
+github.com/pquerna/otp v1.5.0/go.mod h1:dkJfzwRKNiegxyNb54X/3fLwhCynbMspSyWKnvi1AEg=
 github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
 github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
 github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=

+ 29 - 14
logic/auth.go

@@ -235,22 +235,32 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
 		return "", errors.New("incorrect credentials")
 	}
 
-	// Create a new JWT for the node
-	tokenString, err := CreateUserJWT(authRequest.UserName, result.PlatformRoleID)
-	if err != nil {
-		slog.Error("error creating jwt", "error", err)
-		return "", err
-	}
+	if result.IsMFAEnabled {
+		tokenString, err := CreatePreAuthToken(authRequest.UserName)
+		if err != nil {
+			slog.Error("error creating jwt", "error", err)
+			return "", err
+		}
 
-	// update last login time
-	result.LastLoginTime = time.Now().UTC()
-	err = UpsertUser(result)
-	if err != nil {
-		slog.Error("error upserting user", "error", err)
-		return "", err
-	}
+		return tokenString, nil
+	} else {
+		// Create a new JWT for the node
+		tokenString, err := CreateUserJWT(authRequest.UserName, result.PlatformRoleID)
+		if err != nil {
+			slog.Error("error creating jwt", "error", err)
+			return "", err
+		}
+
+		// update last login time
+		result.LastLoginTime = time.Now().UTC()
+		err = UpsertUser(result)
+		if err != nil {
+			slog.Error("error upserting user", "error", err)
+			return "", err
+		}
 
-	return tokenString, nil
+		return tokenString, nil
+	}
 }
 
 // UpsertUser - updates user in the db
@@ -359,6 +369,11 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
 		}
 	}
 
+	user.IsMFAEnabled = userchange.IsMFAEnabled
+	if !user.IsMFAEnabled {
+		user.TOTPSecret = ""
+	}
+
 	user.UserGroups = userchange.UserGroups
 	user.NetworkRoles = userchange.NetworkRoles
 	AddGlobalNetRolesToAdmins(user)

+ 44 - 0
logic/jwts.go

@@ -2,6 +2,9 @@ package logic
 
 import (
 	"context"
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
 	"errors"
 	"fmt"
 	"strings"
@@ -103,6 +106,38 @@ func CreateUserJWT(username string, role models.UserRoleID) (response string, er
 	return "", err
 }
 
+// CreatePreAuthToken generate a jwt token to be used as intermediate
+// token after primary-factor authentication but before secondary-factor
+// authentication.
+func CreatePreAuthToken(username string) (string, error) {
+	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.RegisteredClaims{
+		Issuer:    "Netmaker",
+		Subject:   username,
+		Audience:  []string{"auth:mfa"},
+		IssuedAt:  jwt.NewNumericDate(time.Now()),
+		ExpiresAt: jwt.NewNumericDate(time.Now().Add(5 * time.Minute)),
+	})
+
+	return token.SignedString(jwtSecretKey)
+}
+
+func GenerateOTPAuthURLSignature(url string) string {
+	signer := hmac.New(sha256.New, jwtSecretKey)
+	signer.Write([]byte(url))
+	return hex.EncodeToString(signer.Sum(nil))
+}
+
+func VerifyOTPAuthURL(url, signature string) bool {
+	signatureBytes, err := hex.DecodeString(signature)
+	if err != nil {
+		return false
+	}
+
+	signer := hmac.New(sha256.New, jwtSecretKey)
+	signer.Write([]byte(url))
+	return hmac.Equal(signatureBytes, signer.Sum(nil))
+}
+
 func GetUserNameFromToken(authtoken string) (username string, err error) {
 	claims := &models.UserClaims{}
 	var tokenSplit = strings.Split(authtoken, " ")
@@ -123,6 +158,15 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
 	if err != nil {
 		return "", Unauthorized_Err
 	}
+
+	for _, aud := range claims.Audience {
+		// token created for mfa cannot be used for
+		// anything else.
+		if aud == "auth:mfa" {
+			return "", Unauthorized_Err
+		}
+	}
+
 	if claims.TokenType == models.AccessTokenType {
 		jti := claims.ID
 		if jti != "" {

+ 59 - 0
logic/security.go

@@ -2,6 +2,7 @@ package logic
 
 import (
 	"errors"
+	"github.com/golang-jwt/jwt/v4"
 	"net/http"
 	"strings"
 
@@ -72,6 +73,64 @@ func SecurityCheck(reqAdmin bool, next http.Handler) http.HandlerFunc {
 	}
 }
 
+func PreAuthCheck(next http.Handler) http.HandlerFunc {
+	return func(w http.ResponseWriter, r *http.Request) {
+		authHeader := r.Header.Get("Authorization")
+		headerSplits := strings.Split(authHeader, " ")
+		if len(headerSplits) != 2 {
+			ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+			return
+		}
+
+		authToken := headerSplits[1]
+
+		// first check is user is authenticated.
+		// if yes, allow the user to go through.
+		username, err := GetUserNameFromToken(authHeader)
+		if err != nil {
+			// if no, then check the user has a pre-auth token.
+			var claims jwt.RegisteredClaims
+			token, err := jwt.ParseWithClaims(authToken, &claims, func(token *jwt.Token) (interface{}, error) {
+				return jwtSecretKey, nil
+			})
+			if err != nil {
+				ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+				return
+			}
+
+			if token != nil && token.Valid {
+				if len(claims.Audience) > 0 {
+					var found bool
+					for _, aud := range claims.Audience {
+						if aud == "auth:mfa" {
+							found = true
+						}
+					}
+
+					if !found {
+						ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+						return
+					}
+
+					r.Header.Set("user", claims.Subject)
+					next.ServeHTTP(w, r)
+					return
+				} else {
+					ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+					return
+				}
+			} else {
+				ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+				return
+			}
+		} else {
+			r.Header.Set("user", username)
+			next.ServeHTTP(w, r)
+			return
+		}
+	}
+}
+
 // UserPermissions - checks token stuff
 func UserPermissions(reqAdmin bool, token string) (string, error) {
 	var tokenSplit = strings.Split(token, " ")

+ 5 - 0
logic/settings.go

@@ -320,6 +320,11 @@ func IsBasicAuthEnabled() bool {
 	return GetServerSettings().BasicAuth
 }
 
+// IsMFAEnforced returns whether MFA has been enforced.
+func IsMFAEnforced() bool {
+	return GetServerSettings().MFAEnforced
+}
+
 // IsEndpointDetectionEnabled - returns true if endpoint detection enabled
 func IsEndpointDetectionEnabled() bool {
 	return GetServerSettings().EndpointDetection

+ 11 - 9
logic/users.go

@@ -41,15 +41,17 @@ func GetReturnUser(username string) (models.ReturnUser, error) {
 // ToReturnUser - gets a user as a return user
 func ToReturnUser(user models.User) models.ReturnUser {
 	return models.ReturnUser{
-		UserName:        user.UserName,
-		DisplayName:     user.DisplayName,
-		AccountDisabled: user.AccountDisabled,
-		AuthType:        user.AuthType,
-		RemoteGwIDs:     user.RemoteGwIDs,
-		UserGroups:      user.UserGroups,
-		PlatformRoleID:  user.PlatformRoleID,
-		NetworkRoles:    user.NetworkRoles,
-		LastLoginTime:   user.LastLoginTime,
+		UserName:                   user.UserName,
+		ExternalIdentityProviderID: user.ExternalIdentityProviderID,
+		IsMFAEnabled:               user.IsMFAEnabled,
+		DisplayName:                user.DisplayName,
+		AccountDisabled:            user.AccountDisabled,
+		AuthType:                   user.AuthType,
+		RemoteGwIDs:                user.RemoteGwIDs,
+		UserGroups:                 user.UserGroups,
+		PlatformRoleID:             user.PlatformRoleID,
+		NetworkRoles:               user.NetworkRoles,
+		LastLoginTime:              user.LastLoginTime,
 	}
 }
 

+ 1 - 0
models/settings.go

@@ -25,6 +25,7 @@ type ServerSettings struct {
 	Telemetry                      string   `json:"telemetry"`
 	BasicAuth                      bool     `json:"basic_auth"`
 	JwtValidityDuration            int      `json:"jwt_validity_duration"`
+	MFAEnforced                    bool     `json:"mfa_enforced"`
 	RacRestrictToSingleNetwork     bool     `json:"rac_restrict_to_single_network"`
 	EndpointDetection              bool     `json:"endpoint_detection"`
 	AllowedEmailDomains            string   `json:"allowed_email_domains"`

+ 17 - 0
models/structs.go

@@ -69,6 +69,23 @@ type SuccessfulUserLoginResponse struct {
 	AuthToken string
 }
 
+// PartialUserLoginResponse represents the response returned to the client
+// after successful username and password authentication, but before the
+// completion of TOTP authentication.
+//
+// This response includes a temporary token required to complete
+// the authentication process.
+type PartialUserLoginResponse struct {
+	UserName     string `json:"user_name"`
+	PreAuthToken string `json:"pre_auth_token"`
+}
+
+type TOTPInitiateResponse struct {
+	OTPAuthURL          string `json:"otp_auth_url"`
+	OTPAuthURLSignature string `json:"otp_auth_url_signature"`
+	QRCode              string `json:"qr_code"`
+}
+
 // Claims is  a struct that will be encoded to a JWT.
 // jwt.StandardClaims is an embedded type to provide expiry time
 type Claims struct {

+ 9 - 0
models/user_mgmt.go

@@ -157,6 +157,8 @@ type UserGroup struct {
 type User struct {
 	UserName                   string                                `json:"username" bson:"username" validate:"min=3,in_charset|email"`
 	ExternalIdentityProviderID string                                `json:"external_identity_provider_id"`
+	IsMFAEnabled               bool                                  `json:"is_mfa_enabled"`
+	TOTPSecret                 string                                `json:"totp_secret"`
 	DisplayName                string                                `json:"display_name"`
 	AccountDisabled            bool                                  `json:"account_disabled"`
 	Password                   string                                `json:"password" bson:"password" validate:"required,min=5"`
@@ -180,6 +182,7 @@ type ReturnUserWithRolesAndGroups struct {
 type ReturnUser struct {
 	UserName                   string                                `json:"username"`
 	ExternalIdentityProviderID string                                `json:"external_identity_provider_id"`
+	IsMFAEnabled               bool                                  `json:"is_mfa_enabled"`
 	DisplayName                string                                `json:"display_name"`
 	AccountDisabled            bool                                  `json:"account_disabled"`
 	IsAdmin                    bool                                  `json:"isadmin"`
@@ -199,6 +202,12 @@ type UserAuthParams struct {
 	Password string `json:"password"`
 }
 
+type UserTOTPVerificationParams struct {
+	OTPAuthURL          string `json:"otp_auth_url"`
+	OTPAuthURLSignature string `json:"otp_auth_url_signature"`
+	TOTP                string `json:"totp"`
+}
+
 // UserClaims - user claims struct
 type UserClaims struct {
 	Role           UserRoleID