Browse Source

Merge pull request #1634 from gravitl/bugfix_mq_dyn_sec

node disconnect/connect fix, delete node api fix
Matthew R Kasun 2 years ago
parent
commit
c254e0af85

+ 23 - 12
controllers/node.go

@@ -83,12 +83,15 @@ func authenticate(response http.ResponseWriter, request *http.Request) {
 	var err error
 	var err error
 	result, err = logic.GetNodeByID(authRequest.ID)
 	result, err = logic.GetNodeByID(authRequest.ID)
 	if err != nil {
 	if err != nil {
-		errorResponse.Code = http.StatusBadRequest
-		errorResponse.Message = err.Error()
-		logger.Log(0, request.Header.Get("user"),
-			fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err))
-		logic.ReturnErrorResponse(response, request, errorResponse)
-		return
+		result, err = logic.GetDeletedNodeByID(authRequest.ID)
+		if err != nil {
+			errorResponse.Code = http.StatusBadRequest
+			errorResponse.Message = err.Error()
+			logger.Log(0, request.Header.Get("user"),
+				fmt.Sprintf("failed to get node info [%s]: %v", authRequest.ID, err))
+			logic.ReturnErrorResponse(response, request, errorResponse)
+			return
+		}
 	}
 	}
 
 
 	err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password))
 	err = bcrypt.CompareHashAndPassword([]byte(result.Password), []byte(authRequest.Password))
@@ -256,7 +259,6 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
 				logic.ReturnErrorResponse(w, r, errorResponse)
 				logic.ReturnErrorResponse(w, r, errorResponse)
 				return
 				return
 			}
 			}
-			r.Header.Set("requestfrom", "")
 			//check if node instead of user
 			//check if node instead of user
 			if nodesAllowed {
 			if nodesAllowed {
 				// TODO --- should ensure that node is only operating on itself
 				// TODO --- should ensure that node is only operating on itself
@@ -264,7 +266,6 @@ func authorize(nodesAllowed, networkCheck bool, authNetwork string, next http.Ha
 
 
 					// this indicates request is from a node
 					// this indicates request is from a node
 					// used for failover - if a getNode comes from node, this will trigger a metrics wipe
 					// used for failover - if a getNode comes from node, this will trigger a metrics wipe
-					r.Header.Set("requestfrom", "node")
 					next.ServeHTTP(w, r)
 					next.ServeHTTP(w, r)
 					return
 					return
 				}
 				}
@@ -1040,10 +1041,20 @@ func deleteNode(w http.ResponseWriter, r *http.Request) {
 	fromNode := r.Header.Get("requestfrom") == "node"
 	fromNode := r.Header.Get("requestfrom") == "node"
 	var node, err = logic.GetNodeByID(nodeid)
 	var node, err = logic.GetNodeByID(nodeid)
 	if err != nil {
 	if err != nil {
-		logger.Log(0, r.Header.Get("user"),
-			fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
-		return
+		if fromNode {
+			node, err = logic.GetDeletedNodeByID(nodeid)
+			if err != nil {
+				logger.Log(0, r.Header.Get("user"),
+					fmt.Sprintf("error fetching node from deleted nodes [ %s ] info: %v", nodeid, err))
+				logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+				return
+			}
+		} else {
+			logger.Log(0, r.Header.Get("user"),
+				fmt.Sprintf("error fetching node [ %s ] info: %v", nodeid, err))
+			logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
+			return
+		}
 	}
 	}
 	if isServer(&node) {
 	if isServer(&node) {
 		err := fmt.Errorf("cannot delete server node")
 		err := fmt.Errorf("cannot delete server node")

+ 3 - 1
logic/nodes.go

@@ -186,7 +186,9 @@ func DeleteNodeByID(node *models.Node, exterminate bool) error {
 		}
 		}
 	}
 	}
 	if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil {
 	if err = database.DeleteRecord(database.NODES_TABLE_NAME, key); err != nil {
-		return err
+		if !database.IsEmptyRecord(err) {
+			return err
+		}
 	}
 	}
 
 
 	if servercfg.IsDNSMode() {
 	if servercfg.IsDNSMode() {

+ 1 - 27
netclient/functions/common.go

@@ -192,37 +192,10 @@ func LeaveNetwork(network string) error {
 	if err := removeHostDNS(cfg.Node.Interface, ncutils.IsWindows()); err != nil {
 	if err := removeHostDNS(cfg.Node.Interface, ncutils.IsWindows()); err != nil {
 		logger.Log(0, "failed to delete dns entries for", cfg.Node.Interface, err.Error())
 		logger.Log(0, "failed to delete dns entries for", cfg.Node.Interface, err.Error())
 	}
 	}
-	logger.Log(2, "deleting broker keys as required")
-	if !brokerInUse(cfg.Server.Server) {
-		if err := deleteBrokerFiles(cfg.Server.Server); err != nil {
-			logger.Log(0, "failed to deleter certs for", cfg.Server.Server, err.Error())
-		}
-	}
 	logger.Log(2, "restarting daemon")
 	logger.Log(2, "restarting daemon")
 	return daemon.Restart()
 	return daemon.Restart()
 }
 }
 
 
-func brokerInUse(broker string) bool {
-	networks, _ := ncutils.GetSystemNetworks()
-	for _, net := range networks {
-		cfg := config.ClientConfig{}
-		cfg.Network = net
-		cfg.ReadConfig()
-		if cfg.Server.Server == broker {
-			return true
-		}
-	}
-	return false
-}
-
-func deleteBrokerFiles(broker string) error {
-	dir := ncutils.GetNetclientServerPath(broker)
-	if err := os.RemoveAll(dir); err != nil {
-		return err
-	}
-	return nil
-}
-
 func deleteNodeFromServer(cfg *config.ClientConfig) error {
 func deleteNodeFromServer(cfg *config.ClientConfig) error {
 	node := cfg.Node
 	node := cfg.Node
 	if node.IsServer == "yes" {
 	if node.IsServer == "yes" {
@@ -340,6 +313,7 @@ func API(data any, method, url, authorization string) (*http.Response, error) {
 	if authorization != "" {
 	if authorization != "" {
 		request.Header.Set("authorization", "Bearer "+authorization)
 		request.Header.Set("authorization", "Bearer "+authorization)
 	}
 	}
+	request.Header.Set("requestfrom", "node")
 	return HTTPClient.Do(request)
 	return HTTPClient.Do(request)
 }
 }
 
 

+ 7 - 31
netclient/functions/daemon.go

@@ -2,8 +2,6 @@ package functions
 
 
 import (
 import (
 	"context"
 	"context"
-	"crypto/tls"
-	"crypto/x509"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"os"
 	"os"
@@ -68,12 +66,18 @@ func Daemon() error {
 			cancel()
 			cancel()
 			logger.Log(0, "shutting down netclient daemon")
 			logger.Log(0, "shutting down netclient daemon")
 			wg.Wait()
 			wg.Wait()
+			if mqclient != nil {
+				mqclient.Disconnect(250)
+			}
 			logger.Log(0, "shutdown complete")
 			logger.Log(0, "shutdown complete")
 			return nil
 			return nil
 		case <-reset:
 		case <-reset:
 			logger.Log(0, "received reset")
 			logger.Log(0, "received reset")
 			cancel()
 			cancel()
 			wg.Wait()
 			wg.Wait()
+			if mqclient != nil {
+				mqclient.Disconnect(250)
+			}
 			logger.Log(0, "restarting daemon")
 			logger.Log(0, "restarting daemon")
 			cancel = startGoRoutines(&wg)
 			cancel = startGoRoutines(&wg)
 		}
 		}
@@ -198,34 +202,6 @@ func messageQueue(ctx context.Context, wg *sync.WaitGroup, cfg *config.ClientCon
 	logger.Log(0, "shutting down message queue for server", cfg.Server.Server)
 	logger.Log(0, "shutting down message queue for server", cfg.Server.Server)
 }
 }
 
 
-// NewTLSConf sets up tls configuration to connect to broker securely
-func NewTLSConfig(server string) (*tls.Config, error) {
-	file := ncutils.GetNetclientServerPath(server) + ncutils.GetSeparator() + "root.pem"
-	certpool := x509.NewCertPool()
-	ca, err := os.ReadFile(file)
-	if err != nil {
-		logger.Log(0, "could not read CA file", err.Error())
-	}
-	ok := certpool.AppendCertsFromPEM(ca)
-	if !ok {
-		logger.Log(0, "failed to append cert")
-	}
-	clientKeyPair, err := tls.LoadX509KeyPair(ncutils.GetNetclientServerPath(server)+ncutils.GetSeparator()+"client.pem", ncutils.GetNetclientPath()+ncutils.GetSeparator()+"client.key")
-	if err != nil {
-		logger.Log(0, "could not read client cert/key", err.Error())
-		return nil, err
-	}
-	certs := []tls.Certificate{clientKeyPair}
-	return &tls.Config{
-		RootCAs:            certpool,
-		ClientAuth:         tls.NoClientCert,
-		ClientCAs:          nil,
-		Certificates:       certs,
-		InsecureSkipVerify: false,
-	}, nil
-
-}
-
 // func setMQTTSingenton creates a connection to broker for single use (ie to publish a message)
 // func setMQTTSingenton creates a connection to broker for single use (ie to publish a message)
 // only to be called from cli (eg. connect/disconnect, join, leave) and not from daemon ---
 // only to be called from cli (eg. connect/disconnect, join, leave) and not from daemon ---
 func setupMQTTSingleton(cfg *config.ClientConfig) error {
 func setupMQTTSingleton(cfg *config.ClientConfig) error {
@@ -239,7 +215,7 @@ func setupMQTTSingleton(cfg *config.ClientConfig) error {
 	opts.AddBroker("mqtts://" + server + ":" + port)
 	opts.AddBroker("mqtts://" + server + ":" + port)
 	opts.SetUsername(cfg.Node.ID)
 	opts.SetUsername(cfg.Node.ID)
 	opts.SetPassword(string(pass))
 	opts.SetPassword(string(pass))
-	mqclient := mqtt.NewClient(opts)
+	mqclient = mqtt.NewClient(opts)
 	var connecterr error
 	var connecterr error
 	opts.SetClientID(ncutils.MakeRandomString(23))
 	opts.SetClientID(ncutils.MakeRandomString(23))
 	if token := mqclient.Connect(); !token.WaitTimeout(30*time.Second) || token.Error() != nil {
 	if token := mqclient.Connect(); !token.WaitTimeout(30*time.Second) || token.Error() != nil {

+ 7 - 3
netclient/functions/mqpublish.go

@@ -29,7 +29,6 @@ var metricsCache = new(sync.Map)
 func Checkin(ctx context.Context, wg *sync.WaitGroup) {
 func Checkin(ctx context.Context, wg *sync.WaitGroup) {
 	logger.Log(2, "starting checkin goroutine")
 	logger.Log(2, "starting checkin goroutine")
 	defer wg.Done()
 	defer wg.Done()
-	checkin()
 	ticker := time.NewTicker(time.Minute * ncutils.CheckInInterval)
 	ticker := time.NewTicker(time.Minute * ncutils.CheckInInterval)
 	defer ticker.Stop()
 	defer ticker.Stop()
 	for {
 	for {
@@ -38,7 +37,12 @@ func Checkin(ctx context.Context, wg *sync.WaitGroup) {
 			logger.Log(0, "checkin routine closed")
 			logger.Log(0, "checkin routine closed")
 			return
 			return
 		case <-ticker.C:
 		case <-ticker.C:
-			checkin()
+			if mqclient != nil && mqclient.IsConnected() {
+				checkin()
+			} else {
+				logger.Log(0, "MQ client is not connected, skipping checkin...")
+			}
+
 		}
 		}
 	}
 	}
 }
 }
@@ -107,7 +111,7 @@ func checkin() {
 			config.Write(&nodeCfg, nodeCfg.Network)
 			config.Write(&nodeCfg, nodeCfg.Network)
 		}
 		}
 		Hello(&nodeCfg)
 		Hello(&nodeCfg)
-		if nodeCfg.Server.Is_EE {
+		if nodeCfg.Server.Is_EE && nodeCfg.Node.Connected == "yes" {
 			logger.Log(0, "collecting metrics for node", nodeCfg.Node.Name)
 			logger.Log(0, "collecting metrics for node", nodeCfg.Node.Name)
 			publishMetrics(&nodeCfg)
 			publishMetrics(&nodeCfg)
 		}
 		}