Explorar o código

NET-1440 scale test changes

Max Ma hai 1 ano
pai
achega
18d48024dd
Modificáronse 13 ficheiros con 440 adicións e 24 borrados
  1. 6 0
      controllers/network.go
  2. 45 2
      logic/enrollmentkey.go
  3. 2 2
      logic/extpeers.go
  4. 283 0
      logic/ippool.go
  5. 1 1
      logic/jwts.go
  6. 83 2
      logic/networks.go
  7. 2 2
      logic/nodes.go
  8. 2 0
      logic/zombie.go
  9. 2 0
      main.go
  10. 6 2
      mq/emqx_on_prem.go
  11. 5 9
      mq/mq.go
  12. 1 1
      mq/publishers.go
  13. 2 3
      mq/util.go

+ 6 - 0
controllers/network.go

@@ -392,6 +392,8 @@ func deleteNetwork(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, errtype))
 		return
 	}
+	//delete network from ip pool
+	go logic.RemoveNetworkFromIpPool(network)
 
 	logger.Log(1, r.Header.Get("user"), "deleted network", network)
 	w.WriteHeader(http.StatusOK)
@@ -467,6 +469,10 @@ func createNetwork(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
 		return
 	}
+
+	//add new network to ip pool
+	go logic.AddNetworkToIpPool(network.NetID)
+
 	go func() {
 		defaultHosts := logic.GetDefaultHosts()
 		for i := range defaultHosts {

+ 45 - 2
logic/enrollmentkey.go

@@ -5,11 +5,13 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"sync"
 	"time"
 
 	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/exp/slices"
 )
 
@@ -29,6 +31,10 @@ var EnrollmentErrors = struct {
 	FailedToTokenize:   fmt.Errorf("failed to tokenize"),
 	FailedToDeTokenize: fmt.Errorf("failed to detokenize"),
 }
+var (
+	enrollmentkeyCacheMutex = &sync.RWMutex{}
+	enrollmentkeyCacheMap   = make(map[string]*models.EnrollmentKey)
+)
 
 // CreateEnrollmentKey - creates a new enrollment key in db
 func CreateEnrollmentKey(uses int, expiration time.Time, networks, tags []string, unlimited bool, relay uuid.UUID) (*models.EnrollmentKey, error) {
@@ -138,13 +144,25 @@ func GetEnrollmentKey(value string) (*models.EnrollmentKey, error) {
 	return nil, EnrollmentErrors.NoKeyFound
 }
 
+func deleteEnrollmentkeyFromCache(key string) {
+	enrollmentkeyCacheMutex.Lock()
+	delete(enrollmentkeyCacheMap, key)
+	enrollmentkeyCacheMutex.Unlock()
+}
+
 // DeleteEnrollmentKey - delete's a given enrollment key by value
 func DeleteEnrollmentKey(value string) error {
 	_, err := GetEnrollmentKey(value)
 	if err != nil {
 		return err
 	}
-	return database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
+	err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value)
+	if err == nil {
+		if servercfg.CacheEnabled() {
+			deleteEnrollmentkeyFromCache(value)
+		}
+	}
+	return err
 }
 
 // TryToUseEnrollmentKey - checks first if key can be decremented
@@ -230,7 +248,13 @@ func upsertEnrollmentKey(k *models.EnrollmentKey) error {
 	if err != nil {
 		return err
 	}
-	return database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
+	err = database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME)
+	if err == nil {
+		if servercfg.CacheEnabled() {
+			storeEnrollmentkeyInCache(k.Value, k)
+		}
+	}
+	return nil
 }
 
 func getUniqueEnrollmentID() (string, error) {
@@ -245,7 +269,23 @@ func getUniqueEnrollmentID() (string, error) {
 	return newID, nil
 }
 
+func getEnrollmentkeysFromCache() map[string]*models.EnrollmentKey {
+	return enrollmentkeyCacheMap
+}
+
+func storeEnrollmentkeyInCache(key string, enrollmentkey *models.EnrollmentKey) {
+	enrollmentkeyCacheMutex.Lock()
+	enrollmentkeyCacheMap[key] = enrollmentkey
+	enrollmentkeyCacheMutex.Unlock()
+}
+
 func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
+	if servercfg.CacheEnabled() {
+		keys := getEnrollmentkeysFromCache()
+		if len(keys) != 0 {
+			return keys, nil
+		}
+	}
 	records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME)
 	if err != nil {
 		if !database.IsEmptyRecord(err) {
@@ -263,6 +303,9 @@ func getEnrollmentKeysMap() (map[string]*models.EnrollmentKey, error) {
 				continue
 			}
 			currentKeys[k] = &currentKey
+			if servercfg.CacheEnabled() {
+				storeEnrollmentkeyInCache(currentKey.Value, &currentKey)
+			}
 		}
 	}
 	return currentKeys, nil

+ 2 - 2
logic/extpeers.go

@@ -245,7 +245,7 @@ func CreateExtClient(extclient *models.ExtClient) error {
 	}
 	if extclient.Address == "" {
 		if parentNetwork.IsIPv4 == "yes" {
-			newAddress, err := UniqueAddress(extclient.Network, true)
+			newAddress, err := GetUniqueAddress(extclient.Network)
 			if err != nil {
 				return err
 			}
@@ -255,7 +255,7 @@ func CreateExtClient(extclient *models.ExtClient) error {
 
 	if extclient.Address6 == "" {
 		if parentNetwork.IsIPv6 == "yes" {
-			addr6, err := UniqueAddress6(extclient.Network, true)
+			addr6, err := GetUniqueAddress6(extclient.Network)
 			if err != nil {
 				return err
 			}

+ 283 - 0
logic/ippool.go

@@ -0,0 +1,283 @@
+package logic
+
+import (
+	"container/heap"
+	"errors"
+	"net"
+	"net/netip"
+	"sync"
+
+	"github.com/c-robinson/iplib"
+	"github.com/gravitl/netmaker/database"
+	"github.com/gravitl/netmaker/models"
+	"golang.org/x/exp/slog"
+)
+
+var (
+	ipPool      map[string]PoolMap
+	ipPoolMutex = &sync.RWMutex{}
+)
+
+type IpHeap []net.IP
+
+type PoolMap struct {
+	V4 *IpHeap
+	V6 *IpHeap
+}
+
+func (h IpHeap) Len() int { return len(h) }
+func (h IpHeap) Less(i, j int) bool {
+	addr1, _ := netip.ParseAddr(h[i].String())
+	addr2, _ := netip.ParseAddr(h[j].String())
+	return addr1.Less(addr2)
+}
+func (h IpHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+
+func (h *IpHeap) Push(x any) {
+	// Push and Pop use pointer receivers because they modify the slice's length,
+	// not just its contents.
+	*h = append(*h, x.(net.IP))
+}
+
+func (h *IpHeap) Pop() any {
+	old := *h
+	n := len(old)
+	x := old[n-1]
+	*h = old[0 : n-1]
+	return x
+}
+
+// ReleaseV4IpForNetwork - release ip back to ip pool after node is deleted, IPV4
+func ReleaseV4IpForNetwork(networkName string, ip net.IP) error {
+	return releaseIpForNetwork(networkName, ip, "v4")
+}
+
+// ReleaseV6IpForNetwork - release ip back to ip pool after node is deleted, IPv6
+func ReleaseV6IpForNetwork(networkName string, ip net.IP) error {
+	return releaseIpForNetwork(networkName, ip, "v6")
+}
+
+// releaseIpForNetwork - release ip back to ip pool after node is deleted
+func releaseIpForNetwork(networkName string, ip net.IP, v4v6Type string) error {
+	if _, ok := ipPool[networkName]; !ok {
+		return errors.New("network does not exist")
+	}
+	if ip == nil {
+		return errors.New("ip is nil, it does not need to return")
+	}
+	ipPoolMutex.Lock()
+	if v4v6Type == "v4" {
+		heap.Push(ipPool[networkName].V4, ip)
+	} else if v4v6Type == "v6" {
+		heap.Push(ipPool[networkName].V6, ip)
+	}
+	ipPoolMutex.Unlock()
+	return nil
+}
+
+// AddNetworkToIpPool - add network to ip pool when network is added
+func AddNetworkToIpPool(networkName string) error {
+	network, err := GetParentNetwork(networkName)
+	if err != nil {
+		slog.Error("network name is not found ", "Error", networkName, err)
+		return err
+	}
+	pMap := PoolMap{}
+	ipv4List := &IpHeap{}
+	heap.Init(ipv4List)
+	ipv6List := &IpHeap{}
+	heap.Init(ipv6List)
+	if network.IsIPv4 != "no" {
+		//ensure AddressRange is valid
+		if _, _, err := net.ParseCIDR(network.AddressRange); err != nil {
+			slog.Error("ParseCIDR error ", "Error", networkName, network.AddressRange)
+			return err
+		}
+		net4 := iplib.Net4FromStr(network.AddressRange)
+		newAddrs := net4.FirstAddress()
+
+		for {
+			heap.Push(ipv4List, newAddrs)
+			newAddrs, err = net4.NextIP(newAddrs)
+			if err != nil {
+				break
+			}
+		}
+	}
+
+	if network.IsIPv6 != "no" {
+		// ensure AddressRange is valid
+		if _, _, err := net.ParseCIDR(network.AddressRange6); err != nil {
+			slog.Error("ParseCIDR error ", "Error", networkName, network.AddressRange)
+			return err
+		}
+		net6 := iplib.Net6FromStr(network.AddressRange6)
+
+		newAddrs, err := net6.NextIP(net6.FirstAddress())
+		if err == nil {
+			for {
+				heap.Push(ipv6List, newAddrs)
+				newAddrs, err = net6.NextIP(newAddrs)
+				if err != nil {
+					break
+				}
+			}
+		}
+	}
+
+	pMap.V4 = ipv4List
+	pMap.V6 = ipv6List
+	ipPoolMutex.Lock()
+	ipPool[networkName] = pMap
+	ipPoolMutex.Unlock()
+	return nil
+}
+
+// RemoveNetworkFromIpPool - remove network from ip pool when network is deleted
+func RemoveNetworkFromIpPool(networkName string) {
+	ipPoolMutex.Lock()
+	delete(ipPool, networkName)
+	ipPoolMutex.Unlock()
+}
+
+// GetUniqueAddress - Allocate unique ipv4 address
+func GetUniqueAddress(networkName string) (ip net.IP, err error) {
+	if ipPool == nil {
+		return ip, errors.New("ip pool is not initialized")
+	}
+	ipPoolMutex.Lock()
+	defer ipPoolMutex.Unlock()
+
+	if _, ok := ipPool[networkName]; !ok {
+		return ip, errors.New("network does not exist")
+	}
+
+	if len(*ipPool[networkName].V4) == 0 {
+		return ip, errors.New("ip v4 pool for network " + networkName + " is empty")
+	}
+
+	ip = heap.Pop(ipPool[networkName].V4).(net.IP)
+
+	return
+}
+
+// GetUniqueAddress6 - Allocate unique ipv6 address
+func GetUniqueAddress6(networkName string) (ip net.IP, err error) {
+	if ipPool == nil {
+		return ip, errors.New("ip pool is not initialized")
+	}
+	ipPoolMutex.Lock()
+	defer ipPoolMutex.Unlock()
+
+	if _, ok := ipPool[networkName]; !ok {
+		return ip, errors.New("network does not exist")
+	}
+
+	if len(*ipPool[networkName].V6) == 0 {
+		return ip, errors.New("ip v6 pool for network " + networkName + " is empty")
+	}
+
+	ip = heap.Pop(ipPool[networkName].V6).(net.IP)
+
+	return
+}
+
+// ClearIpPool - set ipPool to nil
+func ClearIpPool() {
+	ipPool = nil
+}
+
+// SetIpPool - set available ip pool for network
+func SetIpPool() error {
+	if ipPool == nil {
+		ipPool = map[string]PoolMap{}
+	}
+
+	currentNetworks, err := GetNetworks()
+	if err != nil {
+		return err
+	}
+
+	for _, v := range currentNetworks {
+		pMap := PoolMap{}
+		netName := v.NetID
+
+		ipv4List := getAvailableIpV4Pool(&v)
+		ipv6List := getAvailableIpV6Pool(&v)
+
+		pMap.V4 = ipv4List
+		pMap.V6 = ipv6List
+
+		delete(ipPool, netName)
+		ipPool[netName] = pMap
+	}
+	return nil
+}
+
+func getAvailableIpV4Pool(network *models.Network) *IpHeap {
+
+	ipv4List := &IpHeap{}
+	heap.Init(ipv4List)
+
+	if network.IsIPv4 == "no" {
+		return ipv4List
+	}
+	//ensure AddressRange is valid
+	if _, _, err := net.ParseCIDR(network.AddressRange); err != nil {
+		slog.Debug("UniqueAddress encountered  an error")
+		return ipv4List
+	}
+	net4 := iplib.Net4FromStr(network.AddressRange)
+	newAddrs := net4.FirstAddress()
+
+	for {
+		if IsIPUnique(network.NetID, newAddrs.String(), database.NODES_TABLE_NAME, false) &&
+			IsIPUnique(network.NetID, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, false) {
+			heap.Push(ipv4List, newAddrs)
+		}
+
+		var err error
+		newAddrs, err = net4.NextIP(newAddrs)
+
+		if err != nil {
+			break
+		}
+	}
+
+	return ipv4List
+}
+
+func getAvailableIpV6Pool(network *models.Network) *IpHeap {
+	ipv6List := &IpHeap{}
+	heap.Init(ipv6List)
+
+	if network.IsIPv6 == "no" {
+		return ipv6List
+	}
+
+	//ensure AddressRange is valid
+	if _, _, err := net.ParseCIDR(network.AddressRange6); err != nil {
+		return ipv6List
+	}
+	net6 := iplib.Net6FromStr(network.AddressRange6)
+
+	newAddrs, err := net6.NextIP(net6.FirstAddress())
+	if err != nil {
+		return ipv6List
+	}
+
+	for {
+		if IsIPUnique(network.NetID, newAddrs.String(), database.NODES_TABLE_NAME, true) &&
+			IsIPUnique(network.NetID, newAddrs.String(), database.EXT_CLIENT_TABLE_NAME, true) {
+			heap.Push(ipv6List, newAddrs)
+		}
+
+		newAddrs, err = net6.NextIP(newAddrs)
+
+		if err != nil {
+			break
+		}
+	}
+
+	return ipv6List
+}

+ 1 - 1
logic/jwts.go

@@ -31,7 +31,7 @@ func SetJWTSecret() {
 
 // CreateJWT func will used to create the JWT while signing in and signing out
 func CreateJWT(uuid string, macAddress string, network string) (response string, err error) {
-	expirationTime := time.Now().Add(5 * time.Minute)
+	expirationTime := time.Now().Add(15 * time.Minute)
 	claims := &models.Claims{
 		ID:         uuid,
 		Network:    network,

+ 83 - 2
logic/networks.go

@@ -15,13 +15,52 @@ import (
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic/acls/nodeacls"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/validation"
 )
 
+var (
+	networkCacheMutex = &sync.RWMutex{}
+	networkCacheMap   = make(map[string]models.Network)
+)
+
+func getNetworksFromCache() (networks []models.Network) {
+	networkCacheMutex.RLock()
+	for _, network := range networkCacheMap {
+		networks = append(networks, network)
+	}
+	networkCacheMutex.RUnlock()
+	return
+}
+
+func deleteNetworkFromCache(key string) {
+	networkCacheMutex.Lock()
+	delete(networkCacheMap, key)
+	networkCacheMutex.Unlock()
+}
+
+func getNetworkFromCache(key string) (network models.Network, ok bool) {
+	networkCacheMutex.RLock()
+	network, ok = networkCacheMap[key]
+	networkCacheMutex.RUnlock()
+	return
+}
+
+func storeNetworkInCache(key string, network models.Network) {
+	networkCacheMutex.Lock()
+	networkCacheMap[key] = network
+	networkCacheMutex.Unlock()
+}
+
 // GetNetworks - returns all networks from database
 func GetNetworks() ([]models.Network, error) {
 	var networks []models.Network
-
+	if servercfg.CacheEnabled() {
+		networks := getNetworksFromCache()
+		if len(networks) != 0 {
+			return networks, nil
+		}
+	}
 	collection, err := database.FetchRecords(database.NETWORKS_TABLE_NAME)
 	if err != nil {
 		return networks, err
@@ -34,6 +73,9 @@ func GetNetworks() ([]models.Network, error) {
 		}
 		// add network our array
 		networks = append(networks, network)
+		if servercfg.CacheEnabled() {
+			storeNetworkInCache(network.NetID, network)
+		}
 	}
 
 	return networks, err
@@ -49,7 +91,14 @@ func DeleteNetwork(network string) error {
 	nodeCount, err := GetNetworkNonServerNodeCount(network)
 	if nodeCount == 0 || database.IsEmptyRecord(err) {
 		// delete server nodes first then db records
-		return database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
+		err = database.DeleteRecord(database.NETWORKS_TABLE_NAME, network)
+		if err != nil {
+			return err
+		}
+		if servercfg.CacheEnabled() {
+			deleteNetworkFromCache(network)
+		}
+		return nil
 	}
 	return errors.New("node check failed. All nodes must be deleted before deleting network")
 }
@@ -93,6 +142,9 @@ func CreateNetwork(network models.Network) (models.Network, error) {
 	if err = database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
 		return models.Network{}, err
 	}
+	if servercfg.CacheEnabled() {
+		storeNetworkInCache(network.NetID, network)
+	}
 
 	return network, nil
 }
@@ -128,6 +180,11 @@ func intersect(n1, n2 *net.IPNet) bool {
 func GetParentNetwork(networkname string) (models.Network, error) {
 
 	var network models.Network
+	if servercfg.CacheEnabled() {
+		if network, ok := getNetworkFromCache(networkname); ok {
+			return network, nil
+		}
+	}
 	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
 	if err != nil {
 		return network, err
@@ -142,6 +199,11 @@ func GetParentNetwork(networkname string) (models.Network, error) {
 func GetNetworkSettings(networkname string) (models.Network, error) {
 
 	var network models.Network
+	if servercfg.CacheEnabled() {
+		if network, ok := getNetworkFromCache(networkname); ok {
+			return network, nil
+		}
+	}
 	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
 	if err != nil {
 		return network, err
@@ -320,6 +382,12 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (
 		}
 		newNetwork.SetNetworkLastModified()
 		err = database.Insert(newNetwork.NetID, string(data), database.NETWORKS_TABLE_NAME)
+		if err == nil {
+			if servercfg.CacheEnabled() {
+				storeNetworkInCache(newNetwork.NetID, *newNetwork)
+				deleteNetworkFromCache(currentNetwork.NetID)
+			}
+		}
 		return hasrangeupdate4, hasrangeupdate6, hasholepunchupdate, err
 	}
 	// copy values
@@ -330,6 +398,11 @@ func UpdateNetwork(currentNetwork *models.Network, newNetwork *models.Network) (
 func GetNetwork(networkname string) (models.Network, error) {
 
 	var network models.Network
+	if servercfg.CacheEnabled() {
+		if network, ok := getNetworkFromCache(networkname); ok {
+			return network, nil
+		}
+	}
 	networkData, err := database.FetchRecord(database.NETWORKS_TABLE_NAME, networkname)
 	if err != nil {
 		return network, err
@@ -394,6 +467,9 @@ func SaveNetwork(network *models.Network) error {
 	if err := database.Insert(network.NetID, string(data), database.NETWORKS_TABLE_NAME); err != nil {
 		return err
 	}
+	if servercfg.CacheEnabled() {
+		storeNetworkInCache(network.NetID, *network)
+	}
 	return nil
 }
 
@@ -402,6 +478,11 @@ func NetworkExists(name string) (bool, error) {
 
 	var network string
 	var err error
+	if servercfg.CacheEnabled() {
+		if _, ok := getNetworkFromCache(name); ok {
+			return ok, nil
+		}
+	}
 	if network, err = database.FetchRecord(database.NETWORKS_TABLE_NAME, name); err != nil {
 		return false, err
 	}

+ 2 - 2
logic/nodes.go

@@ -535,7 +535,7 @@ func createNode(node *models.Node) error {
 
 	if node.Address.IP == nil {
 		if parentNetwork.IsIPv4 == "yes" {
-			if node.Address.IP, err = UniqueAddress(node.Network, false); err != nil {
+			if node.Address.IP, err = GetUniqueAddress(node.Network); err != nil {
 				return err
 			}
 			_, cidr, err := net.ParseCIDR(parentNetwork.AddressRange)
@@ -549,7 +549,7 @@ func createNode(node *models.Node) error {
 	}
 	if node.Address6.IP == nil {
 		if parentNetwork.IsIPv6 == "yes" {
-			if node.Address6.IP, err = UniqueAddress6(node.Network, false); err != nil {
+			if node.Address6.IP, err = GetUniqueAddress6(node.Network); err != nil {
 				return err
 			}
 			_, cidr, err := net.ParseCIDR(parentNetwork.AddressRange6)

+ 2 - 0
logic/zombie.go

@@ -135,6 +135,8 @@ func ManageZombies(ctx context.Context, peerUpdate chan *models.Node) {
 					}
 				}
 			}
+			//reset the ip pool
+			SetIpPool()
 		}
 	}
 }

+ 2 - 0
main.go

@@ -37,6 +37,8 @@ func main() {
 	servercfg.SetVersion(version)
 	fmt.Println(models.RetrieveLogo()) // print the logo
 	initialize()                       // initial db and acls
+	logic.SetIpPool()
+	defer logic.ClearIpPool()
 	setGarbageCollection()
 	setVerbosity()
 	if servercfg.DeployedByOperator() && !servercfg.IsPro {

+ 6 - 2
mq/emqx_on_prem.go

@@ -206,7 +206,9 @@ func (e *EmqxOnPrem) CreateEmqxDefaultAuthenticator() error {
 		if err != nil {
 			return err
 		}
-		return fmt.Errorf("error creating default EMQX authenticator %v", string(msg))
+		if !strings.ContainsAny(string(msg), "ALREADY_EXISTS") {
+			return fmt.Errorf("error creating default EMQX authenticator %v", string(msg))
+		}
 	}
 	return nil
 }
@@ -240,7 +242,9 @@ func (e *EmqxOnPrem) CreateEmqxDefaultAuthorizer() error {
 		if err != nil {
 			return err
 		}
-		return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg))
+		if !strings.ContainsAny(string(msg), "duplicated_authz_source_type") {
+			return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg))
+		}
 	}
 	return nil
 }

+ 5 - 9
mq/mq.go

@@ -34,8 +34,8 @@ func setMqOptions(user, password string, opts *mqtt.ClientOptions) {
 	opts.SetAutoReconnect(true)
 	opts.SetConnectRetry(true)
 	opts.SetCleanSession(true)
-	opts.SetConnectRetryInterval(time.Second * 4)
-	opts.SetKeepAlive(time.Minute)
+	opts.SetConnectRetryInterval(time.Second * 1)
+	opts.SetKeepAlive(time.Second * 10)
 	opts.SetCleanSession(true)
 	opts.SetWriteTimeout(time.Minute)
 }
@@ -75,19 +75,15 @@ func SetupMQTT(fatal bool) {
 	opts.SetOnConnectHandler(func(client mqtt.Client) {
 		serverName := servercfg.GetServer()
 		if token := client.Subscribe(fmt.Sprintf("update/%s/#", serverName), 0, mqtt.MessageHandler(UpdateNode)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
-			client.Disconnect(240)
 			logger.Log(0, "node update subscription failed")
 		}
 		if token := client.Subscribe(fmt.Sprintf("host/serverupdate/%s/#", serverName), 0, mqtt.MessageHandler(UpdateHost)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
-			client.Disconnect(240)
 			logger.Log(0, "host update subscription failed")
 		}
 		if token := client.Subscribe(fmt.Sprintf("signal/%s/#", serverName), 0, mqtt.MessageHandler(ClientPeerUpdate)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
-			client.Disconnect(240)
 			logger.Log(0, "node client subscription failed")
 		}
 		if token := client.Subscribe(fmt.Sprintf("metrics/%s/#", serverName), 0, mqtt.MessageHandler(UpdateMetrics)); token.WaitTimeout(MQ_TIMEOUT*time.Second) && token.Error() != nil {
-			client.Disconnect(240)
 			logger.Log(0, "node metrics subscription failed")
 		}
 
@@ -96,15 +92,15 @@ func SetupMQTT(fatal bool) {
 	})
 	opts.SetConnectionLostHandler(func(c mqtt.Client, e error) {
 		slog.Warn("detected broker connection lost", "err", e.Error())
-		c.Disconnect(250)
+		//c.Disconnect(250)
 		slog.Info("re-initiating MQ connection")
-		SetupMQTT(false)
+		//SetupMQTT(false)
 
 	})
 	mqclient = mqtt.NewClient(opts)
 	tperiod := time.Now().Add(10 * time.Second)
 	for {
-		if token := mqclient.Connect(); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
+		if token := mqclient.Connect(); token.Wait() && token.Error() != nil {
 			logger.Log(2, "unable to connect to broker, retrying ...")
 			if time.Now().After(tperiod) {
 				if token.Error() == nil {

+ 1 - 1
mq/publishers.go

@@ -194,7 +194,7 @@ func PushMetricsToExporter(metrics models.Metrics) error {
 	if err != nil {
 		return errors.New("failed to marshal metrics: " + err.Error())
 	}
-	if mqclient == nil || !mqclient.IsConnectionOpen() {
+	if mqclient == nil || !mqclient.IsConnected() {
 		return errors.New("cannot publish ... mqclient not connected")
 	}
 	if token := mqclient.Publish("metrics_exporter", 0, true, data); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {

+ 2 - 3
mq/util.go

@@ -4,7 +4,6 @@ import (
 	"errors"
 	"fmt"
 	"strings"
-	"time"
 
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
@@ -78,11 +77,11 @@ func publish(host *models.Host, dest string, msg []byte) error {
 	if encryptErr != nil {
 		return encryptErr
 	}
-	if mqclient == nil || !mqclient.IsConnectionOpen() {
+	if mqclient == nil || !mqclient.IsConnected() {
 		return errors.New("cannot publish ... mqclient not connected")
 	}
 
-	if token := mqclient.Publish(dest, 0, true, encrypted); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
+	if token := mqclient.Publish(dest, 0, true, encrypted); token.Wait() && token.Error() != nil {
 		var err error
 		if token.Error() == nil {
 			err = errors.New("connection timeout")