| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431 | package functionsimport (	"bytes"	"encoding/json"	"fmt"	"io"	"log"	"net/http"	"net/url"	"os"	"os/signal"	"strings"	"time"	"github.com/gorilla/websocket"	"github.com/gravitl/netmaker/cli/config"	"github.com/gravitl/netmaker/logger"	"github.com/gravitl/netmaker/models"	"golang.org/x/exp/slog")const (	ambBaseUrl        = "https://api.accounts.netmaker.io"	TenantUrlTemplate = "https://api-%s.app.prod.netmaker.io"	ambOauthWssUrl    = "wss://api.accounts.netmaker.io/api/v1/auth/sso")func ssoLogin(endpoint string) string {	var (		authToken string		interrupt = make(chan os.Signal, 1)		url, _    = url.Parse(endpoint)		socketURL = fmt.Sprintf("wss://%s/api/oauth/headless", url.Host)	)	signal.Notify(interrupt, os.Interrupt)	conn, _, err := websocket.DefaultDialer.Dial(socketURL, nil)	if err != nil {		log.Fatal("error connecting to endpoint ", socketURL, err.Error())	}	defer conn.Close()	_, msg, err := conn.ReadMessage()	if err != nil {		log.Fatal("error reading from server: ", err.Error())	}	fmt.Printf("Please visit:\n %s \n to authenticate\n", string(msg))	done := make(chan struct{})	defer close(done)	go func() {		for {			msgType, msg, err := conn.ReadMessage()			if err != nil {				if msgType < 0 {					done <- struct{}{}					return				}				if !strings.Contains(err.Error(), "normal") {					log.Fatal("read error: ", err.Error())				}				return			}			if msgType == websocket.CloseMessage {				done <- struct{}{}				return			}			if strings.Contains(string(msg), "JWT: ") {				authToken = strings.TrimPrefix(string(msg), "JWT: ")			} else {				logger.Log(0, "Message from server:", string(msg))				return			}		}	}()	for {		select {		case <-done:			return authToken		case <-interrupt:			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))			if err != nil {				logger.Log(0, "write close:", err.Error())			}			return authToken		}	}}func getAuthToken(ctx config.Context, force bool) string {	if !force && ctx.AuthToken != "" {		return ctx.AuthToken	}	if !ctx.Saas {		if ctx.SSO {			authToken := ssoLogin(ctx.Endpoint)			config.SetAuthToken(authToken)			return authToken		}		authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}		payload, err := json.Marshal(authParams)		if err != nil {			log.Fatal(err)		}		res, err := http.Post(ctx.Endpoint+"/api/users/adm/authenticate", "application/json", bytes.NewReader(payload))		if err != nil {			log.Fatal(err)		}		defer res.Body.Close()		resBodyBytes, err := io.ReadAll(res.Body)		if err != nil {			log.Fatalf("Client could not read response body: %s", err)		}		if res.StatusCode != http.StatusOK {			log.Fatalf("Error Status: %d Response: %s", res.StatusCode, string(resBodyBytes))		}		body := new(models.SuccessResponse)		if err := json.Unmarshal(resBodyBytes, body); err != nil {			log.Fatalf("Error unmarshalling JSON: %s", err)		}		authToken := body.Response.(map[string]any)["AuthToken"].(string)		config.SetAuthToken(authToken)		return authToken	}	if !ctx.SSO {		sToken, _, err := basicAuthSaasSignin(ctx.Username, ctx.Password)		if err != nil {			log.Fatal(err)		}		authToken, _, err := tenantLogin(ctx, sToken)		if err != nil {			log.Fatal(err)		}		config.SetAuthToken(authToken)		return authToken	}	accessToken, err := loginSaaSOauth(&models.SsoLoginReqDto{OauthProvider: "oidc"}, ctx.TenantId)	if err != nil {		log.Fatal(err)	}	config.SetAuthToken(accessToken)	return accessToken}func request[T any](method, route string, payload any) *T {	var (		_, ctx = config.GetCurrentContext()		req    *http.Request		err    error	)	if payload == nil {		req, err = http.NewRequest(method, ctx.Endpoint+route, nil)		if err != nil {			log.Fatalf("Client could not create request: %s", err)		}	} else {		payloadBytes, jsonErr := json.Marshal(payload)		if jsonErr != nil {			log.Fatalf("Error in request JSON marshalling: %s", err)		}		req, err = http.NewRequest(method, ctx.Endpoint+route, bytes.NewReader(payloadBytes))		if err != nil {			log.Fatalf("Client could not create request: %s", err)		}		req.Header.Set("Content-Type", "application/json")	}	if ctx.MasterKey != "" {		req.Header.Set("Authorization", "Bearer "+ctx.MasterKey)	} else {		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, false))	}	retried := falseretry:	res, err := http.DefaultClient.Do(req)	if err != nil {		log.Fatalf("Client error making http request: %s", err)	}	// refresh JWT token	if res.StatusCode == http.StatusUnauthorized && !retried && ctx.MasterKey == "" {		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, true))		retried = true		// TODO add a retry limit, drop goto		goto retry	}	resBodyBytes, err := io.ReadAll(res.Body)	if err != nil {		log.Fatalf("Client could not read response body: %s", err)	}	if res.StatusCode != http.StatusOK {		log.Fatalf("Error Status: %d Response: %s", res.StatusCode, string(resBodyBytes))	}	body := new(T)	if len(resBodyBytes) > 0 {		if err := json.Unmarshal(resBodyBytes, body); err != nil {			log.Fatalf("Error unmarshalling JSON: %s", err)		}	}	return body}func get(route string) string {	_, ctx := config.GetCurrentContext()	req, err := http.NewRequest(http.MethodGet, ctx.Endpoint+route, nil)	if err != nil {		log.Fatal(err)	}	if ctx.MasterKey != "" {		req.Header.Set("Authorization", "Bearer "+ctx.MasterKey)	} else {		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, true))	}	res, err := http.DefaultClient.Do(req)	if err != nil {		log.Fatal(err)	}	bodyBytes, err := io.ReadAll(res.Body)	if err != nil {		log.Fatal(err)	}	return string(bodyBytes)}func basicAuthSaasSignin(email, password string) (string, http.Header, error) {	payload := models.SignInReqDto{		FormFields: []models.FormField{			{				Id:    "email",				Value: email,			},			{				Id:    "password",				Value: password,			},		},	}	var res models.SignInResDto	// Create a new HTTP client with a timeout	client := &http.Client{		Timeout: 30 * time.Second,	}	// Create the request body	payloadBuf := new(bytes.Buffer)	json.NewEncoder(payloadBuf).Encode(payload)	// Create the request	req, err := http.NewRequest("POST", ambBaseUrl+"/auth/signin", payloadBuf)	if err != nil {		return "", http.Header{}, err	}	req.Header.Set("Content-Type", "application/json; charset=utf-8")	req.Header.Set("rid", "thirdpartyemailpassword")	// Send the request	resp, err := client.Do(req)	if err != nil {		return "", http.Header{}, err	}	defer resp.Body.Close()	// Check the response status code	if resp.StatusCode != http.StatusOK {		return "", http.Header{}, fmt.Errorf("error authenticating: %s", resp.Status)	}	// Copy the response headers	resHeaders := resp.Header	// Decode the response body	err = json.NewDecoder(resp.Body).Decode(&res)	if err != nil {		return "", http.Header{}, err	}	sToken := resHeaders.Get(models.ResHeaderKeyStAccessToken)	encodedAccessToken := url.QueryEscape(sToken)	return encodedAccessToken, resHeaders, nil}func tenantLogin(ctx config.Context, sToken string) (string, string, error) {	url := fmt.Sprintf("%s/api/v1/tenant/login?tenant_id=%s", ambBaseUrl, ctx.TenantId)	client := &http.Client{}	req, err := http.NewRequest(http.MethodPost, url, nil)	if err != nil {		return "", "", err	}	req.Header.Add("Cookie", fmt.Sprintf("sAccessToken=%s", sToken))	res, err := client.Do(req)	if err != nil {		return "", "", err	}	defer res.Body.Close()	body, err := io.ReadAll(res.Body)	if err != nil {		return "", "", err	}	data := models.TenantLoginResDto{}	json.Unmarshal(body, &data)	return data.Response.AuthToken, fmt.Sprintf(TenantUrlTemplate, ctx.TenantId), nil}func loginSaaSOauth(payload *models.SsoLoginReqDto, tenantId string) (string, error) {	socketUrl := ambOauthWssUrl	// Dial the netmaker server controller	conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil)	if err != nil {		slog.Error("error connecting to endpoint ", "url", socketUrl, "err", err)		return "", err	}	defer conn.Close()	return handleServerSSORegisterConn(payload, conn, tenantId)}func handleServerSSORegisterConn(payload *models.SsoLoginReqDto, conn *websocket.Conn, tenantId string) (string, error) {	reqData, err := json.Marshal(payload)	if err != nil {		return "", err	}	if err := conn.WriteMessage(websocket.TextMessage, reqData); err != nil {		return "", err	}	dataCh := make(chan string)	defer close(dataCh)	interrupt := make(chan os.Signal, 1)	signal.Notify(interrupt, os.Interrupt)	go func() {		for {			msgType, msg, err := conn.ReadMessage()			if err != nil {				if msgType < 0 {					slog.Info("received close message from server")					return				}				if !strings.Contains(err.Error(), "normal") { // Error reading a message from the server					slog.Error("error msg", "err", err)				}				return			}			if msgType == websocket.CloseMessage {				slog.Info("received close message from server")				return			}			if strings.Contains(string(msg), "auth/sso") {				fmt.Printf("Please visit:\n %s \nto authenticate\n", string(msg))			} else {				var res models.SsoLoginData				if err := json.Unmarshal(msg, &res); err != nil {					return				}				accessToken, _, err := tenantLoginV2(res.AmbAccessToken, tenantId, res.Username)				if err != nil {					slog.Error("error logging in tenant", "err", err)					dataCh <- ""					return				}				dataCh <- accessToken				return			}		}	}()	for {		select {		case accessToken := <-dataCh:			if accessToken == "" {				slog.Info("error getting access token")				return "", fmt.Errorf("error getting access token")			}			return accessToken, nil		case <-time.After(30 * time.Second):			slog.Error("authentiation timed out")			os.Exit(1)		case <-interrupt:			slog.Info("interrupt received, closing connection")			// Cleanly close the connection by sending a close message and then			// waiting (with timeout) for the server to close the connection.			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))			if err != nil {				log.Fatal(err)			}			os.Exit(1)		}	}}func tenantLoginV2(ambJwt, tenantId, email string) (string, string, error) {	url := fmt.Sprintf("%s/api/v1/tenant/login/custom", ambBaseUrl)	payload := models.LoginReqDto{		Email:    email,		TenantID: tenantId,	}	payloadBuf := new(bytes.Buffer)	json.NewEncoder(payloadBuf).Encode(payload)	client := &http.Client{}	req, err := http.NewRequest("POST", url, payloadBuf)	if err != nil {		slog.Error("error creating request", "err", err)		return "", "", err	}	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", ambJwt))	res, err := client.Do(req)	if err != nil {		slog.Error("error sending request", "err", err)		return "", "", err	}	defer res.Body.Close()	body, err := io.ReadAll(res.Body)	if err != nil {		slog.Error("error reading response body", "err", err)		return "", "", err	}	data := models.TenantLoginResDto{}	json.Unmarshal(body, &data)	return data.Response.AuthToken, fmt.Sprintf(TenantUrlTemplate, tenantId), nil}
 |