Kaynağa Gözat

feat(go): allow authenticated users to use preauth apis;

Vishal Dalwadi 4 ay önce
ebeveyn
işleme
18792d65ac
2 değiştirilmiş dosya ile 35 ekleme ve 30 silme
  1. 5 7
      logic/jwts.go
  2. 30 23
      logic/security.go

+ 5 - 7
logic/jwts.go

@@ -161,13 +161,11 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
 		return "", Unauthorized_Err
 	}
 
-	if len(claims.Audience) > 0 {
-		for _, aud := range claims.Audience {
-			// token created for mfa cannot be used for
-			// anything else.
-			if aud == "auth:mfa" {
-				return "", Unauthorized_Err
-			}
+	for _, aud := range claims.Audience {
+		// token created for mfa cannot be used for
+		// anything else.
+		if aud == "auth:mfa" {
+			return "", Unauthorized_Err
 		}
 	}
 

+ 30 - 23
logic/security.go

@@ -84,37 +84,44 @@ func PreAuthCheck(next http.Handler) http.HandlerFunc {
 
 		authToken := headerSplits[1]
 
-		var claims jwt.RegisteredClaims
-		token, err := jwt.ParseWithClaims(authToken, &claims, func(token *jwt.Token) (interface{}, error) {
-			return jwtSecretKey, nil
-		})
+		// first check is user is authenticated.
+		// if yes, allow the user to go through.
+		username, err := GetUserNameFromToken(authToken)
 		if err != nil {
-			ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
-			return
-		}
+			// if no, then check the user has a pre-auth token.
+			var claims jwt.RegisteredClaims
+			token, err := jwt.ParseWithClaims(authToken, &claims, func(token *jwt.Token) (interface{}, error) {
+				return jwtSecretKey, nil
+			})
+			if err != nil {
+				ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+				return
+			}
 
-		if token != nil && token.Valid {
-			if len(claims.Audience) > 0 {
-				var found bool
-				for _, aud := range claims.Audience {
-					// token created for mfa cannot be used for
-					// anything else.
-					if aud == "auth:mfa" {
-						found = true
+			if token != nil && token.Valid {
+				if len(claims.Audience) > 0 {
+					var found bool
+					for _, aud := range claims.Audience {
+						if aud == "auth:mfa" {
+							found = true
+						}
 					}
-				}
 
-				if !found {
+					if !found {
+						ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+					}
+				} else {
 					ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+					return
 				}
-
-				r.Header.Set("user", claims.Subject)
-				next.ServeHTTP(w, r)
+			} else {
+				ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
+				return
 			}
-		} else {
-			ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
-			return
 		}
+
+		r.Header.Set("user", username)
+		next.ServeHTTP(w, r)
 	}
 }