|  | @@ -84,37 +84,44 @@ func PreAuthCheck(next http.Handler) http.HandlerFunc {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		authToken := headerSplits[1]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		var claims jwt.RegisteredClaims
 | 
	
		
			
				|  |  | -		token, err := jwt.ParseWithClaims(authToken, &claims, func(token *jwt.Token) (interface{}, error) {
 | 
	
		
			
				|  |  | -			return jwtSecretKey, nil
 | 
	
		
			
				|  |  | -		})
 | 
	
		
			
				|  |  | +		// first check is user is authenticated.
 | 
	
		
			
				|  |  | +		// if yes, allow the user to go through.
 | 
	
		
			
				|  |  | +		username, err := GetUserNameFromToken(authToken)
 | 
	
		
			
				|  |  |  		if err != nil {
 | 
	
		
			
				|  |  | -			ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
 | 
	
		
			
				|  |  | -			return
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | +			// if no, then check the user has a pre-auth token.
 | 
	
		
			
				|  |  | +			var claims jwt.RegisteredClaims
 | 
	
		
			
				|  |  | +			token, err := jwt.ParseWithClaims(authToken, &claims, func(token *jwt.Token) (interface{}, error) {
 | 
	
		
			
				|  |  | +				return jwtSecretKey, nil
 | 
	
		
			
				|  |  | +			})
 | 
	
		
			
				|  |  | +			if err != nil {
 | 
	
		
			
				|  |  | +				ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
 | 
	
		
			
				|  |  | +				return
 | 
	
		
			
				|  |  | +			}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -		if token != nil && token.Valid {
 | 
	
		
			
				|  |  | -			if len(claims.Audience) > 0 {
 | 
	
		
			
				|  |  | -				var found bool
 | 
	
		
			
				|  |  | -				for _, aud := range claims.Audience {
 | 
	
		
			
				|  |  | -					// token created for mfa cannot be used for
 | 
	
		
			
				|  |  | -					// anything else.
 | 
	
		
			
				|  |  | -					if aud == "auth:mfa" {
 | 
	
		
			
				|  |  | -						found = true
 | 
	
		
			
				|  |  | +			if token != nil && token.Valid {
 | 
	
		
			
				|  |  | +				if len(claims.Audience) > 0 {
 | 
	
		
			
				|  |  | +					var found bool
 | 
	
		
			
				|  |  | +					for _, aud := range claims.Audience {
 | 
	
		
			
				|  |  | +						if aud == "auth:mfa" {
 | 
	
		
			
				|  |  | +							found = true
 | 
	
		
			
				|  |  | +						}
 | 
	
		
			
				|  |  |  					}
 | 
	
		
			
				|  |  | -				}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -				if !found {
 | 
	
		
			
				|  |  | +					if !found {
 | 
	
		
			
				|  |  | +						ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
 | 
	
		
			
				|  |  | +					}
 | 
	
		
			
				|  |  | +				} else {
 | 
	
		
			
				|  |  |  					ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
 | 
	
		
			
				|  |  | +					return
 | 
	
		
			
				|  |  |  				}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -				r.Header.Set("user", claims.Subject)
 | 
	
		
			
				|  |  | -				next.ServeHTTP(w, r)
 | 
	
		
			
				|  |  | +			} else {
 | 
	
		
			
				|  |  | +				ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
 | 
	
		
			
				|  |  | +				return
 | 
	
		
			
				|  |  |  			}
 | 
	
		
			
				|  |  | -		} else {
 | 
	
		
			
				|  |  | -			ReturnErrorResponse(w, r, FormatError(Unauthorized_Err, "unauthorized"))
 | 
	
		
			
				|  |  | -			return
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		r.Header.Set("user", username)
 | 
	
		
			
				|  |  | +		next.ServeHTTP(w, r)
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 |