Procházet zdrojové kódy

add ctx to DB funcs (#3435)

Abhishek K před 4 měsíci
rodič
revize
262803c234
5 změnil soubory, kde provedl 39 přidání a 39 odebrání
  1. 5 5
      controllers/user.go
  2. 4 0
      db/db.go
  3. 3 1
      logic/auth.go
  4. 6 4
      logic/jwts.go
  5. 21 29
      schema/accessToken.go

+ 5 - 5
controllers/user.go

@@ -110,7 +110,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 		)
 		)
 		return
 		return
 	}
 	}
-	err = req.Create()
+	err = req.Create(r.Context())
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(
 		logic.ReturnErrorResponse(
 			w,
 			w,
@@ -140,7 +140,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
 		return
 		return
 	}
 	}
-	logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
+	logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(r.Context()), "fetched api access tokens for user "+username)
 }
 }
 
 
 // @Summary     Authenticate a user to retrieve an authorization token
 // @Summary     Authenticate a user to retrieve an authorization token
@@ -161,7 +161,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 	a := schema.UserAccessToken{
 	a := schema.UserAccessToken{
 		ID: id,
 		ID: id,
 	}
 	}
-	err := a.Get()
+	err := a.Get(r.Context())
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
 		return
 		return
@@ -188,7 +188,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 		}
 		}
 	}
 	}
 
 
-	err = (&schema.UserAccessToken{ID: id}).Delete()
+	err = (&schema.UserAccessToken{ID: id}).Delete(r.Context())
 	if err != nil {
 	if err != nil {
 		logic.ReturnErrorResponse(
 		logic.ReturnErrorResponse(
 			w,
 			w,
@@ -754,7 +754,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 	logic.AddGlobalNetRolesToAdmins(&userchange)
 	logic.AddGlobalNetRolesToAdmins(&userchange)
 	if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
 	if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
-		(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens()
+		(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens(r.Context())
 	}
 	}
 	user, err = logic.UpdateUser(&userchange, user)
 	user, err = logic.UpdateUser(&userchange, user)
 	if err != nil {
 	if err != nil {

+ 4 - 0
db/db.go

@@ -75,6 +75,10 @@ func Middleware(next http.Handler) http.Handler {
 //
 //
 // The function panics, if a connection does not exist.
 // The function panics, if a connection does not exist.
 func FromContext(ctx context.Context) *gorm.DB {
 func FromContext(ctx context.Context) *gorm.DB {
+	db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
+	if !ok {
+		panic(ErrDBNotFound)
+	}
 
 
 	return db
 	return db
 }
 }

+ 3 - 1
logic/auth.go

@@ -1,6 +1,7 @@
 package logic
 package logic
 
 
 import (
 import (
+	"context"
 	"encoding/base64"
 	"encoding/base64"
 	"encoding/json"
 	"encoding/json"
 	"errors"
 	"errors"
@@ -12,6 +13,7 @@ import (
 	"golang.org/x/exp/slog"
 	"golang.org/x/exp/slog"
 
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/schema"
 	"github.com/gravitl/netmaker/schema"
@@ -361,7 +363,7 @@ func DeleteUser(user string) error {
 		return err
 		return err
 	}
 	}
 	go RemoveUserFromAclPolicy(user)
 	go RemoveUserFromAclPolicy(user)
-	return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens()
+	return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
 }
 }
 
 
 func SetAuthSecret(secret string) error {
 func SetAuthSecret(secret string) error {

+ 6 - 4
logic/jwts.go

@@ -1,6 +1,7 @@
 package logic
 package logic
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"strings"
 	"strings"
@@ -8,6 +9,7 @@ import (
 
 
 	"github.com/golang-jwt/jwt/v4"
 	"github.com/golang-jwt/jwt/v4"
 
 
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/schema"
 	"github.com/gravitl/netmaker/schema"
@@ -127,13 +129,13 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
 		if jti != "" {
 		if jti != "" {
 			a := schema.UserAccessToken{ID: jti}
 			a := schema.UserAccessToken{ID: jti}
 			// check if access token is active
 			// check if access token is active
-			err := a.Get()
+			err := a.Get(db.WithContext(context.TODO()))
 			if err != nil {
 			if err != nil {
 				err = errors.New("token revoked")
 				err = errors.New("token revoked")
 				return "", err
 				return "", err
 			}
 			}
 			a.LastUsed = time.Now()
 			a.LastUsed = time.Now()
-			a.Update()
+			a.Update(db.WithContext(context.TODO()))
 		}
 		}
 	}
 	}
 
 
@@ -171,13 +173,13 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
 		if jti != "" {
 		if jti != "" {
 			a := schema.UserAccessToken{ID: jti}
 			a := schema.UserAccessToken{ID: jti}
 			// check if access token is active
 			// check if access token is active
-			err := a.Get()
+			err := a.Get(db.WithContext(context.TODO()))
 			if err != nil {
 			if err != nil {
 				err = errors.New("token revoked")
 				err = errors.New("token revoked")
 				return "", false, false, err
 				return "", false, false, err
 			}
 			}
 			a.LastUsed = time.Now()
 			a.LastUsed = time.Now()
-			a.Update()
+			a.Update(db.WithContext(context.TODO()))
 		}
 		}
 	}
 	}
 	if token != nil && token.Valid {
 	if token != nil && token.Valid {

+ 21 - 29
schema/accessToken.go

@@ -7,54 +7,46 @@ import (
 	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/db"
 )
 )
 
 
-// accessTokenTableName - access tokens table
-const accessTokenTableName = "user_access_tokens"
-
 // UserAccessToken - token used to access netmaker
 // UserAccessToken - token used to access netmaker
 type UserAccessToken struct {
 type UserAccessToken struct {
-	ID        string    `gorm:"id,primary_key" json:"id"`
-	Name      string    `gorm:"name" json:"name"`
-	UserName  string    `gorm:"user_name" json:"user_name"`
-	ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
-	LastUsed  time.Time `gorm:"last_used" json:"last_used"`
-	CreatedBy string    `gorm:"created_by" json:"created_by"`
-	CreatedAt time.Time `gorm:"created_at" json:"created_at"`
-}
-
-func (a *UserAccessToken) Table() string {
-	return accessTokenTableName
+	ID        string    `gorm:"primaryKey" json:"id"`
+	Name      string    `json:"name"`
+	UserName  string    `json:"user_name"`
+	ExpiresAt time.Time `json:"expires_at"`
+	LastUsed  time.Time `json:"last_used"`
+	CreatedBy string    `json:"created_by"`
+	CreatedAt time.Time `json:"created_at"`
 }
 }
 
 
-func (a *UserAccessToken) Get() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
+func (a *UserAccessToken) Get(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).First(&a).Where("id = ?", a.ID).Error
 }
 }
 
 
-func (a *UserAccessToken) Update() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
+func (a *UserAccessToken) Update(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Updates(&a).Error
 }
 }
 
 
-func (a *UserAccessToken) Create() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
+func (a *UserAccessToken) Create(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Create(&a).Error
 }
 }
 
 
-func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
-	err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
+func (a *UserAccessToken) List(ctx context.Context) (ats []UserAccessToken, err error) {
+	err = db.FromContext(ctx).Model(&UserAccessToken{}).Find(&ats).Error
 	return
 	return
 }
 }
 
 
-func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
-	db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
+func (a *UserAccessToken) ListByUser(ctx context.Context) (ats []UserAccessToken) {
+	db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Find(&ats)
 	if ats == nil {
 	if ats == nil {
 		ats = []UserAccessToken{}
 		ats = []UserAccessToken{}
 	}
 	}
 	return
 	return
 }
 }
 
 
-func (a *UserAccessToken) Delete() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
+func (a *UserAccessToken) Delete(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Delete(&a).Error
 }
 }
 
 
-func (a *UserAccessToken) DeleteAllUserTokens() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error
-
+func (a *UserAccessToken) DeleteAllUserTokens(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Delete(&a).Error
 }
 }