Bläddra i källkod

add network name validation

abhishek9686 9 månader sedan
förälder
incheckning
2685033730
6 ändrade filer med 41 tillägg och 41 borttagningar
  1. 3 3
      controllers/dns.go
  2. 15 12
      logic/dns.go
  3. 10 19
      logic/networks.go
  4. 5 0
      logic/tags.go
  5. 6 5
      models/tags.go
  6. 2 2
      mq/publishers.go

+ 3 - 3
controllers/dns.go

@@ -49,7 +49,7 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) {
 	var dns []models.DNSEntry
 	var params = mux.Vars(r)
 	network := params["network"]
-	dns, err := logic.GetNodeDNS(network)
+	dns, err := logic.GetNodeDNS(models.NetworkID(network))
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err))
@@ -125,7 +125,7 @@ func getDNS(w http.ResponseWriter, r *http.Request) {
 	var dns []models.DNSEntry
 	var params = mux.Vars(r)
 	network := params["network"]
-	dns, err := logic.GetDNS(network)
+	dns, err := logic.GetDNS(models.NetworkID(network))
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
 			fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error()))
@@ -298,7 +298,7 @@ func syncDNS(w http.ResponseWriter, r *http.Request) {
 	}
 	var params = mux.Vars(r)
 	netID := params["network"]
-	k, err := logic.GetDNS(netID)
+	k, err := logic.GetDNS(models.NetworkID(netID))
 	if err == nil && len(k) > 0 {
 		err = mq.PushSyncDNS(k)
 	}

+ 15 - 12
logic/dns.go

@@ -28,7 +28,7 @@ func SetDNS() error {
 
 	for _, net := range networks {
 		corefilestring = corefilestring + net.NetID + " "
-		dns, err := GetDNS(net.NetID)
+		dns, err := GetDNS(models.NetworkID(net.NetID))
 		if err != nil && !database.IsEmptyRecord(err) {
 			return err
 		}
@@ -58,13 +58,13 @@ func SetDNS() error {
 }
 
 // GetDNS - gets the DNS of a current network
-func GetDNS(network string) ([]models.DNSEntry, error) {
+func GetDNS(networkID models.NetworkID) ([]models.DNSEntry, error) {
 
-	dns, err := GetNodeDNS(network)
+	dns, err := GetNodeDNS(networkID)
 	if err != nil && !database.IsEmptyRecord(err) {
 		return dns, err
 	}
-	customdns, err := GetCustomDNS(network)
+	customdns, err := GetCustomDNS(networkID.String())
 	if err != nil && !database.IsEmptyRecord(err) {
 		return dns, err
 	}
@@ -96,17 +96,20 @@ func GetExtclientDNS() []models.DNSEntry {
 }
 
 // GetNodeDNS - gets the DNS of a network node
-func GetNodeDNS(network string) ([]models.DNSEntry, error) {
+func GetNodeDNS(networkID models.NetworkID) ([]models.DNSEntry, error) {
 
 	var dns []models.DNSEntry
-
-	nodes, err := GetNetworkNodes(network)
+	net, err := GetNetwork(networkID.String())
+	if err != nil {
+		return []models.DNSEntry{}, err
+	}
+	nodes, err := GetNetworkNodes(networkID.String())
 	if err != nil {
 		return dns, err
 	}
 
 	for _, node := range nodes {
-		if node.Network != network {
+		if node.Network != networkID.String() {
 			continue
 		}
 		host, err := GetHost(node.HostID.String())
@@ -114,8 +117,8 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) {
 			continue
 		}
 		var entry = models.DNSEntry{}
-		entry.Name = fmt.Sprintf("%s.%s", host.Name, network)
-		entry.Network = network
+		entry.Name = fmt.Sprintf("%s.%s", host.Name, net.Name)
+		entry.Network = net.NetID
 		if node.Address.IP != nil {
 			entry.Address = node.Address.IP.String()
 		}
@@ -188,7 +191,7 @@ func GetAllDNS() ([]models.DNSEntry, error) {
 		return []models.DNSEntry{}, err
 	}
 	for _, net := range networks {
-		netdns, err := GetDNS(net.Name)
+		netdns, err := GetDNS(models.NetworkID(net.NetID))
 		if err != nil {
 			return []models.DNSEntry{}, nil
 		}
@@ -202,7 +205,7 @@ func GetDNSEntryNum(domain string, network string) (int, error) {
 
 	num := 0
 
-	entries, err := GetDNS(network)
+	entries, err := GetDNS(models.NetworkID(network))
 	if err != nil {
 		return 0, err
 	}

+ 10 - 19
logic/networks.go

@@ -5,6 +5,7 @@ import (
 	"errors"
 	"fmt"
 	"net"
+	"regexp"
 	"sort"
 	"strings"
 	"sync"
@@ -230,7 +231,7 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 	network.SetDefaults()
 	network.SetNodesLastModified()
 	network.SetNetworkLastModified()
-
+	network.Name = strings.ReplaceAll(network.Name, " ", "-")
 	err := ValidateNetwork(&network, false)
 	if err != nil {
 		//logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
@@ -483,6 +484,7 @@ func IsNetworkNameUnique(network *models.Network) (bool, error) {
 
 // UpdateNetwork - updates a network with another network's fields
 func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (bool, bool, bool, error) {
+	newNetwork.Name = strings.ReplaceAll(newNetwork.Name, " ", "-")
 	if err := ValidateNetwork(newNetwork, true); err != nil {
 		return false, false, false, err
 	}
@@ -561,16 +563,10 @@ func GetNetwork(networkID string) (models.Network, error) {
 	return network, nil
 }
 
-// NetIDInNetworkCharSet - checks if a netid of a network uses valid characters
-func NetIDInNetworkCharSet(network *models.Network) bool {
-	charset := "abcdefghijklmnopqrstuvwxyz1234567890-_"
-
-	for _, char := range network.NetID {
-		if !strings.Contains(charset, string(char)) {
-			return false
-		}
-	}
-	return true
+// IsNetworkNameValid - checks if a netid of a network uses valid characters
+func IsNetworkNameValid(network *models.Network) bool {
+	re := regexp.MustCompile(`^[A-Za-z0-9-]+$`)
+	return re.MatchString(network.Name)
 }
 
 // Validate - validates fields of an network struct
@@ -580,7 +576,9 @@ func ValidateNetwork(network *models.Network, isUpdate bool) error {
 	if !isFieldUnique {
 		return errors.New("duplicate network name")
 	}
-	//
+	if !IsNetworkNameValid(network) {
+		return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), and the minus sign (-) are allowed")
+	}
 	_ = v.RegisterValidation("checkyesorno", func(fl validator.FieldLevel) bool {
 		return validation.CheckYesOrNo(fl)
 	})
@@ -594,13 +592,6 @@ func ValidateNetwork(network *models.Network, isUpdate bool) error {
 	return err
 }
 
-// ParseNetwork - parses a network into a model
-func ParseNetwork(value string) (models.Network, error) {
-	var network models.Network
-	err := json.Unmarshal([]byte(value), &network)
-	return network, err
-}
-
 // SaveNetwork - save network struct to database
 func SaveNetwork(network *models.Network) error {
 	data, err := json.Marshal(network)

+ 5 - 0
logic/tags.go

@@ -85,9 +85,14 @@ func ListTagsWithNodes(netID models.NetworkID) ([]models.TagListResp, error) {
 	if err != nil {
 		return []models.TagListResp{}, err
 	}
+	network, err := GetNetwork(netID.String())
+	if err != nil {
+		return []models.TagListResp{}, err
+	}
 	tagsNodeMap := GetTagMapWithNodesByNetwork(netID)
 	resp := []models.TagListResp{}
 	for _, tagI := range tags {
+		tagI.NetworkName = network.Name
 		tagRespI := models.TagListResp{
 			Tag:         tagI,
 			UsedByCnt:   len(tagsNodeMap[tagI.ID]),

+ 6 - 5
models/tags.go

@@ -20,11 +20,12 @@ func (t Tag) GetIDFromName() string {
 }
 
 type Tag struct {
-	ID        TagID     `json:"id"`
-	TagName   string    `json:"tag_name"`
-	Network   NetworkID `json:"network"`
-	CreatedBy string    `json:"created_by"`
-	CreatedAt time.Time `json:"created_at"`
+	ID          TagID     `json:"id"`
+	TagName     string    `json:"tag_name"`
+	Network     NetworkID `json:"network"`
+	NetworkName string    `json:"network_name"`
+	CreatedBy   string    `json:"created_by"`
+	CreatedAt   time.Time `json:"created_at"`
 }
 
 type CreateTagReq struct {

+ 2 - 2
mq/publishers.go

@@ -256,7 +256,7 @@ func sendPeers() {
 
 func SendDNSSyncByNetwork(network string) error {
 
-	k, err := logic.GetDNS(network)
+	k, err := logic.GetDNS(models.NetworkID(network))
 	if err == nil && len(k) > 0 {
 		err = PushSyncDNS(k)
 		if err != nil {
@@ -272,7 +272,7 @@ func sendDNSSync() error {
 	networks, err := logic.GetNetworks()
 	if err == nil && len(networks) > 0 {
 		for _, v := range networks {
-			k, err := logic.GetDNS(v.NetID)
+			k, err := logic.GetDNS(models.NetworkID(v.NetID))
 			if err == nil && len(k) > 0 {
 				err = PushSyncDNS(k)
 				if err != nil {