Browse Source

use cached JWT token and refresh on expiry

Anish Mukherjee 2 years ago
parent
commit
b2d0a6dfe9
4 changed files with 39 additions and 17 deletions
  1. 15 4
      cli/config/config.go
  2. 2 2
      cli/functions/ext_client.go
  3. 18 7
      cli/functions/http_client.go
  4. 4 4
      cli/functions/server.go

+ 15 - 4
cli/config/config.go

@@ -16,6 +16,7 @@ type Context struct {
 	Password  string `yaml:"password,omitempty"`
 	MasterKey string `yaml:"masterkey,omitempty"`
 	Current   bool   `yaml:"current,omitempty"`
+	AuthToken string `yaml:"auth_token,omitempty"`
 }
 
 var (
@@ -75,10 +76,10 @@ func saveContext() {
 }
 
 // GetCurrentContext - returns current set context
-func GetCurrentContext() (ret Context) {
-	for _, ctx := range contextMap {
-		if ctx.Current {
-			ret = ctx
+func GetCurrentContext() (name string, ctx Context) {
+	for n, c := range contextMap {
+		if c.Current {
+			name, ctx = n, c
 			return
 		}
 	}
@@ -107,6 +108,16 @@ func SetContext(ctxName string, ctx Context) {
 	saveContext()
 }
 
+// SetAuthToken - saves the auth token
+func SetAuthToken(authToken string) {
+	ctxName, _ := GetCurrentContext()
+	if ctx, ok := contextMap[ctxName]; ok {
+		ctx.AuthToken = authToken
+		contextMap[ctxName] = ctx
+		saveContext()
+	}
+}
+
 // DeleteContext - deletes a context
 func DeleteContext(ctxName string) {
 	if _, ok := contextMap[ctxName]; ok {

+ 2 - 2
cli/functions/ext_client.go

@@ -27,7 +27,7 @@ func GetExtClient(networkName, clientID string) *models.ExtClient {
 
 // GetExtClientConfig - fetch a wireguard config of an external client
 func GetExtClientConfig(networkName, clientID string) string {
-	ctx := config.GetCurrentContext()
+	_, ctx := config.GetCurrentContext()
 	req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("%s/api/extclients/%s/%s/file", ctx.Endpoint, networkName, clientID), nil)
 	if err != nil {
 		log.Fatal(err)
@@ -35,7 +35,7 @@ func GetExtClientConfig(networkName, clientID string) string {
 	if ctx.MasterKey != "" {
 		req.Header.Set("Authorization", "Bearer "+ctx.MasterKey)
 	} else {
-		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx))
+		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, true))
 	}
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {

+ 18 - 7
cli/functions/http_client.go

@@ -11,7 +11,10 @@ import (
 	"github.com/gravitl/netmaker/models"
 )
 
-func getAuthToken(ctx config.Context) string {
+func getAuthToken(ctx config.Context, force bool) string {
+	if !force && ctx.AuthToken != "" {
+		return ctx.AuthToken
+	}
 	authParams := &models.UserAuthParams{UserName: ctx.Username, Password: ctx.Password}
 	payload, err := json.Marshal(authParams)
 	if err != nil {
@@ -32,14 +35,16 @@ func getAuthToken(ctx config.Context) string {
 	if err := json.Unmarshal(resBodyBytes, body); err != nil {
 		log.Fatalf("Error unmarshalling JSON: %s", err)
 	}
-	return body.Response.(map[string]any)["AuthToken"].(string)
+	authToken := body.Response.(map[string]any)["AuthToken"].(string)
+	config.SetAuthToken(authToken)
+	return authToken
 }
 
 func request[T any](method, route string, payload any) *T {
 	var (
-		ctx = config.GetCurrentContext()
-		req *http.Request
-		err error
+		_, ctx = config.GetCurrentContext()
+		req    *http.Request
+		err    error
 	)
 	if payload == nil {
 		req, err = http.NewRequest(method, ctx.Endpoint+route, nil)
@@ -60,18 +65,24 @@ func request[T any](method, route string, payload any) *T {
 	if ctx.MasterKey != "" {
 		req.Header.Set("Authorization", "Bearer "+ctx.MasterKey)
 	} else {
-		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx))
+		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, false))
 	}
+retry:
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {
 		log.Fatalf("Client error making http request: %s", err)
 	}
+	// refresh JWT token
+	if res.StatusCode == http.StatusUnauthorized {
+		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, true))
+		goto retry
+	}
 	resBodyBytes, err := io.ReadAll(res.Body)
 	if err != nil {
 		log.Fatalf("Client could not read response body: %s", err)
 	}
 	if res.StatusCode != http.StatusOK {
-		log.Fatalf("Error response: %s", string(resBodyBytes))
+		log.Fatalf("Error Status: %d Response: %s", http.StatusOK, string(resBodyBytes))
 	}
 	body := new(T)
 	if len(resBodyBytes) > 0 {

+ 4 - 4
cli/functions/server.go

@@ -12,7 +12,7 @@ import (
 
 // GetLogs - fetch Netmaker server logs
 func GetLogs() string {
-	ctx := config.GetCurrentContext()
+	_, ctx := config.GetCurrentContext()
 	req, err := http.NewRequest(http.MethodGet, ctx.Endpoint+"/api/logs", nil)
 	if err != nil {
 		log.Fatal(err)
@@ -20,7 +20,7 @@ func GetLogs() string {
 	if ctx.MasterKey != "" {
 		req.Header.Set("Authorization", "Bearer "+ctx.MasterKey)
 	} else {
-		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx))
+		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, true))
 	}
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {
@@ -45,7 +45,7 @@ func GetServerConfig() *cfg.ServerConfig {
 
 // GetServerHealth - fetch server current health status
 func GetServerHealth() string {
-	ctx := config.GetCurrentContext()
+	_, ctx := config.GetCurrentContext()
 	req, err := http.NewRequest(http.MethodGet, ctx.Endpoint+"/api/server/health", nil)
 	if err != nil {
 		log.Fatal(err)
@@ -53,7 +53,7 @@ func GetServerHealth() string {
 	if ctx.MasterKey != "" {
 		req.Header.Set("Authorization", "Bearer "+ctx.MasterKey)
 	} else {
-		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx))
+		req.Header.Set("Authorization", "Bearer "+getAuthToken(ctx, true))
 	}
 	res, err := http.DefaultClient.Do(req)
 	if err != nil {