Browse Source

enforce unique names for ext client names (#2476)

* enforce unique names for ext client names

* only check for unique id on creation

* check for unique id if changed
Matthew R Kasun 2 years ago
parent
commit
0c70c4daba
4 changed files with 90 additions and 68 deletions
  1. 69 59
      controllers/ext_client.go
  2. 1 0
      controllers/regex.go
  3. 15 0
      logic/clients.go
  4. 5 9
      logic/extpeers.go

+ 69 - 59
controllers/ext_client.go

@@ -17,6 +17,7 @@ import (
 	"github.com/gravitl/netmaker/models/promodels"
 	"github.com/gravitl/netmaker/mq"
 	"github.com/skip2/go-qrcode"
+	"golang.org/x/exp/slog"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
@@ -308,31 +309,28 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Content-Type", "application/json")
 
 	var params = mux.Vars(r)
-	networkName := params["network"]
 	nodeid := params["nodeid"]
 
 	ingressExists := checkIngressExists(nodeid)
 	if !ingressExists {
 		err := errors.New("ingress does not exist")
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to create extclient on network [%s]: %v", networkName, err))
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		slog.Error("failed to create extclient", "user", r.Header.Get("user"), "error", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
 
-	var extclient models.ExtClient
 	var customExtClient models.CustomExtClient
 
 	if err := json.NewDecoder(r.Body).Decode(&customExtClient); err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
-	if err := validateExtClient(&extclient, &customExtClient); err != nil {
+	if err := validateCustomExtClient(&customExtClient, true); err != nil {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
+	extclient := logic.UpdateExtClient(&models.ExtClient{}, &customExtClient)
 
-	extclient.Network = networkName
 	extclient.IngressGatewayID = nodeid
 	node, err := logic.GetNodeByID(nodeid)
 	if err != nil {
@@ -341,6 +339,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
+	extclient.Network = node.Network
 	host, err := logic.GetHost(node.HostID.String())
 	if err != nil {
 		logger.Log(0, r.Header.Get("user"),
@@ -351,21 +350,19 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 	listenPort := logic.GetPeerListenPort(host)
 	extclient.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort)
 	extclient.Enabled = true
-	parentNetwork, err := logic.GetNetwork(networkName)
+	parentNetwork, err := logic.GetNetwork(node.Network)
 	if err == nil { // check if parent network default ACL is enabled (yes) or not (no)
 		extclient.Enabled = parentNetwork.DefaultACL == "yes"
 	}
 
 	if err := logic.SetClientDefaultACLs(&extclient); err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to assign ACLs to new ext client on network [%s]: %v", networkName, err))
+		slog.Error("failed to set default acls for extclient", "user", r.Header.Get("user"), "network", node.Network, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
 
 	if err = logic.CreateExtClient(&extclient); err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to create new ext client on network [%s]: %v", networkName, err))
+		slog.Error("failed to create extclient", "user", r.Header.Get("user"), "network", node.Network, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
@@ -374,13 +371,13 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 	if r.Header.Get("ismaster") != "yes" {
 		userID := r.Header.Get("user")
 		if isAdmin, err = checkProClientAccess(userID, extclient.ClientID, &parentNetwork); err != nil {
-			logger.Log(0, userID, "attempted to create a client on network", networkName, "but they lack access")
-			logic.DeleteExtClient(networkName, extclient.ClientID)
+			slog.Error("pro client access check failed", "user", userID, "network", node.Network, "error", err)
+			logic.DeleteExtClient(node.Network, extclient.ClientID)
 			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 			return
 		}
 		if !isAdmin {
-			if err = pro.AssociateNetworkUserClient(userID, networkName, extclient.ClientID); err != nil {
+			if err = pro.AssociateNetworkUserClient(userID, node.Network, extclient.ClientID); err != nil {
 				logger.Log(0, "failed to associate client", extclient.ClientID, "to user", userID)
 			}
 			extclient.OwnerID = userID
@@ -390,7 +387,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 
-	logger.Log(0, r.Header.Get("user"), "created new ext client on network", networkName)
+	slog.Info("created extclient", "user", r.Header.Get("user"), "network", node.Network, "clientid", extclient.ClientID)
 	w.WriteHeader(http.StatusOK)
 	go func() {
 		if err := mq.PublishPeerUpdate(); err != nil {
@@ -419,7 +416,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 	var params = mux.Vars(r)
 
 	var update models.CustomExtClient
-	var oldExtClient models.ExtClient
+	//var oldExtClient models.ExtClient
 	var sendPeerUpdate bool
 	err := json.NewDecoder(r.Body).Decode(&update)
 	if err != nil {
@@ -429,50 +426,40 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	clientid := params["clientid"]
-	network := params["network"]
-	key, err := logic.GetRecordKey(clientid, network)
+	oldExtClient, err := logic.GetExtClientByName(clientid)
 	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to get record key for client [%s], network [%s]: %v",
-				clientid, network, err))
+		slog.Error("failed to retrieve extclient", "user", r.Header.Get("user"), "id", clientid, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
-	if err := validateExtClient(&oldExtClient, &update); err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
-		return
-	}
-	data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key)
-	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to fetch  ext client record key [%s] from db for client [%s], network [%s]: %v",
-				key, clientid, network, err))
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-		return
-	}
-	if err = json.Unmarshal([]byte(data), &oldExtClient); err != nil {
-		logger.Log(0, "error unmarshalling extclient: ",
-			err.Error())
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
-		return
+	if oldExtClient.ClientID == update.ClientID {
+		if err := validateCustomExtClient(&update, false); err != nil {
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
+	} else {
+		if err := validateCustomExtClient(&update, true); err != nil {
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
 	}
 
 	// == PRO ==
-	networkName := params["network"]
+	//networkName := params["network"]
 	var changedID = update.ClientID != oldExtClient.ClientID
 	if r.Header.Get("ismaster") != "yes" {
 		userID := r.Header.Get("user")
-		_, doesOwn := doesUserOwnClient(userID, params["clientid"], networkName)
+		_, doesOwn := doesUserOwnClient(userID, params["clientid"], oldExtClient.Network)
 		if !doesOwn {
 			logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("user not permitted"), "internal"))
 			return
 		}
 	}
 	if changedID && oldExtClient.OwnerID != "" {
-		if err := pro.DissociateNetworkUserClient(oldExtClient.OwnerID, networkName, oldExtClient.ClientID); err != nil {
+		if err := pro.DissociateNetworkUserClient(oldExtClient.OwnerID, oldExtClient.Network, oldExtClient.ClientID); err != nil {
 			logger.Log(0, "failed to dissociate client", oldExtClient.ClientID, "from user", oldExtClient.OwnerID)
 		}
-		if err := pro.AssociateNetworkUserClient(oldExtClient.OwnerID, networkName, update.ClientID); err != nil {
+		if err := pro.AssociateNetworkUserClient(oldExtClient.OwnerID, oldExtClient.Network, update.ClientID); err != nil {
 			logger.Log(0, "failed to associate client", update.ClientID, "to user", oldExtClient.OwnerID)
 		}
 	}
@@ -485,13 +472,15 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 	if update.Enabled != oldExtClient.Enabled {
 		sendPeerUpdate = true
 	}
-	// extra var need as logic.Update changes oldExtClient
-	currentClient := oldExtClient
-	newclient, err := logic.UpdateExtClient(&oldExtClient, &update)
-	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("failed to update ext client [%s], network [%s]: %v",
-				clientid, network, err))
+	newclient := logic.UpdateExtClient(&oldExtClient, &update)
+	if err := logic.DeleteExtClient(oldExtClient.Network, oldExtClient.ClientID); err != nil {
+
+		slog.Error("failed to delete ext client", "user", r.Header.Get("user"), "id", oldExtClient.ClientID, "network", oldExtClient.Network, "error", err)
+		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
+		return
+	}
+	if err := logic.SaveExtClient(&newclient); err != nil {
+		slog.Error("failed to save ext client", "user", r.Header.Get("user"), "id", newclient.ClientID, "network", newclient.Network, "error", err)
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
 	}
@@ -507,7 +496,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
 	json.NewEncoder(w).Encode(newclient)
 	if changedID {
 		go func() {
-			if err := mq.PublishExtClientDNSUpdate(currentClient, *newclient, networkName); err != nil {
+			if err := mq.PublishExtClientDNSUpdate(oldExtClient, newclient, oldExtClient.Network); err != nil {
 				logger.Log(1, "error pubishing dns update for extcient update", err.Error())
 			}
 		}()
@@ -647,18 +636,20 @@ func doesUserOwnClient(username, clientID, network string) (bool, bool) {
 	return false, logic.StringSliceContains(netUser.Clients, clientID)
 }
 
-// validateExtClient	Validates the extclient object
-func validateExtClient(extclient *models.ExtClient, customExtClient *models.CustomExtClient) error {
+// validateCustomExtClient	Validates the extclient object
+func validateCustomExtClient(customExtClient *models.CustomExtClient, checkID bool) error {
 	//validate clientid
-	if customExtClient.ClientID != "" && !validName(customExtClient.ClientID) {
-		return errInvalidExtClientID
+	if customExtClient.ClientID != "" {
+		if err := isValid(customExtClient.ClientID, checkID); err != nil {
+			return fmt.Errorf("client validatation: %v", err)
+		}
 	}
-	extclient.ClientID = customExtClient.ClientID
+	//extclient.ClientID = customExtClient.ClientID
 	if len(customExtClient.PublicKey) > 0 {
 		if _, err := wgtypes.ParseKey(customExtClient.PublicKey); err != nil {
 			return errInvalidExtClientPubKey
 		}
-		extclient.PublicKey = customExtClient.PublicKey
+		//extclient.PublicKey = customExtClient.PublicKey
 	}
 	//validate extra ips
 	if len(customExtClient.ExtraAllowedIPs) > 0 {
@@ -667,14 +658,33 @@ func validateExtClient(extclient *models.ExtClient, customExtClient *models.Cust
 				return errInvalidExtClientExtraIP
 			}
 		}
-		extclient.ExtraAllowedIPs = customExtClient.ExtraAllowedIPs
+		//extclient.ExtraAllowedIPs = customExtClient.ExtraAllowedIPs
 	}
 	//validate DNS
 	if customExtClient.DNS != "" {
 		if ip := net.ParseIP(customExtClient.DNS); ip == nil {
 			return errInvalidExtClientDNS
 		}
-		extclient.DNS = customExtClient.DNS
+		//extclient.DNS = customExtClient.DNS
+	}
+	return nil
+}
+
+// isValid	Checks if the clientid is valid
+func isValid(clientid string, checkID bool) error {
+	if !validName(clientid) {
+		return errInvalidExtClientID
+	}
+	if checkID {
+		extclients, err := logic.GetAllExtClients()
+		if err != nil {
+			return fmt.Errorf("extclients isValid: %v", err)
+		}
+		for _, extclient := range extclients {
+			if clientid == extclient.ClientID {
+				return errDuplicateExtClientName
+			}
+		}
 	}
 	return nil
 }

+ 1 - 0
controllers/regex.go

@@ -10,6 +10,7 @@ var (
 	errInvalidExtClientID      = errors.New("ext client ID must be alphanumderic and/or dashes and less that 15 chars")
 	errInvalidExtClientExtraIP = errors.New("ext client extra ip must be a valid cidr")
 	errInvalidExtClientDNS     = errors.New("ext client dns must be a valid ip address")
+	errDuplicateExtClientName  = errors.New("duplicate client name")
 )
 
 // allow only dashes and alphaneumeric for ext client and node names

+ 15 - 0
logic/clients.go

@@ -1,6 +1,7 @@
 package logic
 
 import (
+	"errors"
 	"sort"
 
 	"github.com/gravitl/netmaker/models"
@@ -70,3 +71,17 @@ func SortExtClient(unsortedExtClient []models.ExtClient) {
 		return unsortedExtClient[i].ClientID < unsortedExtClient[j].ClientID
 	})
 }
+
+// GetExtClientByName - gets an ext client by name
+func GetExtClientByName(ID string) (models.ExtClient, error) {
+	clients, err := GetAllExtClients()
+	if err != nil {
+		return models.ExtClient{}, err
+	}
+	for i := range clients {
+		if clients[i].ClientID == ID {
+			return clients[i], nil
+		}
+	}
+	return models.ExtClient{}, errors.New("client not found")
+}

+ 5 - 9
logic/extpeers.go

@@ -152,9 +152,9 @@ func GetExtClientByPubKey(publicKey string, network string) (*models.ExtClient,
 	return nil, fmt.Errorf("no client found")
 }
 
-// CreateExtClient - creates an extclient
+// CreateExtClient - creates and saves an extclient
 func CreateExtClient(extclient *models.ExtClient) error {
-	// lock because we need unique IPs and having it concurrent makes parallel calls result in same "unique" IPs
+	// lock because we may need unique IPs and having it concurrent makes parallel calls result in same "unique" IPs
 	addressLock.Lock()
 	defer addressLock.Unlock()
 
@@ -219,12 +219,8 @@ func SaveExtClient(extclient *models.ExtClient) error {
 }
 
 // UpdateExtClient - updates an ext client with new values
-func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) (*models.ExtClient, error) {
-	new := old
-	err := DeleteExtClient(old.Network, old.ClientID)
-	if err != nil {
-		return new, err
-	}
+func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) models.ExtClient {
+	new := *old
 	new.ClientID = update.ClientID
 	if update.PublicKey != "" && old.PublicKey != update.PublicKey {
 		new.PublicKey = update.PublicKey
@@ -241,7 +237,7 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) (*mo
 	if update.DeniedACLs != nil && !reflect.DeepEqual(old.DeniedACLs, update.DeniedACLs) {
 		new.DeniedACLs = update.DeniedACLs
 	}
-	return new, CreateExtClient(new)
+	return new
 }
 
 // GetExtClientsByID - gets the clients of attached gateway