Browse Source

Merge pull request #1887 from gravitl/feature_cli_sso

Add headless oauth login endpoint for CLI
dcarns 2 years ago
parent
commit
ae6a12b16b
7 changed files with 314 additions and 57 deletions
  1. 92 4
      auth/auth.go
  2. 93 0
      auth/headless_callback.go
  3. 1 2
      auth/nodesession.go
  4. 5 1
      cli/cmd/context/set.go
  5. 1 0
      cli/config/config.go
  6. 71 0
      cli/functions/http_client.go
  7. 51 50
      controllers/user.go

+ 92 - 4
auth/auth.go

@@ -6,10 +6,12 @@ import (
 	"errors"
 	"net/http"
 	"strings"
+	"time"
 
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/oauth2"
 
+	"github.com/gorilla/websocket"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/pro/netcache"
@@ -31,6 +33,7 @@ const (
 	auth_key               = "netmaker_auth"
 	user_signin_length     = 16
 	node_signin_length     = 64
+	headless_signin_length = 32
 )
 
 // OAuthUser - generic OAuth strategy user
@@ -42,7 +45,10 @@ type OAuthUser struct {
 	AccessToken       string `json:"accesstoken" bson:"accesstoken"`
 }
 
-var auth_provider *oauth2.Config
+var (
+	auth_provider *oauth2.Config
+	upgrader      = websocket.Upgrader{}
+)
 
 func getCurrentAuthFunctions() map[string]interface{} {
 	var authInfo = servercfg.GetAuthProviderInfo()
@@ -104,9 +110,17 @@ func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
 	}
 	state, _ := getStateAndCode(r)
 	_, err := netcache.Get(state) // if in netcache proceeed with node registration login
-	if err == nil || len(state) == node_signin_length || errors.Is(err, netcache.ErrExpired) {
-		logger.Log(0, "proceeding with node SSO callback")
-		HandleNodeSSOCallback(w, r)
+	if err == nil || errors.Is(err, netcache.ErrExpired) {
+		switch len(state) {
+		case node_signin_length:
+			logger.Log(0, "proceeding with node SSO callback")
+			HandleNodeSSOCallback(w, r)
+		case headless_signin_length:
+			logger.Log(0, "proceeding with headless SSO callback")
+			HandleHeadlessSSOCallback(w, r)
+		default:
+			logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))
+		}
 	} else { // handle normal login
 		functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
 	}
@@ -146,6 +160,80 @@ func IsOauthUser(user *models.User) error {
 	return bCryptErr
 }
 
+// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
+func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
+	conn, err := upgrader.Upgrade(w, r, nil)
+	if err != nil {
+		logger.Log(0, "error during connection upgrade for headless sign-in:", err.Error())
+		return
+	}
+	if conn == nil {
+		logger.Log(0, "failed to establish web-socket connection during headless sign-in")
+		return
+	}
+	defer conn.Close()
+
+	req := &netcache.CValue{User: "", Pass: ""}
+	stateStr := logic.RandomString(headless_signin_length)
+	if err = netcache.Set(stateStr, req); err != nil {
+		logger.Log(0, "Failed to process sso request -", err.Error())
+		return
+	}
+
+	timeout := make(chan bool, 1)
+	answer := make(chan string, 1)
+	defer close(answer)
+	defer close(timeout)
+
+	if auth_provider == nil {
+		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
+			logger.Log(0, "error during message writing:", err.Error())
+		}
+		return
+	}
+	redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
+	if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {
+		logger.Log(0, "error during message writing:", err.Error())
+	}
+
+	go func() {
+		for {
+			cachedReq, err := netcache.Get(stateStr)
+			if err != nil {
+				if strings.Contains(err.Error(), "expired") {
+					logger.Log(0, "timeout occurred while waiting for SSO")
+					timeout <- true
+					break
+				}
+				continue
+			} else if cachedReq.Pass != "" {
+				logger.Log(0, "SSO process completed for user ", cachedReq.User)
+				answer <- cachedReq.Pass
+				break
+			}
+			time.Sleep(500) // try it 2 times per second to see if auth is completed
+		}
+	}()
+
+	select {
+	case result := <-answer:
+		if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {
+			logger.Log(0, "Error during message writing:", err.Error())
+		}
+	case <-timeout:
+		logger.Log(0, "Authentication server time out for headless SSO login")
+		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
+			logger.Log(0, "Error during message writing:", err.Error())
+		}
+	}
+	if err = netcache.Del(stateStr); err != nil {
+		logger.Log(0, "failed to remove SSO cache entry", err.Error())
+	}
+	if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
+		logger.Log(0, "write close:", err.Error())
+	}
+}
+
 // == private methods ==
 
 func addUser(email string) error {

+ 93 - 0
auth/headless_callback.go

@@ -0,0 +1,93 @@
+package auth
+
+import (
+	"bytes"
+	"fmt"
+	"net/http"
+
+	"github.com/gravitl/netmaker/logger"
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/logic/pro/netcache"
+	"github.com/gravitl/netmaker/models"
+)
+
+// HandleHeadlessSSOCallback - handle OAuth callback for headless logins such as Netmaker CLI
+func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
+	functions := getCurrentAuthFunctions()
+	if functions == nil {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("bad conf"))
+		logger.Log(0, "Missing Oauth config in HandleHeadlessSSOCallback")
+		return
+	}
+	state, code := getStateAndCode(r)
+
+	userClaims, err := functions[get_user_info].(func(string, string) (*OAuthUser, error))(state, code)
+	if err != nil {
+		logger.Log(0, "error when getting user info from callback:", err.Error())
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("Failed to retrieve OAuth user claims"))
+		return
+	}
+
+	if code == "" || state == "" {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("Wrong params"))
+		logger.Log(0, "Missing params in HandleHeadlessSSOCallback")
+		return
+	}
+
+	// all responses should be in html format from here on out
+	w.Header().Add("content-type", "text/html; charset=utf-8")
+
+	// retrieve machinekey from state cache
+	reqKeyIf, machineKeyFoundErr := netcache.Get(state)
+	if machineKeyFoundErr != nil {
+		logger.Log(0, "requested machine state key expired before authorisation completed -", machineKeyFoundErr.Error())
+		response := returnErrTemplate("", "requested machine state key expired before authorisation completed", state, reqKeyIf)
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write(response)
+		return
+	}
+
+	_, err = logic.GetUser(userClaims.getUserName())
+	if err != nil { // user must not exists, so try to make one
+		if err = addUser(userClaims.getUserName()); err != nil {
+			logger.Log(1, "could not create new user: ", userClaims.getUserName())
+			return
+		}
+	}
+	newPass, fetchErr := fetchPassValue("")
+	if fetchErr != nil {
+		return
+	}
+	jwt, jwtErr := logic.VerifyAuthRequest(models.UserAuthParams{
+		UserName: userClaims.getUserName(),
+		Password: newPass,
+	})
+	if jwtErr != nil {
+		logger.Log(1, "could not parse jwt for user", userClaims.getUserName())
+		return
+	}
+
+	logger.Log(1, "headless SSO login by user:", userClaims.getUserName())
+
+	// Send OK to user in the browser
+	var response bytes.Buffer
+	if err := ssoCallbackTemplate.Execute(&response, ssoCallbackTemplateConfig{
+		User: userClaims.getUserName(),
+		Verb: "Authenticated",
+	}); err != nil {
+		logger.Log(0, "Could not render SSO callback template ", err.Error())
+		response := returnErrTemplate(userClaims.getUserName(), "Could not render SSO callback template", state, reqKeyIf)
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write(response)
+	} else {
+		w.WriteHeader(http.StatusOK)
+		w.Write(response.Bytes())
+	}
+	reqKeyIf.Pass = fmt.Sprintf("JWT: %s", jwt)
+	if err = netcache.Set(state, reqKeyIf); err != nil {
+		logger.Log(0, "failed to set netcache for user", reqKeyIf.User, "-", err.Error())
+	}
+}

+ 1 - 2
auth/nodesession.go

@@ -1,7 +1,6 @@
 package auth
 
 import (
-	"encoding/hex"
 	"encoding/json"
 	"fmt"
 	"strings"
@@ -45,7 +44,7 @@ func SessionHandler(conn *websocket.Conn) {
 	req.Pass = ""
 	req.User = ""
 	// Add any extra parameter provided in the configuration to the Authorize Endpoint request??
-	stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length)))
+	stateStr := logic.RandomString(node_signin_length)
 	if err := netcache.Set(stateStr, req); err != nil {
 		logger.Log(0, "Failed to process sso request -", err.Error())
 		return

+ 5 - 1
cli/cmd/context/set.go

@@ -12,6 +12,7 @@ var (
 	username  string
 	password  string
 	masterKey string
+	sso       bool
 )
 
 var contextSetCmd = &cobra.Command{
@@ -25,8 +26,9 @@ var contextSetCmd = &cobra.Command{
 			Username:  username,
 			Password:  password,
 			MasterKey: masterKey,
+			SSO:       sso,
 		}
-		if ctx.Username == "" && ctx.MasterKey == "" {
+		if ctx.Username == "" && ctx.MasterKey == "" && !ctx.SSO {
 			cmd.Usage()
 			log.Fatal("Either username/password or master key is required")
 		}
@@ -36,9 +38,11 @@ var contextSetCmd = &cobra.Command{
 
 func init() {
 	contextSetCmd.Flags().StringVar(&endpoint, "endpoint", "", "Endpoint of the API Server")
+	contextSetCmd.MarkFlagRequired("endpoint")
 	contextSetCmd.Flags().StringVar(&username, "username", "", "Username")
 	contextSetCmd.Flags().StringVar(&password, "password", "", "Password")
 	contextSetCmd.MarkFlagsRequiredTogether("username", "password")
+	contextSetCmd.Flags().BoolVar(&sso, "sso", false, "Login via Single Sign On (SSO) ?")
 	contextSetCmd.Flags().StringVar(&masterKey, "master_key", "", "Master Key")
 	rootCmd.AddCommand(contextSetCmd)
 }

+ 1 - 0
cli/config/config.go

@@ -17,6 +17,7 @@ type Context struct {
 	MasterKey string `yaml:"masterkey,omitempty"`
 	Current   bool   `yaml:"current,omitempty"`
 	AuthToken string `yaml:"auth_token,omitempty"`
+	SSO       bool   `yaml:"sso,omitempty"`
 }
 
 var (

+ 71 - 0
cli/functions/http_client.go

@@ -3,18 +3,89 @@ package functions
 import (
 	"bytes"
 	"encoding/json"
+	"fmt"
 	"io"
 	"log"
 	"net/http"
+	"net/url"
+	"os"
+	"os/signal"
+	"strings"
 
+	"github.com/gorilla/websocket"
 	"github.com/gravitl/netmaker/cli/config"
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 )
 
+func ssoLogin(endpoint string) string {
+	var (
+		authToken string
+		interrupt = make(chan os.Signal, 1)
+		url, _    = url.Parse(endpoint)
+		socketURL = fmt.Sprintf("wss://%s/api/oauth/headless", url.Host)
+	)
+	signal.Notify(interrupt, os.Interrupt)
+	conn, _, err := websocket.DefaultDialer.Dial(socketURL, nil)
+	if err != nil {
+		log.Fatal("error connecting to endpoint ", socketURL, err.Error())
+	}
+	defer conn.Close()
+	_, msg, err := conn.ReadMessage()
+	if err != nil {
+		log.Fatal("error reading from server: ", err.Error())
+	}
+	fmt.Printf("Please visit:\n %s \n to authenticate\n", string(msg))
+	done := make(chan struct{})
+	defer close(done)
+	go func() {
+		for {
+			msgType, msg, err := conn.ReadMessage()
+			if err != nil {
+				if msgType < 0 {
+					done <- struct{}{}
+					return
+				}
+				if !strings.Contains(err.Error(), "normal") {
+					log.Fatal("read error: ", err.Error())
+				}
+				return
+			}
+			if msgType == websocket.CloseMessage {
+				done <- struct{}{}
+				return
+			}
+			if strings.Contains(string(msg), "JWT: ") {
+				authToken = strings.TrimPrefix(string(msg), "JWT: ")
+			} else {
+				logger.Log(0, "Message from server:", string(msg))
+				return
+			}
+		}
+	}()
+	for {
+		select {
+		case <-done:
+			return authToken
+		case <-interrupt:
+			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
+			if err != nil {
+				logger.Log(0, "write close:", err.Error())
+			}
+			return authToken
+		}
+	}
+}
+
 func getAuthToken(ctx config.Context, force bool) string {
 	if !force && ctx.AuthToken != "" {
 		return ctx.AuthToken
 	}
+	if ctx.SSO {
+		authToken := ssoLogin(ctx.Endpoint)
+		config.SetAuthToken(authToken)
+		return authToken
+	}
 	authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}
 	payload, err := json.Marshal(authParams)
 	if err != nil {

+ 51 - 50
controllers/user.go

@@ -34,6 +34,7 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/node-handler", socketHandler)
+	r.HandleFunc("/api/oauth/headless", auth.HandleHeadlessSSO)
 	r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterNodeSSO).Methods(http.MethodGet)
 }
 
@@ -41,13 +42,13 @@ func userHandlers(r *mux.Router) {
 //
 // Node authenticates using its password and retrieves a JWT for authorization.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: successResponse
+//			Responses:
+//				200: successResponse
 func authenticateUser(response http.ResponseWriter, request *http.Request) {
 
 	// Auth request consists of Mac Address and Password (from node that is authorizing
@@ -113,13 +114,13 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 //
 // Checks whether the server has an admin.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: successResponse
+//			Responses:
+//				200: successResponse
 func hasAdmin(w http.ResponseWriter, r *http.Request) {
 
 	w.Header().Set("Content-Type", "application/json")
@@ -139,13 +140,13 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
 //
 // Get an individual user.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func getUser(w http.ResponseWriter, r *http.Request) {
 	// set header.
 	w.Header().Set("Content-Type", "application/json")
@@ -167,13 +168,13 @@ func getUser(w http.ResponseWriter, r *http.Request) {
 //
 // Get all users.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func getUsers(w http.ResponseWriter, r *http.Request) {
 	// set header.
 	w.Header().Set("Content-Type", "application/json")
@@ -194,13 +195,13 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
 //
 // Make a user an admin.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func createAdmin(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 
@@ -236,13 +237,13 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
 //
 // Create a user.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func createUser(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 
@@ -270,13 +271,13 @@ func createUser(w http.ResponseWriter, r *http.Request) {
 //
 // Updates the networks of the given user.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	var params = mux.Vars(r)
@@ -319,13 +320,13 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 //
 // Update a user.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func updateUser(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	var params = mux.Vars(r)
@@ -369,13 +370,13 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 //
 // Updates the given admin user's info (as long as the user is an admin).
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	var params = mux.Vars(r)
@@ -420,13 +421,13 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 //
 // Delete a user.
 //
-//		Schemes: https
+//			Schemes: https
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func deleteUser(w http.ResponseWriter, r *http.Request) {
 	// Set header
 	w.Header().Set("Content-Type", "application/json")