Kaynağa Gözat

feat(go): use sql schema db for networks;

Vishal Dalwadi 7 ay önce
ebeveyn
işleme
151c7efe8c
8 değiştirilmiş dosya ile 82 ekleme ve 140 silme
  1. 2 2
      controllers/ext_client.go
  2. 2 2
      logic/dns.go
  3. 1 1
      logic/extpeers.go
  4. 1 1
      logic/gateway.go
  5. 60 96
      logic/networks.go
  6. 4 18
      logic/nodes.go
  7. 4 1
      logic/telemetry.go
  8. 8 19
      logic/util.go

+ 2 - 2
controllers/ext_client.go

@@ -189,7 +189,7 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	network, err := logic.GetParentNetwork(client.Network)
+	network, err := logic.GetNetwork(client.Network)
 	if err != nil {
 		logger.Log(
 			1,
@@ -399,7 +399,7 @@ func getExtClientHAConf(w http.ResponseWriter, r *http.Request) {
 
 	var params = mux.Vars(r)
 	networkid := params["network"]
-	network, err := logic.GetParentNetwork(networkid)
+	network, err := logic.GetNetwork(networkid)
 	if err != nil {
 		logger.Log(
 			1,

+ 2 - 2
logic/dns.go

@@ -254,7 +254,7 @@ func ValidateDNSCreate(entry models.DNSEntry) error {
 	})
 
 	_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
-		_, err := GetParentNetwork(entry.Network)
+		_, err := GetNetwork(entry.Network)
 		return err == nil
 	})
 
@@ -286,7 +286,7 @@ func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error {
 		return err == nil && num == 0
 	})
 	_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
-		_, err := GetParentNetwork(change.Network)
+		_, err := GetNetwork(change.Network)
 		return err == nil
 	})
 

+ 1 - 1
logic/extpeers.go

@@ -915,7 +915,7 @@ func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) {
 		return
 	}
 
-	network, err := GetParentNetwork(client.Network)
+	network, err := GetNetwork(client.Network)
 	if err != nil {
 		logger.Log(1, "Could not retrieve Ingress Gateway Network", client.Network)
 		return

+ 1 - 1
logic/gateway.go

@@ -184,7 +184,7 @@ func CreateIngressGateway(netid string, nodeid string, ingress models.IngressReq
 		return models.Node{}, errors.New("gateway can only be created on linux based node")
 	}
 
-	network, err := GetParentNetwork(netid)
+	network, err := GetNetwork(netid)
 	if err != nil {
 		return models.Node{}, err
 	}

+ 60 - 96
logic/networks.go

@@ -1,9 +1,14 @@
 package logic
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"fmt"
+	"github.com/gravitl/netmaker/converters"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
+	"gorm.io/gorm"
 	"net"
 	"sort"
 	"strings"
@@ -162,25 +167,21 @@ func storeNetworkInCache(key string, network models.Network) {
 
 // GetNetworks - returns all networks from database
 func GetNetworks() ([]models.Network, error) {
-	var networks []models.Network
 	if servercfg.CacheEnabled() {
 		networks := getNetworksFromCache()
 		if len(networks) != 0 {
 			return networks, nil
 		}
 	}
-	collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
+
+	_networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO()))
 	if err != nil {
-		return networks, err
+		return nil, err
 	}
 
-	for _, value := range collection {
-		var network models.Network
-		if err := json.Unmarshal([]byte(value), &network); err != nil {
-			return networks, err
-		}
-		// add network our array
-		networks = append(networks, network)
+	networks := converters.ToModelNetworks(_networks)
+
+	for _, network := range networks {
 		if servercfg.CacheEnabled() {
 			storeNetworkInCache(network.NetID, network)
 		}
@@ -190,24 +191,27 @@ func GetNetworks() ([]models.Network, error) {
 }
 
 // DeleteNetwork - deletes a network
-func DeleteNetwork(network string, force bool, done chan struct{}) error {
-
-	nodeCount, err := GetNetworkNonServerNodeCount(network)
+func DeleteNetwork(netID string, force bool, done chan struct{}) error {
+	nodeCount, err := GetNetworkNonServerNodeCount(netID)
 	if nodeCount == 0 || database.IsEmptyRecord(err) {
 		// delete server nodes first then db records
-		err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
+		_network := &schema.Network{
+			ID: netID,
+		}
+		err = _network.Delete(db.WithContext(context.TODO()))
 		if err != nil {
 			return err
 		}
+
 		if servercfg.CacheEnabled() {
-			deleteNetworkFromCache(network)
+			deleteNetworkFromCache(netID)
 		}
 		return nil
 	}
 
 	// Remove All Nodes
 	go func() {
-		nodes, err := GetNetworkNodes(network)
+		nodes, err := GetNetworkNodes(netID)
 		if err == nil {
 			for _, node := range nodes {
 				node := node
@@ -219,17 +223,22 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
 			}
 		}
 		// remove ACL for network
-		err = nodeacls.DeleteACLContainer(nodeacls.NetworkID(network))
+		err = nodeacls.DeleteACLContainer(nodeacls.NetworkID(netID))
 		if err != nil {
-			logger.Log(1, "failed to remove the node acls during network delete for network,", network)
+			logger.Log(1, "failed to remove the node acls during network delete for network,", netID)
 		}
+
 		// delete server nodes first then db records
-		err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
+		_network := &schema.Network{
+			ID: netID,
+		}
+		err = _network.Delete(db.WithContext(context.TODO()))
 		if err != nil {
 			return
 		}
+
 		if servercfg.CacheEnabled() {
-			deleteNetworkFromCache(network)
+			deleteNetworkFromCache(netID)
 		}
 		done <- struct{}{}
 		close(done)
@@ -238,7 +247,7 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error {
 	// Delete default network enrollment key
 	keys, _ := GetAllEnrollmentKeys()
 	for _, key := range keys {
-		if key.Tags[0] == network {
+		if key.Tags[0] == netID {
 			if key.Default {
 				DeleteEnrollmentKey(key.Value, true)
 				break
@@ -281,14 +290,12 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 		return models.Network{}, err
 	}
 
-	data, err := json.Marshal(&network)
+	_network := converters.ToSchemaNetwork(network)
+	err = _network.Create(db.WithContext(context.TODO()))
 	if err != nil {
 		return models.Network{}, err
 	}
 
-	if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
-		return models.Network{}, err
-	}
 	if servercfg.CacheEnabled() {
 		storeNetworkInCache(network.NetID, network)
 	}
@@ -334,49 +341,11 @@ func intersect(n1, n2 *net.IPNet) bool {
 	return n2.Contains(n1.IP) || n1.Contains(n2.IP)
 }
 
-// GetParentNetwork - get parent network
-func GetParentNetwork(networkname string) (models.Network, error) {
-
-	var network models.Network
-	if servercfg.CacheEnabled() {
-		if network, ok := getNetworkFromCache(networkname); ok {
-			return network, nil
-		}
-	}
-	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
-	if err != nil {
-		return network, err
-	}
-	if err = json.Unmarshal([]byte(networkData), &network); err != nil {
-		return models.Network{}, err
-	}
-	return network, nil
-}
-
-// GetNetworkSettings - get parent network
-func GetNetworkSettings(networkname string) (models.Network, error) {
-
-	var network models.Network
-	if servercfg.CacheEnabled() {
-		if network, ok := getNetworkFromCache(networkname); ok {
-			return network, nil
-		}
-	}
-	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
-	if err != nil {
-		return network, err
-	}
-	if err = json.Unmarshal([]byte(networkData), &network); err != nil {
-		return models.Network{}, err
-	}
-	return network, nil
-}
-
 // UniqueAddress - get a unique ipv4 address
 func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) {
 	add := net.IP{}
 	var network models.Network
-	network, err := GetParentNetwork(networkName)
+	network, err := GetNetwork(networkName)
 	if err != nil {
 		logger.Log(0, "UniqueAddressServer encountered  an error")
 		return add, err
@@ -419,7 +388,7 @@ func UniqueAddressCache(networkName string, reverse bool) (net.IP, error) {
 func UniqueAddressDB(networkName string, reverse bool) (net.IP, error) {
 	add := net.IP{}
 	var network models.Network
-	network, err := GetParentNetwork(networkName)
+	network, err := GetNetwork(networkName)
 	if err != nil {
 		logger.Log(0, "UniqueAddressServer encountered  an error")
 		return add, err
@@ -519,7 +488,7 @@ func UniqueAddress6(networkName string, reverse bool) (net.IP, error) {
 func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) {
 	add := net.IP{}
 	var network models.Network
-	network, err := GetParentNetwork(networkName)
+	network, err := GetNetwork(networkName)
 	if err != nil {
 		fmt.Println("Network Not Found")
 		return add, err
@@ -564,7 +533,7 @@ func UniqueAddress6DB(networkName string, reverse bool) (net.IP, error) {
 func UniqueAddress6Cache(networkName string, reverse bool) (net.IP, error) {
 	add := net.IP{}
 	var network models.Network
-	network, err := GetParentNetwork(networkName)
+	network, err := GetNetwork(networkName)
 	if err != nil {
 		fmt.Println("Network Not Found")
 		return add, err
@@ -635,12 +604,10 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (
 		hasrangeupdate4 := newNetwork.AddressRange != currentNetwork.AddressRange
 		hasrangeupdate6 := newNetwork.AddressRange6 != currentNetwork.AddressRange6
 		hasholepunchupdate := newNetwork.DefaultUDPHolePunch != currentNetwork.DefaultUDPHolePunch
-		data, err := json.Marshal(newNetwork)
-		if err != nil {
-			return false, false, false, err
-		}
 		newNetwork.SetNetworkLastModified()
-		err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
+
+		_network := converters.ToSchemaNetwork(*newNetwork)
+		err := _network.Update(db.WithContext(context.TODO()))
 		if err == nil {
 			if servercfg.CacheEnabled() {
 				storeNetworkInCache(newNetwork.NetID, *newNetwork)
@@ -653,22 +620,22 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (
 }
 
 // GetNetwork - gets a network from database
-func GetNetwork(networkname string) (models.Network, error) {
-
-	var network models.Network
+func GetNetwork(netID string) (models.Network, error) {
 	if servercfg.CacheEnabled() {
-		if network, ok := getNetworkFromCache(networkname); ok {
+		if network, ok := getNetworkFromCache(netID); ok {
 			return network, nil
 		}
 	}
-	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
-	if err != nil {
-		return network, err
+
+	_network := &schema.Network{
+		ID: netID,
 	}
-	if err = json.Unmarshal([]byte(networkData), &network); err != nil {
+	err := _network.Get(db.WithContext(context.TODO()))
+	if err != nil {
 		return models.Network{}, err
 	}
-	return network, nil
+
+	return converters.ToModelNetwork(*_network), nil
 }
 
 // NetIDInNetworkCharSet - checks if a netid of a network uses valid characters
@@ -718,13 +685,12 @@ func ParseNetwork(value string) (models.Network, error) {
 
 // SaveNetwork - save network struct to database
 func SaveNetwork(network *models.Network) error {
-	data, err := json.Marshal(network)
+	_network := converters.ToSchemaNetwork(*network)
+	err := _network.Update(db.WithContext(context.TODO()))
 	if err != nil {
 		return err
 	}
-	if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
-		return err
-	}
+
 	if servercfg.CacheEnabled() {
 		storeNetworkInCache(network.NetID, *network)
 	}
@@ -732,19 +698,17 @@ func SaveNetwork(network *models.Network) error {
 }
 
 // NetworkExists - check if network exists
-func NetworkExists(name string) (bool, error) {
-
-	var network string
-	var err error
-	if servercfg.CacheEnabled() {
-		if _, ok := getNetworkFromCache(name); ok {
-			return ok, nil
+func NetworkExists(netID string) (bool, error) {
+	_, err := GetNetwork(netID)
+	if err != nil {
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			return false, nil
+		} else {
+			return false, err
 		}
 	}
-	if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
-		return false, err
-	}
-	return len(network) > 0, nil
+
+	return true, nil
 }
 
 // SortNetworks - Sorts slice of Networks by their NetID alphabetically with numbers first

+ 4 - 18
logic/nodes.go

@@ -197,7 +197,7 @@ func UpsertNode(newNode *models.Node) error {
 // UpdateNode - takes a node and updates another node with it's values
 func UpdateNode(currentNode *models.Node, newNode *models.Node) error {
 	if newNode.Address.IP.String() != currentNode.Address.IP.String() {
-		if network, err := GetParentNetwork(newNode.Network); err == nil {
+		if network, err := GetNetwork(newNode.Network); err == nil {
 			if !IsAddressInCIDR(newNode.Address.IP, network.AddressRange) {
 				return fmt.Errorf("invalid address provided; out of network range for node %s", newNode.ID)
 			}
@@ -394,7 +394,7 @@ func ValidateNode(node *models.Node, isUpdate bool) error {
 		return isFieldUnique
 	})
 	_ = v.RegisterValidation("network_exists", func(fl validator.FieldLevel) bool {
-		_, err := GetNetworkByNode(node)
+		_, err := GetNetwork(node.Network)
 		return err == nil
 	})
 	_ = v.RegisterValidation("checkyesornoorunset", func(f1 validator.FieldLevel) bool {
@@ -477,24 +477,10 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m
 	return
 }
 
-// GetNetworkByNode - gets the network model from a node
-func GetNetworkByNode(node *models.Node) (models.Network, error) {
-
-	var network = models.Network{}
-	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, node.Network)
-	if err != nil {
-		return network, err
-	}
-	if err = json.Unmarshal([]byte(networkData), &network); err != nil {
-		return models.Network{}, err
-	}
-	return network, nil
-}
-
 // SetNodeDefaults - sets the defaults of a node to avoid empty fields
 func SetNodeDefaults(node *models.Node, resetConnected bool) {
 
-	parentNetwork, _ := GetNetworkByNode(node)
+	parentNetwork, _ := GetNetwork(node.Network)
 	_, cidr, err := net.ParseCIDR(parentNetwork.AddressRange)
 	if err == nil {
 		node.NetworkRange = *cidr
@@ -784,7 +770,7 @@ func ValidateNodeIp(currentNode *models.Node, newNode *models.ApiNode) error {
 }
 
 func ValidateEgressRange(gateway models.EgressGatewayRequest) error {
-	network, err := GetNetworkSettings(gateway.NetID)
+	network, err := GetNetwork(gateway.NetID)
 	if err != nil {
 		slog.Error("error getting network with netid", "error", gateway.NetID, err.Error)
 		return errors.New("error getting network with netid:  " + gateway.NetID + " " + err.Error())

+ 4 - 1
logic/telemetry.go

@@ -1,7 +1,10 @@
 package logic
 
 import (
+	"context"
 	"encoding/json"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"os"
 	"time"
 
@@ -89,7 +92,7 @@ func FetchTelemetryData() telemetryData {
 	data.IsPro = servercfg.IsPro
 	data.ExtClients = getDBLength(database.EXT_CLIENT_TABLE_NAME)
 	data.Users = getDBLength(database.USERS_TABLE_NAME)
-	data.Networks = getDBLength(database.NETWORKS_TABLE_NAME)
+	data.Networks, _ = (&schema.Network{}).Count(db.WithContext(context.TODO()))
 	data.Hosts = getDBLength(database.HOSTS_TABLE_NAME)
 	data.Version = servercfg.GetVersion()
 	data.Servers = getServerCount()

+ 8 - 19
logic/util.go

@@ -2,11 +2,13 @@
 package logic
 
 import (
+	"context"
 	"crypto/rand"
 	"encoding/base32"
 	"encoding/base64"
-	"encoding/json"
 	"fmt"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"log/slog"
 	"net"
 	"os"
@@ -16,7 +18,6 @@ import (
 
 	"github.com/blang/semver"
 	"github.com/c-robinson/iplib"
-	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
 )
 
@@ -52,24 +53,12 @@ func IsAddressInCIDR(address net.IP, cidr string) bool {
 }
 
 // SetNetworkNodesLastModified - sets the network nodes last modified
-func SetNetworkNodesLastModified(networkName string) error {
-
-	timestamp := time.Now().Unix()
-
-	network, err := GetParentNetwork(networkName)
-	if err != nil {
-		return err
-	}
-	network.NodesLastModified = timestamp
-	data, err := json.Marshal(&network)
-	if err != nil {
-		return err
-	}
-	err = database.Insert(networkName, string(data), database.NETWORKS_TABLE_NAME)
-	if err != nil {
-		return err
+func SetNetworkNodesLastModified(netID string) error {
+	_network := &schema.Network{
+		ID:                netID,
+		NodesLastModified: time.Now().Unix(),
 	}
-	return nil
+	return _network.UpdateNodesLastModified(db.WithContext(context.TODO()))
 }
 
 // RandomString - returns a random string in a charset