Browse Source

add headless sso login

Anish Mukherjee 2 years ago
parent
commit
289bb3e5ec
5 changed files with 212 additions and 52 deletions
  1. 81 1
      auth/auth.go
  2. 5 1
      cli/cmd/context/set.go
  3. 1 0
      cli/config/config.go
  4. 74 0
      cli/functions/http_client.go
  5. 51 50
      controllers/user.go

+ 81 - 1
auth/auth.go

@@ -2,15 +2,18 @@ package auth
 
 
 import (
 import (
 	"encoding/base64"
 	"encoding/base64"
+	"encoding/hex"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
 	"strings"
 	"strings"
+	"time"
 
 
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/crypto/bcrypt"
 	"golang.org/x/oauth2"
 	"golang.org/x/oauth2"
 
 
+	"github.com/gorilla/websocket"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/pro/netcache"
 	"github.com/gravitl/netmaker/logic/pro/netcache"
@@ -43,7 +46,10 @@ type OAuthUser struct {
 	AccessToken       string `json:"accesstoken" bson:"accesstoken"`
 	AccessToken       string `json:"accesstoken" bson:"accesstoken"`
 }
 }
 
 
-var auth_provider *oauth2.Config
+var (
+	auth_provider *oauth2.Config
+	upgrader      = websocket.Upgrader{}
+)
 
 
 func getCurrentAuthFunctions() map[string]interface{} {
 func getCurrentAuthFunctions() map[string]interface{} {
 	var authInfo = servercfg.GetAuthProviderInfo()
 	var authInfo = servercfg.GetAuthProviderInfo()
@@ -154,6 +160,80 @@ func IsOauthUser(user *models.User) error {
 	return bCryptErr
 	return bCryptErr
 }
 }
 
 
+// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
+func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
+	conn, err := upgrader.Upgrade(w, r, nil)
+	if err != nil {
+		logger.Log(0, "error during connection upgrade for node sign-in:", err.Error())
+		return
+	}
+	if conn == nil {
+		logger.Log(0, "failed to establish web-socket connection during node sign-in")
+		return
+	}
+	defer conn.Close()
+
+	req := &netcache.CValue{User: "", Pass: ""}
+	stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length)))
+	if err = netcache.Set(stateStr, req); err != nil {
+		logger.Log(0, "Failed to process sso request -", err.Error())
+		return
+	}
+
+	timeout := make(chan bool, 1)
+	answer := make(chan string, 1)
+	defer close(answer)
+	defer close(timeout)
+
+	if auth_provider == nil {
+		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
+			logger.Log(0, "error during message writing:", err.Error())
+		}
+		return
+	}
+	redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
+	if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {
+		logger.Log(0, "error during message writing:", err.Error())
+	}
+
+	go func() {
+		for {
+			cachedReq, err := netcache.Get(stateStr)
+			if err != nil {
+				if strings.Contains(err.Error(), "expired") {
+					logger.Log(0, "timeout occurred while waiting for SSO")
+					timeout <- true
+					break
+				}
+				continue
+			} else if cachedReq.Pass != "" {
+				logger.Log(0, "SSO process completed for user ", cachedReq.User)
+				answer <- cachedReq.Pass
+				break
+			}
+			time.Sleep(500) // try it 2 times per second to see if auth is completed
+		}
+	}()
+
+	select {
+	case result := <-answer:
+		if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {
+			logger.Log(0, "Error during message writing:", err.Error())
+		}
+	case <-timeout:
+		logger.Log(0, "Authentication server time out for headless SSO login")
+		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
+			logger.Log(0, "Error during message writing:", err.Error())
+		}
+	}
+	if err = netcache.Del(stateStr); err != nil {
+		logger.Log(0, "failed to remove node SSO cache entry", err.Error())
+	}
+	if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
+		logger.Log(0, "write close:", err.Error())
+	}
+}
+
 // == private methods ==
 // == private methods ==
 
 
 func addUser(email string) error {
 func addUser(email string) error {

+ 5 - 1
cli/cmd/context/set.go

@@ -12,6 +12,7 @@ var (
 	username  string
 	username  string
 	password  string
 	password  string
 	masterKey string
 	masterKey string
+	sso       bool
 )
 )
 
 
 var contextSetCmd = &cobra.Command{
 var contextSetCmd = &cobra.Command{
@@ -25,8 +26,9 @@ var contextSetCmd = &cobra.Command{
 			Username:  username,
 			Username:  username,
 			Password:  password,
 			Password:  password,
 			MasterKey: masterKey,
 			MasterKey: masterKey,
+			SSO:       sso,
 		}
 		}
-		if ctx.Username == "" && ctx.MasterKey == "" {
+		if ctx.Username == "" && ctx.MasterKey == "" && !ctx.SSO {
 			cmd.Usage()
 			cmd.Usage()
 			log.Fatal("Either username/password or master key is required")
 			log.Fatal("Either username/password or master key is required")
 		}
 		}
@@ -36,9 +38,11 @@ var contextSetCmd = &cobra.Command{
 
 
 func init() {
 func init() {
 	contextSetCmd.Flags().StringVar(&endpoint, "endpoint", "", "Endpoint of the API Server")
 	contextSetCmd.Flags().StringVar(&endpoint, "endpoint", "", "Endpoint of the API Server")
+	contextSetCmd.MarkFlagRequired("endpoint")
 	contextSetCmd.Flags().StringVar(&username, "username", "", "Username")
 	contextSetCmd.Flags().StringVar(&username, "username", "", "Username")
 	contextSetCmd.Flags().StringVar(&password, "password", "", "Password")
 	contextSetCmd.Flags().StringVar(&password, "password", "", "Password")
 	contextSetCmd.MarkFlagsRequiredTogether("username", "password")
 	contextSetCmd.MarkFlagsRequiredTogether("username", "password")
+	contextSetCmd.Flags().BoolVar(&sso, "sso", false, "Login via Single Sign On (SSO) ?")
 	contextSetCmd.Flags().StringVar(&masterKey, "master_key", "", "Master Key")
 	contextSetCmd.Flags().StringVar(&masterKey, "master_key", "", "Master Key")
 	rootCmd.AddCommand(contextSetCmd)
 	rootCmd.AddCommand(contextSetCmd)
 }
 }

+ 1 - 0
cli/config/config.go

@@ -17,6 +17,7 @@ type Context struct {
 	MasterKey string `yaml:"masterkey,omitempty"`
 	MasterKey string `yaml:"masterkey,omitempty"`
 	Current   bool   `yaml:"current,omitempty"`
 	Current   bool   `yaml:"current,omitempty"`
 	AuthToken string `yaml:"auth_token,omitempty"`
 	AuthToken string `yaml:"auth_token,omitempty"`
+	SSO       bool   `yaml:"sso,omitempty"`
 }
 }
 
 
 var (
 var (

+ 74 - 0
cli/functions/http_client.go

@@ -3,18 +3,92 @@ package functions
 import (
 import (
 	"bytes"
 	"bytes"
 	"encoding/json"
 	"encoding/json"
+	"fmt"
 	"io"
 	"io"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
+	"os"
+	"os/signal"
+	"strings"
 
 
+	"github.com/gorilla/websocket"
 	"github.com/gravitl/netmaker/cli/config"
 	"github.com/gravitl/netmaker/cli/config"
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
+	nmconfig "github.com/gravitl/netmaker/netclient/config"
 )
 )
 
 
+func ssoLogin(endpoint string) string {
+	var (
+		accessToken *models.AccessToken
+		interrupt   = make(chan os.Signal, 1)
+		socketURL   = fmt.Sprintf("wss://%s/api/oauth/headless", endpoint)
+	)
+	signal.Notify(interrupt, os.Interrupt)
+	conn, _, err := websocket.DefaultDialer.Dial(socketURL, nil)
+	if err != nil {
+		log.Fatal("error connecting to endpoint: ", err.Error())
+	}
+	defer conn.Close()
+	_, msg, err := conn.ReadMessage()
+	if err != nil {
+		log.Fatal("error reading from server: ", err.Error())
+	}
+	fmt.Printf("Please visit:\n %s \n to authenticate", string(msg))
+	done := make(chan struct{})
+	defer close(done)
+	go func() {
+		for {
+			msgType, msg, err := conn.ReadMessage()
+			if err != nil {
+				if msgType < 0 {
+					done <- struct{}{}
+					return
+				}
+				if !strings.Contains(err.Error(), "normal") {
+					log.Fatal("read error: ", err.Error())
+				}
+				return
+			}
+			if msgType == websocket.CloseMessage {
+				done <- struct{}{}
+				return
+			}
+			if strings.Contains(string(msg), "AccessToken: ") {
+				// Access was granted
+				rxToken := strings.TrimPrefix(string(msg), "AccessToken: ")
+				if accessToken, err = nmconfig.ParseAccessToken(rxToken); err != nil {
+					log.Fatalf("failed to parse received access token %s,err=%s\n", accessToken, err.Error())
+				}
+			} else {
+				logger.Log(0, "Message from server:", string(msg))
+				return
+			}
+		}
+	}()
+	for {
+		select {
+		case <-done:
+			return accessToken.Key
+		case <-interrupt:
+			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
+			if err != nil {
+				logger.Log(0, "write close:", err.Error())
+			}
+			return accessToken.Key
+		}
+	}
+}
+
 func getAuthToken(ctx config.Context, force bool) string {
 func getAuthToken(ctx config.Context, force bool) string {
 	if !force && ctx.AuthToken != "" {
 	if !force && ctx.AuthToken != "" {
 		return ctx.AuthToken
 		return ctx.AuthToken
 	}
 	}
+	if ctx.SSO {
+		authToken := ssoLogin(ctx.Endpoint)
+		config.SetAuthToken(authToken)
+		return authToken
+	}
 	authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}
 	authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}
 	payload, err := json.Marshal(authParams)
 	payload, err := json.Marshal(authParams)
 	if err != nil {
 	if err != nil {

+ 51 - 50
controllers/user.go

@@ -34,6 +34,7 @@ func userHandlers(r *mux.Router) {
 	r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/node-handler", socketHandler)
 	r.HandleFunc("/api/oauth/node-handler", socketHandler)
+	r.HandleFunc("/api/oauth/headless", auth.HandleHeadlessSSO)
 	r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterNodeSSO).Methods(http.MethodGet)
 	r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterNodeSSO).Methods(http.MethodGet)
 }
 }
 
 
@@ -41,13 +42,13 @@ func userHandlers(r *mux.Router) {
 //
 //
 // Node authenticates using its password and retrieves a JWT for authorization.
 // Node authenticates using its password and retrieves a JWT for authorization.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: successResponse
+//			Responses:
+//				200: successResponse
 func authenticateUser(response http.ResponseWriter, request *http.Request) {
 func authenticateUser(response http.ResponseWriter, request *http.Request) {
 
 
 	// Auth request consists of Mac Address and Password (from node that is authorizing
 	// Auth request consists of Mac Address and Password (from node that is authorizing
@@ -113,13 +114,13 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 //
 //
 // Checks whether the server has an admin.
 // Checks whether the server has an admin.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: successResponse
+//			Responses:
+//				200: successResponse
 func hasAdmin(w http.ResponseWriter, r *http.Request) {
 func hasAdmin(w http.ResponseWriter, r *http.Request) {
 
 
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
@@ -139,13 +140,13 @@ func hasAdmin(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Get an individual user.
 // Get an individual user.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func getUser(w http.ResponseWriter, r *http.Request) {
 func getUser(w http.ResponseWriter, r *http.Request) {
 	// set header.
 	// set header.
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
@@ -167,13 +168,13 @@ func getUser(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Get all users.
 // Get all users.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func getUsers(w http.ResponseWriter, r *http.Request) {
 func getUsers(w http.ResponseWriter, r *http.Request) {
 	// set header.
 	// set header.
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
@@ -194,13 +195,13 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Make a user an admin.
 // Make a user an admin.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func createAdmin(w http.ResponseWriter, r *http.Request) {
 func createAdmin(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 
 
@@ -236,13 +237,13 @@ func createAdmin(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Create a user.
 // Create a user.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func createUser(w http.ResponseWriter, r *http.Request) {
 func createUser(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 
 
@@ -270,13 +271,13 @@ func createUser(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Updates the networks of the given user.
 // Updates the networks of the given user.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
@@ -319,13 +320,13 @@ func updateUserNetworks(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Update a user.
 // Update a user.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func updateUser(w http.ResponseWriter, r *http.Request) {
 func updateUser(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
@@ -369,13 +370,13 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Updates the given admin user's info (as long as the user is an admin).
 // Updates the given admin user's info (as long as the user is an admin).
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	var params = mux.Vars(r)
 	var params = mux.Vars(r)
@@ -420,13 +421,13 @@ func updateUserAdm(w http.ResponseWriter, r *http.Request) {
 //
 //
 // Delete a user.
 // Delete a user.
 //
 //
-//		Schemes: https
+//			Schemes: https
 //
 //
-// 		Security:
-//   		oauth
+//			Security:
+//	  		oauth
 //
 //
-//		Responses:
-//			200: userBodyResponse
+//			Responses:
+//				200: userBodyResponse
 func deleteUser(w http.ResponseWriter, r *http.Request) {
 func deleteUser(w http.ResponseWriter, r *http.Request) {
 	// Set header
 	// Set header
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")