Explorar o código

Merge branch 'NET-1950' of https://github.com/gravitl/netmaker into NET-1950

abhishek9686 hai 4 meses
pai
achega
4ae6b3a407
Modificáronse 9 ficheiros con 81 adicións e 83 borrados
  1. 2 0
      controllers/controller.go
  2. 5 4
      controllers/user.go
  3. 4 0
      db/db.go
  4. 4 1
      logic/auth.go
  5. 9 6
      logic/jwts.go
  6. 0 60
      models/accessToken.go
  7. 4 9
      schema/job.go
  8. 1 3
      schema/models.go
  9. 52 0
      schema/user_access_token.go

+ 2 - 0
controllers/controller.go

@@ -3,6 +3,7 @@ package controller
 import (
 	"context"
 	"fmt"
+	"github.com/gravitl/netmaker/db"
 	"net/http"
 	"os"
 	"strings"
@@ -18,6 +19,7 @@ import (
 
 // HttpMiddlewares - middleware functions for REST interactions
 var HttpMiddlewares = []mux.MiddlewareFunc{
+	db.Middleware,
 	userMiddleWare,
 }
 

+ 5 - 4
controllers/user.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/gravitl/netmaker/schema"
 	"net/http"
 	"reflect"
 	"time"
@@ -57,7 +58,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 
 	// Auth request consists of Mac Address and Password (from node that is authorizing
 	// in case of Master, auth is ignored and mac is set to "mastermac"
-	var req models.UserAccessToken
+	var req schema.UserAccessToken
 
 	err := json.NewDecoder(r.Body).Decode(&req)
 	if err != nil {
@@ -97,7 +98,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 		)
 		return
 	}
-	err = req.Create()
+	err = req.Create(r.Context())
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
@@ -127,7 +128,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
 		return
 	}
-	logic.ReturnSuccessResponseWithJson(w, r, (&models.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
+	logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(r.Context()), "fetched api access tokens for user "+username)
 }
 
 // @Summary     Authenticate a user to retrieve an authorization token
@@ -146,7 +147,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	err := (&models.UserAccessToken{ID: id}).Delete()
+	err := (&schema.UserAccessToken{ID: id}).Delete(r.Context())
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,

+ 4 - 0
db/db.go

@@ -75,6 +75,10 @@ func Middleware(next http.Handler) http.Handler {
 //
 // The function panics, if a connection does not exist.
 func FromContext(ctx context.Context) *gorm.DB {
+	db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
+	if !ok {
+		panic(ErrDBNotFound)
+	}
 
 	return db
 }

+ 4 - 1
logic/auth.go

@@ -1,10 +1,13 @@
 package logic
 
 import (
+	"context"
 	"encoding/base64"
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"time"
 
 	"github.com/go-playground/validator/v10"
@@ -360,7 +363,7 @@ func DeleteUser(user string) error {
 		return err
 	}
 	go RemoveUserFromAclPolicy(user)
-	return (&models.UserAccessToken{UserName: user}).DeleteAllUserTokens()
+	return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
 }
 
 func SetAuthSecret(secret string) error {

+ 9 - 6
logic/jwts.go

@@ -1,8 +1,11 @@
 package logic
 
 import (
+	"context"
 	"errors"
 	"fmt"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"strings"
 	"time"
 
@@ -125,15 +128,15 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
 	if claims.TokenType == models.AccessTokenType {
 		jti := claims.ID
 		if jti != "" {
-			a := models.UserAccessToken{ID: jti}
+			a := schema.UserAccessToken{ID: jti}
 			// check if access token is active
-			err := a.Get()
+			err := a.Get(db.WithContext(context.TODO()))
 			if err != nil {
 				err = errors.New("token revoked")
 				return "", err
 			}
 			a.LastUsed = time.Now()
-			a.Update()
+			a.Update(db.WithContext(context.TODO()))
 		}
 	}
 
@@ -169,15 +172,15 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
 	if claims.TokenType == models.AccessTokenType {
 		jti := claims.ID
 		if jti != "" {
-			a := models.UserAccessToken{ID: jti}
+			a := schema.UserAccessToken{ID: jti}
 			// check if access token is active
-			err := a.Get()
+			err := a.Get(db.WithContext(context.TODO()))
 			if err != nil {
 				err = errors.New("token revoked")
 				return "", false, false, err
 			}
 			a.LastUsed = time.Now()
-			a.Update()
+			a.Update(db.WithContext(context.TODO()))
 		}
 	}
 	if token != nil && token.Valid {

+ 0 - 60
models/accessToken.go

@@ -1,60 +0,0 @@
-package models
-
-import (
-	"context"
-	"time"
-
-	"github.com/gravitl/netmaker/db"
-)
-
-// accessTokenTableName - access tokens table
-const accessTokenTableName = "user_access_tokens"
-
-// UserAccessToken - token used to access netmaker
-type UserAccessToken struct {
-	ID        string    `gorm:"id,primary_key" json:"id"`
-	Name      string    `gorm:"name" json:"name"`
-	UserName  string    `gorm:"user_name" json:"user_name"`
-	ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
-	LastUsed  time.Time `gorm:"last_used" json:"last_used"`
-	CreatedBy string    `gorm:"created_by" json:"created_by"`
-	CreatedAt time.Time `gorm:"created_at" json:"created_at"`
-}
-
-func (a *UserAccessToken) Table() string {
-	return accessTokenTableName
-}
-
-func (a *UserAccessToken) Get() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
-}
-
-func (a *UserAccessToken) Update() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
-}
-
-func (a *UserAccessToken) Create() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
-}
-
-func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
-	err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
-	return
-}
-
-func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
-	db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
-	if ats == nil {
-		ats = []UserAccessToken{}
-	}
-	return
-}
-
-func (a *UserAccessToken) Delete() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
-}
-
-func (a *UserAccessToken) DeleteAllUserTokens() error {
-	return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Delete(&a).Error
-
-}

+ 4 - 9
schema/jobs.go → schema/job.go

@@ -16,21 +16,16 @@ import (
 // that it is easier to prevent a task from
 // being executed again.
 type Job struct {
-	ID        string    `gorm:"id;primary_key"`
-	CreatedAt time.Time `gorm:"created_at"`
-}
-
-// TableName returns the name of the jobs table.
-func (j *Job) TableName() string {
-	return "jobs"
+	ID        string `gorm:"primaryKey"`
+	CreatedAt time.Time
 }
 
 // Create creates a job record in the jobs table.
 func (j *Job) Create(ctx context.Context) error {
-	return db.FromContext(ctx).Table(j.TableName()).Create(j).Error
+	return db.FromContext(ctx).Model(&Job{}).Create(j).Error
 }
 
 // Get returns a job record with the given Job.ID.
 func (j *Job) Get(ctx context.Context) error {
-	return db.FromContext(ctx).Table(j.TableName()).Where("id = ?", j.ID).First(j).Error
+	return db.FromContext(ctx).Model(&Job{}).Where("id = ?", j.ID).First(j).Error
 }

+ 1 - 3
schema/models.go

@@ -1,11 +1,9 @@
 package schema
 
-import "github.com/gravitl/netmaker/models"
-
 // ListModels lists all the models in this schema.
 func ListModels() []interface{} {
 	return []interface{}{
 		&Job{},
-		&models.UserAccessToken{},
+		&UserAccessToken{},
 	}
 }

+ 52 - 0
schema/user_access_token.go

@@ -0,0 +1,52 @@
+package schema
+
+import (
+	"context"
+	"time"
+
+	"github.com/gravitl/netmaker/db"
+)
+
+// UserAccessToken - token used to access netmaker
+type UserAccessToken struct {
+	ID        string    `gorm:"primaryKey" json:"id"`
+	Name      string    `json:"name"`
+	UserName  string    `json:"user_name"`
+	ExpiresAt time.Time `json:"expires_at"`
+	LastUsed  time.Time `json:"last_used"`
+	CreatedBy string    `json:"created_by"`
+	CreatedAt time.Time `json:"created_at"`
+}
+
+func (a *UserAccessToken) Get(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).First(&a).Where("id = ?", a.ID).Error
+}
+
+func (a *UserAccessToken) Update(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Updates(&a).Error
+}
+
+func (a *UserAccessToken) Create(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Create(&a).Error
+}
+
+func (a *UserAccessToken) List(ctx context.Context) (ats []UserAccessToken, err error) {
+	err = db.FromContext(ctx).Model(&UserAccessToken{}).Find(&ats).Error
+	return
+}
+
+func (a *UserAccessToken) ListByUser(ctx context.Context) (ats []UserAccessToken) {
+	db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Find(&ats)
+	if ats == nil {
+		ats = []UserAccessToken{}
+	}
+	return
+}
+
+func (a *UserAccessToken) Delete(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Delete(&a).Error
+}
+
+func (a *UserAccessToken) DeleteAllUserTokens(ctx context.Context) error {
+	return db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Delete(&a).Error
+}