Jelajahi Sumber

allow required updates for oAuth user

Abhishek Kondur 2 tahun lalu
induk
melakukan
5cacd7b041
8 mengubah file dengan 53 tambahan dan 66 penghapusan
  1. 2 47
      auth/auth.go
  2. 1 1
      auth/azure-ad.go
  3. 1 1
      auth/github.go
  4. 1 1
      auth/google.go
  5. 1 1
      auth/headless_callback.go
  6. 1 1
      auth/oidc.go
  7. 0 12
      controllers/user.go
  8. 46 2
      logic/auth.go

+ 2 - 47
auth/auth.go

@@ -1,15 +1,12 @@
 package auth
 
 import (
-	"encoding/base64"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"net/http"
 	"strings"
 	"time"
 
-	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/oauth2"
 
 	"github.com/gorilla/websocket"
@@ -31,7 +28,6 @@ const (
 	github_provider_name   = "github"
 	oidc_provider_name     = "oidc"
 	verify_user            = "verifyuser"
-	auth_key               = "netmaker_auth"
 	user_signin_length     = 16
 	node_signin_length     = 64
 	headless_signin_length = 32
@@ -74,7 +70,7 @@ func InitializeAuthProvider() string {
 	if functions == nil {
 		return ""
 	}
-	var _, err = fetchPassValue(logic.RandomString(64))
+	var _, err = logic.FetchPassValue(logic.RandomString(64))
 	if err != nil {
 		logger.Log(0, err.Error())
 		return ""
@@ -151,16 +147,6 @@ func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
 	functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
 }
 
-// IsOauthUser - returns
-func IsOauthUser(user *models.User) error {
-	var currentValue, err = fetchPassValue("")
-	if err != nil {
-		return err
-	}
-	var bCryptErr = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentValue))
-	return bCryptErr
-}
-
 // HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
 func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
 	conn, err := upgrader.Upgrade(w, r, nil)
@@ -243,7 +229,7 @@ func addUser(email string) error {
 		logger.Log(1, "error checking for existence of admin user during OAuth login for", email, "; user not added")
 		return err
 	} // generate random password to adapt to current model
-	var newPass, fetchErr = fetchPassValue("")
+	var newPass, fetchErr = logic.FetchPassValue("")
 	if fetchErr != nil {
 		return fetchErr
 	}
@@ -269,37 +255,6 @@ func addUser(email string) error {
 	return nil
 }
 
-func fetchPassValue(newValue string) (string, error) {
-
-	type valueHolder struct {
-		Value string `json:"value" bson:"value"`
-	}
-	var b64NewValue = base64.StdEncoding.EncodeToString([]byte(newValue))
-	var newValueHolder = &valueHolder{
-		Value: b64NewValue,
-	}
-	var data, marshalErr = json.Marshal(newValueHolder)
-	if marshalErr != nil {
-		return "", marshalErr
-	}
-
-	var currentValue, err = logic.FetchAuthSecret(auth_key, string(data))
-	if err != nil {
-		return "", err
-	}
-	var unmarshErr = json.Unmarshal([]byte(currentValue), newValueHolder)
-	if unmarshErr != nil {
-		return "", unmarshErr
-	}
-
-	var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)
-	if b64Err != nil {
-		logger.Log(0, "could not decode pass")
-		return "", nil
-	}
-	return string(b64CurrentValue), nil
-}
-
 func getStateAndCode(r *http.Request) (string, string) {
 	var state, code string
 	if r.FormValue("state") != "" && r.FormValue("code") != "" {

+ 1 - 1
auth/azure-ad.go

@@ -66,7 +66,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
-	var newPass, fetchErr = fetchPassValue("")
+	var newPass, fetchErr = logic.FetchPassValue("")
 	if fetchErr != nil {
 		return
 	}

+ 1 - 1
auth/github.go

@@ -66,7 +66,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
-	var newPass, fetchErr = fetchPassValue("")
+	var newPass, fetchErr = logic.FetchPassValue("")
 	if fetchErr != nil {
 		return
 	}

+ 1 - 1
auth/google.go

@@ -68,7 +68,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
-	var newPass, fetchErr = fetchPassValue("")
+	var newPass, fetchErr = logic.FetchPassValue("")
 	if fetchErr != nil {
 		return
 	}

+ 1 - 1
auth/headless_callback.go

@@ -57,7 +57,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
-	newPass, fetchErr := fetchPassValue("")
+	newPass, fetchErr := logic.FetchPassValue("")
 	if fetchErr != nil {
 		return
 	}

+ 1 - 1
auth/oidc.go

@@ -79,7 +79,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
-	var newPass, fetchErr = fetchPassValue("")
+	var newPass, fetchErr = logic.FetchPassValue("")
 	if fetchErr != nil {
 		return
 	}

+ 0 - 12
controllers/user.go

@@ -356,12 +356,6 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	if auth.IsOauthUser(user) == nil {
-		err := fmt.Errorf("cannot update user info for oauth user %s", username)
-		logger.Log(0, err.Error())
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
-		return
-	}
 	var userchange models.User
 	// we decode our body request params
 	err = json.NewDecoder(r.Body).Decode(&userchange)
@@ -409,12 +403,6 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	if auth.IsOauthUser(user) != nil {
-		err := fmt.Errorf("cannot update user info for oauth user %s", username)
-		logger.Log(0, err.Error())
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "forbidden"))
-		return
-	}
 	var userchange models.User
 	// we decode our body request params
 	err = json.NewDecoder(r.Body).Decode(&userchange)

+ 46 - 2
logic/auth.go

@@ -1,6 +1,7 @@
 package logic
 
 import (
+	"encoding/base64"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -17,6 +18,8 @@ import (
 	"github.com/gravitl/netmaker/servercfg"
 )
 
+const auth_key = "netmaker_auth"
+
 // HasAdmin - checks if server has an admin
 func HasAdmin() (bool, error) {
 
@@ -264,7 +267,7 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
 
 	queryUser := user.UserName
 
-	if userchange.UserName != "" {
+	if !IsOauthUser(user) && userchange.UserName != "" { // cannot update username for an oAuth user
 		user.UserName = userchange.UserName
 	}
 	if len(userchange.Networks) > 0 {
@@ -273,7 +276,7 @@ func UpdateUser(userchange, user *models.User) (*models.User, error) {
 	if len(userchange.Groups) > 0 {
 		user.Groups = userchange.Groups
 	}
-	if userchange.Password != "" {
+	if !IsOauthUser(user) && userchange.Password != "" { // cannot update password for an oAuth User
 		// encrypt that password so we never see it again
 		hash, err := bcrypt.GenerateFromPassword([]byte(userchange.Password), 5)
 
@@ -373,6 +376,47 @@ func FetchAuthSecret(key string, secret string) (string, error) {
 	return record, nil
 }
 
+// IsOauthUser - returns
+func IsOauthUser(user *models.User) bool {
+	var currentValue, err = FetchPassValue("")
+	if err != nil {
+		return false
+	}
+	var bCryptErr = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentValue))
+	return bCryptErr == nil
+}
+
+func FetchPassValue(newValue string) (string, error) {
+
+	type valueHolder struct {
+		Value string `json:"value" bson:"value"`
+	}
+	var b64NewValue = base64.StdEncoding.EncodeToString([]byte(newValue))
+	var newValueHolder = &valueHolder{
+		Value: b64NewValue,
+	}
+	var data, marshalErr = json.Marshal(newValueHolder)
+	if marshalErr != nil {
+		return "", marshalErr
+	}
+
+	var currentValue, err = FetchAuthSecret(auth_key, string(data))
+	if err != nil {
+		return "", err
+	}
+	var unmarshErr = json.Unmarshal([]byte(currentValue), newValueHolder)
+	if unmarshErr != nil {
+		return "", unmarshErr
+	}
+
+	var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)
+	if b64Err != nil {
+		logger.Log(0, "could not decode pass")
+		return "", nil
+	}
+	return string(b64CurrentValue), nil
+}
+
 // GetState - gets an SsoState from DB, if expired returns error
 func GetState(state string) (*models.SsoState, error) {
 	var s models.SsoState