Browse Source

feat(NET-678): add saas support to nmctl (#2687)

* feat(NET-678): add saas support to nmctl

* fix(NET-678): fix context endpoint for sso
Aceix 1 year ago
parent
commit
8aa185d880
4 changed files with 343 additions and 26 deletions
  1. 28 5
      cli/cmd/context/set.go
  2. 2 0
      cli/config/config.go
  3. 262 21
      cli/functions/http_client.go
  4. 51 0
      models/structs.go

+ 28 - 5
cli/cmd/context/set.go

@@ -1,9 +1,11 @@
 package context
 package context
 
 
 import (
 import (
+	"fmt"
 	"log"
 	"log"
 
 
 	"github.com/gravitl/netmaker/cli/config"
 	"github.com/gravitl/netmaker/cli/config"
+	"github.com/gravitl/netmaker/cli/functions"
 	"github.com/spf13/cobra"
 	"github.com/spf13/cobra"
 )
 )
 
 
@@ -13,6 +15,8 @@ var (
 	password  string
 	password  string
 	masterKey string
 	masterKey string
 	sso       bool
 	sso       bool
+	tenantId  string
+	saas      bool
 )
 )
 
 
 var contextSetCmd = &cobra.Command{
 var contextSetCmd = &cobra.Command{
@@ -27,10 +31,28 @@ var contextSetCmd = &cobra.Command{
 			Password:  password,
 			Password:  password,
 			MasterKey: masterKey,
 			MasterKey: masterKey,
 			SSO:       sso,
 			SSO:       sso,
+			TenantId:  tenantId,
+			Saas:      saas,
 		}
 		}
-		if ctx.Username == "" && ctx.MasterKey == "" && !ctx.SSO {
-			cmd.Usage()
-			log.Fatal("Either username/password or master key is required")
+		if !ctx.Saas {
+			if ctx.Username == "" && ctx.MasterKey == "" && !ctx.SSO {
+				log.Fatal("Either username/password or master key is required")
+				cmd.Usage()
+			}
+			if ctx.Endpoint == "" {
+				log.Fatal("Endpoint is required when for self-hosted tenants")
+				cmd.Usage()
+			}
+		} else {
+			if ctx.TenantId == "" {
+				log.Fatal("Tenant ID is required for SaaS tenants")
+				cmd.Usage()
+			}
+			ctx.Endpoint = fmt.Sprintf(functions.TenantUrlTemplate, tenantId)
+			if ctx.Username == "" && ctx.Password == "" && !ctx.SSO {
+				log.Fatal("Username/password is required for non-SSO SaaS contexts")
+				cmd.Usage()
+			}
 		}
 		}
 		config.SetContext(args[0], ctx)
 		config.SetContext(args[0], ctx)
 	},
 	},
@@ -38,11 +60,12 @@ var contextSetCmd = &cobra.Command{
 
 
 func init() {
 func init() {
 	contextSetCmd.Flags().StringVar(&endpoint, "endpoint", "", "Endpoint of the API Server")
 	contextSetCmd.Flags().StringVar(&endpoint, "endpoint", "", "Endpoint of the API Server")
-	contextSetCmd.MarkFlagRequired("endpoint")
 	contextSetCmd.Flags().StringVar(&username, "username", "", "Username")
 	contextSetCmd.Flags().StringVar(&username, "username", "", "Username")
 	contextSetCmd.Flags().StringVar(&password, "password", "", "Password")
 	contextSetCmd.Flags().StringVar(&password, "password", "", "Password")
 	contextSetCmd.MarkFlagsRequiredTogether("username", "password")
 	contextSetCmd.MarkFlagsRequiredTogether("username", "password")
-	contextSetCmd.Flags().BoolVar(&sso, "sso", false, "Login via Single Sign On (SSO) ?")
+	contextSetCmd.Flags().BoolVar(&sso, "sso", false, "Login via Single Sign On (SSO)?")
 	contextSetCmd.Flags().StringVar(&masterKey, "master_key", "", "Master Key")
 	contextSetCmd.Flags().StringVar(&masterKey, "master_key", "", "Master Key")
+	contextSetCmd.Flags().StringVar(&tenantId, "tenant_id", "", "Tenant ID")
+	contextSetCmd.Flags().BoolVar(&saas, "saas", false, "Is this context for a SaaS tenant?")
 	rootCmd.AddCommand(contextSetCmd)
 	rootCmd.AddCommand(contextSetCmd)
 }
 }

+ 2 - 0
cli/config/config.go

@@ -18,6 +18,8 @@ type Context struct {
 	Current   bool   `yaml:"current,omitempty"`
 	Current   bool   `yaml:"current,omitempty"`
 	AuthToken string `yaml:"auth_token,omitempty"`
 	AuthToken string `yaml:"auth_token,omitempty"`
 	SSO       bool   `yaml:"sso,omitempty"`
 	SSO       bool   `yaml:"sso,omitempty"`
+	TenantId  string `yaml:"tenant_id,omitempty"`
+	Saas      bool   `yaml:"saas,omitempty"`
 }
 }
 
 
 var (
 var (

+ 262 - 21
cli/functions/http_client.go

@@ -11,11 +11,19 @@ import (
 	"os"
 	"os"
 	"os/signal"
 	"os/signal"
 	"strings"
 	"strings"
+	"time"
 
 
 	"github.com/gorilla/websocket"
 	"github.com/gorilla/websocket"
 	"github.com/gravitl/netmaker/cli/config"
 	"github.com/gravitl/netmaker/cli/config"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
+	"golang.org/x/exp/slog"
+)
+
+const (
+	ambBaseUrl        = "https://api.accounts.netmaker.io"
+	TenantUrlTemplate = "https://api-%s.app.prod.netmaker.io"
+	ambOauthWssUrl    = "wss://api.accounts.netmaker.io/api/v1/auth/sso"
 )
 )
 
 
 func ssoLogin(endpoint string) string {
 func ssoLogin(endpoint string) string {
@@ -81,34 +89,57 @@ func getAuthToken(ctx config.Context, force bool) string {
 	if !force && ctx.AuthToken != "" {
 	if !force && ctx.AuthToken != "" {
 		return ctx.AuthToken
 		return ctx.AuthToken
 	}
 	}
-	if ctx.SSO {
-		authToken := ssoLogin(ctx.Endpoint)
+	if !ctx.Saas {
+		if ctx.SSO {
+			authToken := ssoLogin(ctx.Endpoint)
+			config.SetAuthToken(authToken)
+			return authToken
+		}
+		authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}
+		payload, err := json.Marshal(authParams)
+		if err != nil {
+			log.Fatal(err)
+		}
+		res, err := http.Post(ctx.Endpoint+"/api/users/adm/authenticate", "application/json", bytes.NewReader(payload))
+		if err != nil {
+			log.Fatal(err)
+		}
+		defer res.Body.Close()
+		resBodyBytes, err := io.ReadAll(res.Body)
+		if err != nil {
+			log.Fatalf("Client could not read response body: %s", err)
+		}
+		if res.StatusCode != http.StatusOK {
+			log.Fatalf("Error Status: %d Response: %s", res.StatusCode, string(resBodyBytes))
+		}
+		body := new(models.SuccessResponse)
+		if err := json.Unmarshal(resBodyBytes, body); err != nil {
+			log.Fatalf("Error unmarshalling JSON: %s", err)
+		}
+		authToken := body.Response.(map[string]any)["AuthToken"].(string)
 		config.SetAuthToken(authToken)
 		config.SetAuthToken(authToken)
 		return authToken
 		return authToken
 	}
 	}
-	authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}
-	payload, err := json.Marshal(authParams)
-	if err != nil {
-		log.Fatal(err)
+
+	if !ctx.SSO {
+		sToken, _, err := basicAuthSaasSignin(ctx.Username, ctx.Password)
+		if err != nil {
+			log.Fatal(err)
+		}
+		authToken, _, err := tenantLogin(ctx, sToken)
+		if err != nil {
+			log.Fatal(err)
+		}
+		config.SetAuthToken(authToken)
+		return authToken
 	}
 	}
-	res, err := http.Post(ctx.Endpoint+"/api/users/adm/authenticate", "application/json", bytes.NewReader(payload))
+
+	accessToken, err := loginSaaSOauth(&models.SsoLoginReqDto{OauthProvider: "oidc"}, ctx.TenantId)
 	if err != nil {
 	if err != nil {
 		log.Fatal(err)
 		log.Fatal(err)
 	}
 	}
-	resBodyBytes, err := io.ReadAll(res.Body)
-	if err != nil {
-		log.Fatalf("Client could not read response body: %s", err)
-	}
-	if res.StatusCode != http.StatusOK {
-		log.Fatalf("Error Status: %d Response: %s", res.StatusCode, string(resBodyBytes))
-	}
-	body := new(models.SuccessResponse)
-	if err := json.Unmarshal(resBodyBytes, body); err != nil {
-		log.Fatalf("Error unmarshalling JSON: %s", err)
-	}
-	authToken := body.Response.(map[string]any)["AuthToken"].(string)
-	config.SetAuthToken(authToken)
-	return authToken
+	config.SetAuthToken(accessToken)
+	return accessToken
 }
 }
 
 
 func request[T any](method, route string, payload any) *T {
 func request[T any](method, route string, payload any) *T {
@@ -188,3 +219,213 @@ func get(route string) string {
 	}
 	}
 	return string(bodyBytes)
 	return string(bodyBytes)
 }
 }
+
+func basicAuthSaasSignin(email, password string) (string, http.Header, error) {
+	payload := models.SignInReqDto{
+		FormFields: []models.FormField{
+			{
+				Id:    "email",
+				Value: email,
+			},
+			{
+				Id:    "password",
+				Value: password,
+			},
+		},
+	}
+
+	var res models.SignInResDto
+
+	// Create a new HTTP client with a timeout
+	client := &http.Client{
+		Timeout: 30 * time.Second,
+	}
+
+	// Create the request body
+	payloadBuf := new(bytes.Buffer)
+	json.NewEncoder(payloadBuf).Encode(payload)
+
+	// Create the request
+	req, err := http.NewRequest("POST", ambBaseUrl+"/auth/signin", payloadBuf)
+	if err != nil {
+		return "", http.Header{}, err
+	}
+	req.Header.Set("Content-Type", "application/json; charset=utf-8")
+	req.Header.Set("rid", "thirdpartyemailpassword")
+
+	// Send the request
+	resp, err := client.Do(req)
+	if err != nil {
+		return "", http.Header{}, err
+	}
+	defer resp.Body.Close()
+
+	// Check the response status code
+	if resp.StatusCode != http.StatusOK {
+		return "", http.Header{}, fmt.Errorf("error authenticating: %s", resp.Status)
+	}
+
+	// Copy the response headers
+	resHeaders := resp.Header
+
+	// Decode the response body
+	err = json.NewDecoder(resp.Body).Decode(&res)
+	if err != nil {
+		return "", http.Header{}, err
+	}
+
+	sToken := resHeaders.Get(models.ResHeaderKeyStAccessToken)
+	encodedAccessToken := url.QueryEscape(sToken)
+
+	return encodedAccessToken, resHeaders, nil
+}
+
+func tenantLogin(ctx config.Context, sToken string) (string, string, error) {
+	url := fmt.Sprintf("%s/api/v1/tenant/login?tenant_id=%s", ambBaseUrl, ctx.TenantId)
+
+	client := &http.Client{}
+	req, err := http.NewRequest(http.MethodPost, url, nil)
+
+	if err != nil {
+		return "", "", err
+	}
+	req.Header.Add("Cookie", fmt.Sprintf("sAccessToken=%s", sToken))
+
+	res, err := client.Do(req)
+	if err != nil {
+		return "", "", err
+	}
+	defer res.Body.Close()
+
+	body, err := io.ReadAll(res.Body)
+	if err != nil {
+		return "", "", err
+	}
+
+	data := models.TenantLoginResDto{}
+	json.Unmarshal(body, &data)
+
+	return data.Response.AuthToken, fmt.Sprintf(TenantUrlTemplate, ctx.TenantId), nil
+}
+
+func loginSaaSOauth(payload *models.SsoLoginReqDto, tenantId string) (string, error) {
+	socketUrl := ambOauthWssUrl
+	// Dial the netmaker server controller
+	conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil)
+	if err != nil {
+		slog.Error("error connecting to endpoint ", "url", socketUrl, "err", err)
+		return "", err
+	}
+
+	defer conn.Close()
+	return handleServerSSORegisterConn(payload, conn, tenantId)
+}
+
+func handleServerSSORegisterConn(payload *models.SsoLoginReqDto, conn *websocket.Conn, tenantId string) (string, error) {
+	reqData, err := json.Marshal(payload)
+	if err != nil {
+		return "", err
+	}
+	if err := conn.WriteMessage(websocket.TextMessage, reqData); err != nil {
+		return "", err
+	}
+	dataCh := make(chan string)
+	defer close(dataCh)
+	interrupt := make(chan os.Signal, 1)
+	signal.Notify(interrupt, os.Interrupt)
+
+	go func() {
+		for {
+			msgType, msg, err := conn.ReadMessage()
+			if err != nil {
+				if msgType < 0 {
+					slog.Info("received close message from server")
+					return
+				}
+				if !strings.Contains(err.Error(), "normal") { // Error reading a message from the server
+					slog.Error("error msg", "err", err)
+				}
+				return
+			}
+			if msgType == websocket.CloseMessage {
+				slog.Info("received close message from server")
+				return
+			}
+			if strings.Contains(string(msg), "auth/sso") {
+				fmt.Printf("Please visit:\n %s \nto authenticate\n", string(msg))
+			} else {
+				var res models.SsoLoginData
+				if err := json.Unmarshal(msg, &res); err != nil {
+					return
+				}
+				accessToken, _, err := tenantLoginV2(res.AmbAccessToken, tenantId, res.Username)
+				if err != nil {
+					slog.Error("error logging in tenant", "err", err)
+					dataCh <- ""
+					return
+				}
+				dataCh <- accessToken
+				return
+			}
+		}
+	}()
+
+	for {
+		select {
+		case accessToken := <-dataCh:
+			if accessToken == "" {
+				slog.Info("error getting access token")
+				return "", fmt.Errorf("error getting access token")
+			}
+			return accessToken, nil
+		case <-time.After(30 * time.Second):
+			slog.Error("authentiation timed out")
+			os.Exit(1)
+		case <-interrupt:
+			slog.Info("interrupt received, closing connection")
+			// Cleanly close the connection by sending a close message and then
+			// waiting (with timeout) for the server to close the connection.
+			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
+			if err != nil {
+				log.Fatal(err)
+			}
+			os.Exit(1)
+		}
+	}
+}
+
+func tenantLoginV2(ambJwt, tenantId, email string) (string, string, error) {
+	url := fmt.Sprintf("%s/api/v1/tenant/login/custom", ambBaseUrl)
+	payload := models.LoginReqDto{
+		Email:    email,
+		TenantID: tenantId,
+	}
+	payloadBuf := new(bytes.Buffer)
+	json.NewEncoder(payloadBuf).Encode(payload)
+
+	client := &http.Client{}
+	req, err := http.NewRequest("POST", url, payloadBuf)
+	if err != nil {
+		slog.Error("error creating request", "err", err)
+		return "", "", err
+	}
+	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", ambJwt))
+
+	res, err := client.Do(req)
+	if err != nil {
+		slog.Error("error sending request", "err", err)
+		return "", "", err
+	}
+	defer res.Body.Close()
+
+	body, err := io.ReadAll(res.Body)
+	if err != nil {
+		slog.Error("error reading response body", "err", err)
+		return "", "", err
+	}
+
+	data := models.TenantLoginResDto{}
+	json.Unmarshal(body, &data)
+
+	return data.Response.AuthToken, fmt.Sprintf(TenantUrlTemplate, tenantId), nil
+}

+ 51 - 0
models/structs.go

@@ -307,3 +307,54 @@ type LicenseLimits struct {
 	Clients  int `json:"clients"`
 	Clients  int `json:"clients"`
 	Networks int `json:"networks"`
 	Networks int `json:"networks"`
 }
 }
+
+type SignInReqDto struct {
+	FormFields FormFields `json:"formFields"`
+}
+
+type FormField struct {
+	Id    string `json:"id"`
+	Value any    `json:"value"`
+}
+
+type FormFields []FormField
+
+type SignInResDto struct {
+	Status string `json:"status"`
+	User   User   `json:"user"`
+}
+
+type TenantLoginResDto struct {
+	Code     int    `json:"code"`
+	Message  string `json:"message"`
+	Response struct {
+		UserName  string `json:"UserName"`
+		AuthToken string `json:"AuthToken"`
+	} `json:"response"`
+}
+
+type SsoLoginReqDto struct {
+	OauthProvider string `json:"oauthprovider"`
+}
+
+type SsoLoginResDto struct {
+	User      string `json:"UserName"`
+	AuthToken string `json:"AuthToken"`
+}
+
+type SsoLoginData struct {
+	Expiration     time.Time `json:"expiration"`
+	OauthProvider  string    `json:"oauthprovider,omitempty"`
+	OauthCode      string    `json:"oauthcode,omitempty"`
+	Username       string    `json:"username,omitempty"`
+	AmbAccessToken string    `json:"ambaccesstoken,omitempty"`
+}
+
+type LoginReqDto struct {
+	Email    string `json:"email"`
+	TenantID string `json:"tenant_id"`
+}
+
+const (
+	ResHeaderKeyStAccessToken = "St-Access-Token"
+)