Browse Source

add callback for headless sso

Anish Mukherjee 2 years ago
parent
commit
93fdf359b3
3 changed files with 103 additions and 16 deletions
  1. 13 4
      auth/auth.go
  2. 80 0
      auth/nodecallback.go
  3. 10 12
      cli/functions/http_client.go

+ 13 - 4
auth/auth.go

@@ -35,6 +35,7 @@ const (
 	auth_key               = "netmaker_auth"
 	auth_key               = "netmaker_auth"
 	user_signin_length     = 16
 	user_signin_length     = 16
 	node_signin_length     = 64
 	node_signin_length     = 64
+	headless_signin_length = 32
 )
 )
 
 
 // OAuthUser - generic OAuth strategy user
 // OAuthUser - generic OAuth strategy user
@@ -116,9 +117,17 @@ func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 	state, _ := getStateAndCode(r)
 	state, _ := getStateAndCode(r)
 	_, err := netcache.Get(state) // if in netcache proceeed with node registration login
 	_, err := netcache.Get(state) // if in netcache proceeed with node registration login
-	if err == nil || len(state) == node_signin_length || errors.Is(err, netcache.ErrExpired) {
-		logger.Log(0, "proceeding with node SSO callback")
-		HandleNodeSSOCallback(w, r)
+	if err == nil || errors.Is(err, netcache.ErrExpired) {
+		switch len(state) {
+		case node_signin_length:
+			logger.Log(0, "proceeding with node SSO callback")
+			HandleNodeSSOCallback(w, r)
+		case headless_signin_length:
+			logger.Log(0, "proceeding with headless SSO callback")
+			HandleHeadlessSSOCallback(w, r)
+		default:
+			logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))
+		}
 	} else { // handle normal login
 	} else { // handle normal login
 		functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
 		functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
 	}
 	}
@@ -174,7 +183,7 @@ func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
 	defer conn.Close()
 	defer conn.Close()
 
 
 	req := &netcache.CValue{User: "", Pass: ""}
 	req := &netcache.CValue{User: "", Pass: ""}
-	stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length)))
+	stateStr := hex.EncodeToString([]byte(logic.RandomString(headless_signin_length)))
 	if err = netcache.Set(stateStr, req); err != nil {
 	if err = netcache.Set(stateStr, req); err != nil {
 		logger.Log(0, "Failed to process sso request -", err.Error())
 		logger.Log(0, "Failed to process sso request -", err.Error())
 		return
 		return

+ 80 - 0
auth/nodecallback.go

@@ -122,6 +122,86 @@ func HandleNodeSSOCallback(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 }
 }
 
 
+// HandleHeadlessSSOCallback - handle OAuth callback for headless logins such as Netmaker CLI
+func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) {
+	functions := getCurrentAuthFunctions()
+	if functions == nil {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("bad conf"))
+		logger.Log(0, "Missing Oauth config in HandleHeadlessSSOCallback")
+		return
+	}
+	state, code := getStateAndCode(r)
+
+	userClaims, err := functions[get_user_info].(func(string, string) (*OAuthUser, error))(state, code)
+	if err != nil {
+		logger.Log(0, "error when getting user info from callback:", err.Error())
+		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
+		return
+	}
+
+	if code == "" || state == "" {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("Wrong params"))
+		logger.Log(0, "Missing params in HandleHeadlessSSOCallback")
+		return
+	}
+
+	// all responses should be in html format from here on out
+	w.Header().Add("content-type", "text/html; charset=utf-8")
+
+	// retrieve machinekey from state cache
+	reqKeyIf, machineKeyFoundErr := netcache.Get(state)
+	if machineKeyFoundErr != nil {
+		logger.Log(0, "requested machine state key expired before authorisation completed -", machineKeyFoundErr.Error())
+		response := returnErrTemplate("", "requested machine state key expired before authorisation completed", state, reqKeyIf)
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write(response)
+		return
+	}
+
+	_, err = logic.GetUser(userClaims.getUserName())
+	if err != nil { // user must not exists, so try to make one
+		if err = addUser(userClaims.getUserName()); err != nil {
+			logger.Log(1, "could not create new user: ", userClaims.getUserName())
+			return
+		}
+	}
+	newPass, fetchErr := fetchPassValue("")
+	if fetchErr != nil {
+		return
+	}
+	jwt, jwtErr := logic.VerifyAuthRequest(models.UserAuthParams{
+		UserName: userClaims.getUserName(),
+		Password: newPass,
+	})
+	if jwtErr != nil {
+		logger.Log(1, "could not parse jwt for user", userClaims.getUserName())
+		return
+	}
+
+	logger.Log(1, "headless SSO login by user:", userClaims.getUserName())
+
+	// Send OK to user in the browser
+	var response bytes.Buffer
+	if err := ssoCallbackTemplate.Execute(&response, ssoCallbackTemplateConfig{
+		User: userClaims.getUserName(),
+		Verb: "Authenticated",
+	}); err != nil {
+		logger.Log(0, "Could not render SSO callback template ", err.Error())
+		response := returnErrTemplate(userClaims.getUserName(), "Could not render SSO callback template", state, reqKeyIf)
+		w.WriteHeader(http.StatusInternalServerError)
+		w.Write(response)
+	} else {
+		w.WriteHeader(http.StatusOK)
+		w.Write(response.Bytes())
+	}
+	reqKeyIf.Pass = fmt.Sprintf("JWT: %s", jwt)
+	if err = netcache.Set(state, reqKeyIf); err != nil {
+		logger.Log(0, "failed to set netcache for user", reqKeyIf.User, "-", err.Error())
+	}
+}
+
 func setNetcache(ncache *netcache.CValue, state string) error {
 func setNetcache(ncache *netcache.CValue, state string) error {
 	if ncache == nil {
 	if ncache == nil {
 		return fmt.Errorf("cache miss")
 		return fmt.Errorf("cache miss")

+ 10 - 12
cli/functions/http_client.go

@@ -7,6 +7,7 @@ import (
 	"io"
 	"io"
 	"log"
 	"log"
 	"net/http"
 	"net/http"
+	"net/url"
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"strings"
 	"strings"
@@ -15,19 +16,19 @@ import (
 	"github.com/gravitl/netmaker/cli/config"
 	"github.com/gravitl/netmaker/cli/config"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
-	nmconfig "github.com/gravitl/netmaker/netclient/config"
 )
 )
 
 
 func ssoLogin(endpoint string) string {
 func ssoLogin(endpoint string) string {
 	var (
 	var (
-		accessToken *models.AccessToken
-		interrupt   = make(chan os.Signal, 1)
-		socketURL   = fmt.Sprintf("wss://%s/api/oauth/headless", endpoint)
+		authToken string
+		interrupt = make(chan os.Signal, 1)
+		url, _    = url.Parse(endpoint)
+		socketURL = fmt.Sprintf("wss://%s/api/oauth/headless", url.Host)
 	)
 	)
 	signal.Notify(interrupt, os.Interrupt)
 	signal.Notify(interrupt, os.Interrupt)
 	conn, _, err := websocket.DefaultDialer.Dial(socketURL, nil)
 	conn, _, err := websocket.DefaultDialer.Dial(socketURL, nil)
 	if err != nil {
 	if err != nil {
-		log.Fatal("error connecting to endpoint: ", err.Error())
+		log.Fatal("error connecting to endpoint ", socketURL, err.Error())
 	}
 	}
 	defer conn.Close()
 	defer conn.Close()
 	_, msg, err := conn.ReadMessage()
 	_, msg, err := conn.ReadMessage()
@@ -54,12 +55,9 @@ func ssoLogin(endpoint string) string {
 				done <- struct{}{}
 				done <- struct{}{}
 				return
 				return
 			}
 			}
-			if strings.Contains(string(msg), "AccessToken: ") {
+			if strings.Contains(string(msg), "JWT: ") {
 				// Access was granted
 				// Access was granted
-				rxToken := strings.TrimPrefix(string(msg), "AccessToken: ")
-				if accessToken, err = nmconfig.ParseAccessToken(rxToken); err != nil {
-					log.Fatalf("failed to parse received access token %s,err=%s\n", accessToken, err.Error())
-				}
+				authToken = strings.TrimPrefix(string(msg), "JWT: ")
 			} else {
 			} else {
 				logger.Log(0, "Message from server:", string(msg))
 				logger.Log(0, "Message from server:", string(msg))
 				return
 				return
@@ -69,13 +67,13 @@ func ssoLogin(endpoint string) string {
 	for {
 	for {
 		select {
 		select {
 		case <-done:
 		case <-done:
-			return accessToken.Key
+			return authToken
 		case <-interrupt:
 		case <-interrupt:
 			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			if err != nil {
 			if err != nil {
 				logger.Log(0, "write close:", err.Error())
 				logger.Log(0, "write close:", err.Error())
 			}
 			}
-			return accessToken.Key
+			return authToken
 		}
 		}
 	}
 	}
 }
 }