Explorar o código

move egress model to schema

abhishek9686 hai 4 meses
pai
achega
a7581248bf
Modificáronse 6 ficheiros con 42 adicións e 32 borrados
  1. 10 8
      controllers/egress.go
  2. 4 2
      logic/acls.go
  3. 5 3
      logic/egress.go
  4. 4 2
      logic/extpeers.go
  5. 4 2
      migrate/migrate.go
  6. 15 15
      schema/egress.go

+ 10 - 8
controllers/egress.go

@@ -1,6 +1,7 @@
 package controller
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"net/http"
@@ -8,6 +9,7 @@ import (
 
 	"github.com/google/uuid"
 	"github.com/gorilla/mux"
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
@@ -73,7 +75,7 @@ func createEgress(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid egress request"), "badrequest"))
 		return
 	}
-	err = e.Create()
+	err = e.Create(db.WithContext(context.TODO()))
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
@@ -103,7 +105,7 @@ func listEgress(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	e := schema.Egress{Network: network}
-	list, err := e.ListByNetwork()
+	list, err := e.ListByNetwork(db.WithContext(context.TODO()))
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
@@ -146,7 +148,7 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
 	}
 
 	e := schema.Egress{ID: req.ID}
-	err = e.Get()
+	err = e.Get(db.WithContext(context.TODO()))
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
@@ -173,7 +175,7 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("invalid egress request"), "badrequest"))
 		return
 	}
-	err = e.Update()
+	err = e.Update(db.WithContext(context.TODO()))
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
@@ -184,14 +186,14 @@ func updateEgress(w http.ResponseWriter, r *http.Request) {
 	}
 	if updateNat {
 		e.Nat = req.Nat
-		e.UpdateNatStatus()
+		e.UpdateNatStatus(db.WithContext(context.TODO()))
 	}
 	if updateInetGw {
 		e.IsInetGw = req.IsInetGw
-		e.UpdateINetGwStatus()
+		e.UpdateINetGwStatus(db.WithContext(context.TODO()))
 	}
 	go mq.PublishPeerUpdate(false)
-	logic.ReturnSuccessResponseWithJson(w, r, req, "updated egress resource")
+	logic.ReturnSuccessResponseWithJson(w, r, e, "updated egress resource")
 }
 
 // @Summary     Delete Egress Resource
@@ -211,7 +213,7 @@ func deleteEgress(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	e := schema.Egress{ID: id}
-	err := e.Delete()
+	err := e.Delete(db.WithContext(context.TODO()))
 	if err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return

+ 4 - 2
logic/acls.go

@@ -1,6 +1,7 @@
 package logic
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -12,6 +13,7 @@ import (
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/schema"
 	"github.com/gravitl/netmaker/servercfg"
@@ -295,7 +297,7 @@ func checkIfAclTagisValid(t models.AclPolicyTag, netID models.NetworkID, policyT
 		e := schema.Egress{
 			ID: t.Value,
 		}
-		err := e.Get()
+		err := e.Get(db.WithContext(context.TODO()))
 		if err != nil {
 			return false
 		}
@@ -1251,7 +1253,7 @@ func getEgressUserRulesForNode(targetnode *models.Node,
 			for _, dstI := range acl.Dst {
 				if dstI.ID == models.EgressID {
 					e := schema.Egress{ID: dstI.Value}
-					err := e.Get()
+					err := e.Get(db.WithContext(context.TODO()))
 					if err != nil {
 						continue
 					}

+ 5 - 3
logic/egress.go

@@ -1,9 +1,11 @@
 package logic
 
 import (
+	"context"
 	"encoding/json"
 	"net"
 
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/schema"
 )
@@ -48,7 +50,7 @@ func GetInetClientsFromAclPolicies(node *models.Node) (inetClientIDs []string) {
 				e := schema.Egress{
 					ID: dstI.Value,
 				}
-				err := e.Get()
+				err := e.Get(db.WithContext(context.TODO()))
 				if err != nil {
 					continue
 				}
@@ -76,7 +78,7 @@ func IsNodeUsingInternetGw(node *models.Node) {
 		for _, dstI := range acl.Dst {
 			if dstI.ID == models.EgressID {
 				e := schema.Egress{ID: dstI.Value}
-				err := e.Get()
+				err := e.Get(db.WithContext(context.TODO()))
 				if err != nil {
 					continue
 				}
@@ -102,7 +104,7 @@ func IsNodeUsingInternetGw(node *models.Node) {
 }
 
 func GetNodeEgressInfo(targetNode *models.Node) {
-	eli, _ := (&schema.Egress{Network: targetNode.Network}).ListByNetwork()
+	eli, _ := (&schema.Egress{Network: targetNode.Network}).ListByNetwork(db.WithContext(context.TODO()))
 	req := models.EgressGatewayRequest{
 		NodeID: targetNode.ID.String(),
 		NetID:  targetNode.Network,

+ 4 - 2
logic/extpeers.go

@@ -1,6 +1,7 @@
 package logic
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -13,6 +14,7 @@ import (
 
 	"github.com/goombaio/namegenerator"
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic/acls"
 	"github.com/gravitl/netmaker/models"
@@ -631,7 +633,7 @@ func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode
 			if dstI.ID == models.EgressID {
 
 				e := schema.Egress{ID: dstI.Value}
-				err := e.Get()
+				err := e.Get(db.WithContext(context.TODO()))
 				if err != nil {
 					continue
 				}
@@ -720,7 +722,7 @@ func getFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules []
 							if dstI.ID == models.EgressID {
 
 								e := schema.Egress{ID: dstI.Value}
-								err := e.Get()
+								err := e.Get(db.WithContext(context.TODO()))
 								if err != nil {
 									continue
 								}

+ 4 - 2
migrate/migrate.go

@@ -1,6 +1,7 @@
 package migrate
 
 import (
+	"context"
 	"encoding/json"
 	"fmt"
 	"log"
@@ -11,6 +12,7 @@ import (
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/logic/acls"
@@ -525,7 +527,7 @@ func migrateToEgressV1() {
 					CreatedBy: user.UserName,
 					CreatedAt: time.Now().UTC(),
 				}
-				err = e.Create()
+				err = e.Create(db.WithContext(context.TODO()))
 				if err == nil {
 					node.IsEgressGateway = false
 					node.EgressGatewayRequest = models.EgressGatewayRequest{}
@@ -551,7 +553,7 @@ func migrateToEgressV1() {
 				CreatedBy: user.UserName,
 				CreatedAt: time.Now().UTC(),
 			}
-			err = e.Create()
+			err = e.Create(db.WithContext(context.TODO()))
 			if err == nil {
 				node.IsEgressGateway = false
 				node.EgressGatewayRequest = models.EgressGatewayRequest{}

+ 15 - 15
schema/egress.go

@@ -19,7 +19,7 @@ type Egress struct {
 	Tags        datatypes.JSONMap `gorm:"tags" json:"tags"`
 	Range       string            `gorm:"range" json:"range"`
 	Nat         bool              `gorm:"nat" json:"nat"`
-	IsInetGw    bool              `gorm:"is_internet_gateway"`
+	IsInetGw    bool              `gorm:"is_internet_gateway" json:"is_internet_gateway"`
 	CreatedBy   string            `gorm:"created_by" json:"created_by"`
 	CreatedAt   time.Time         `gorm:"created_at" json:"created_at"`
 	UpdatedAt   time.Time         `gorm:"updated_at" json:"updated_at"`
@@ -29,35 +29,35 @@ func (e *Egress) Table() string {
 	return egressTable
 }
 
-func (e *Egress) Get() error {
-	return db.FromContext(context.TODO()).Table(e.Table()).First(&e).Where("id = ?", e.ID).Error
+func (e *Egress) Get(ctx context.Context) error {
+	return db.FromContext(ctx).Table(e.Table()).First(&e).Where("id = ?", e.ID).Error
 }
 
-func (e *Egress) Update() error {
-	return db.FromContext(context.TODO()).Table(e.Table()).Where("id = ?", e.ID).Updates(&e).Error
+func (e *Egress) Update(ctx context.Context) error {
+	return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(&e).Error
 }
 
-func (e *Egress) UpdateNatStatus() error {
-	return db.FromContext(context.TODO()).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{
+func (e *Egress) UpdateNatStatus(ctx context.Context) error {
+	return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{
 		"nat": e.Nat,
 	}).Error
 }
 
-func (e *Egress) UpdateINetGwStatus() error {
-	return db.FromContext(context.TODO()).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{
+func (e *Egress) UpdateINetGwStatus(ctx context.Context) error {
+	return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Updates(map[string]any{
 		"is_internet_gateway": e.IsInetGw,
 	}).Error
 }
 
-func (e *Egress) Create() error {
-	return db.FromContext(context.TODO()).Table(e.Table()).Create(&e).Error
+func (e *Egress) Create(ctx context.Context) error {
+	return db.FromContext(ctx).Table(e.Table()).Create(&e).Error
 }
 
-func (e *Egress) ListByNetwork() (egs []Egress, err error) {
-	err = db.FromContext(context.TODO()).Table(e.Table()).Where("network = ?", e.Network).Find(&egs).Error
+func (e *Egress) ListByNetwork(ctx context.Context) (egs []Egress, err error) {
+	err = db.FromContext(ctx).Table(e.Table()).Where("network = ?", e.Network).Find(&egs).Error
 	return
 }
 
-func (e *Egress) Delete() error {
-	return db.FromContext(context.TODO()).Table(e.Table()).Where("id = ?", e.ID).Delete(&e).Error
+func (e *Egress) Delete(ctx context.Context) error {
+	return db.FromContext(ctx).Table(e.Table()).Where("id = ?", e.ID).Delete(&e).Error
 }