Abhishek Kondur il y a 2 ans
Parent
commit
c80f3d0146
5 fichiers modifiés avec 30 ajouts et 10 suppressions
  1. 5 2
      controllers/hosts.go
  2. 6 0
      mq/dynsec_helper.go
  3. 8 7
      mq/handlers.go
  4. 1 1
      mq/publishers.go
  5. 10 0
      mq/util.go

+ 5 - 2
controllers/hosts.go

@@ -2,6 +2,7 @@ package controller
 
 import (
 	"encoding/json"
+	"fmt"
 	"net/http"
 	"reflect"
 
@@ -203,7 +204,9 @@ func updateHostNetworks(w http.ResponseWriter, r *http.Request) {
 	}); err != nil {
 		logger.Log(0, r.Header.Get("user"), "failed to update host networks roles in DynSec:", err.Error())
 	}
-	go func() {
+	go func(newNets, delNets []string) {
+		logger.Log(0, fmt.Sprint("-----------> NEW NETS: ", newNets))
+		logger.Log(0, fmt.Sprint("-----------> DEL NETS: ", delNets))
 		for _, newNet := range newNets {
 			node, err := logic.GetNodeByNetwork(currHost.ID.String(), newNet)
 			if err != nil {
@@ -230,7 +233,7 @@ func updateHostNetworks(w http.ResponseWriter, r *http.Request) {
 				logger.Log(0, "failed to send mq msg to delete host from network: ", delNet, err.Error())
 			}
 		}
-	}()
+	}(newNets, delNets)
 
 	logger.Log(2, r.Header.Get("user"), "updated host networks", currHost.Name)
 	w.WriteHeader(http.StatusOK)

+ 6 - 0
mq/dynsec_helper.go

@@ -311,6 +311,12 @@ func fetchServerAcls() []Acl {
 			Priority: -1,
 			Allow:    true,
 		},
+		{
+			AclType:  "publishClientSend",
+			Topic:    "host/update/#",
+			Priority: -1,
+			Allow:    true,
+		},
 		{
 			AclType:  "publishClientSend",
 			Topic:    "metrics_exporter",

+ 8 - 7
mq/handlers.go

@@ -118,22 +118,22 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
 	}()
 }
 
-// UpdateHost  message Handler -- handles updates from client hosts
+// UpdateHost  message Handler -- handles host updates from clients
 func UpdateHost(client mqtt.Client, msg mqtt.Message) {
-	go func() {
-		id, err := getID(msg.Topic())
+	go func(msg mqtt.Message) {
+		id, err := getHostID(msg.Topic())
 		if err != nil {
 			logger.Log(1, "error getting host.ID sent on ", msg.Topic(), err.Error())
 			return
 		}
 		currentHost, err := logic.GetHost(id)
 		if err != nil {
-			logger.Log(1, "error getting node ", id, err.Error())
+			logger.Log(1, "error getting host ", id, err.Error())
 			return
 		}
 		decrypted, decryptErr := decryptMsgWithHost(currentHost, msg.Payload())
 		if decryptErr != nil {
-			logger.Log(1, "failed to decrypt message for node ", id, decryptErr.Error())
+			logger.Log(1, "failed to decrypt message for host ", id, decryptErr.Error())
 			return
 		}
 		var newHost models.Host
@@ -141,13 +141,14 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
 			logger.Log(1, "error unmarshaling payload ", err.Error())
 			return
 		}
+		logger.Log(0, fmt.Sprintf("--------> HOST Update: %+v\n", newHost))
 		// if servercfg.Is_EE && ifaceDelta {
 		// 	if err = logic.EnterpriseResetAllPeersFailovers(currentHost.ID.String(), currentHost.Network); err != nil {
 		// 		logger.Log(1, "failed to reset failover list during node update", currentHost.ID.String(), currentHost.Network)
 		// 	}
 		// }
 		sendPeerUpdate := logic.UpdateHostFromClient(&newHost, currentHost)
-		if err := logic.UpsertHost(&newHost); err != nil {
+		if err := logic.UpsertHost(currentHost); err != nil {
 			logger.Log(1, "error saving host", err.Error())
 			return
 		}
@@ -158,7 +159,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) {
 			}
 		}
 		logger.Log(1, "updated host", newHost.ID.String())
-	}()
+	}(msg)
 }
 
 // UpdateMetrics  message Handler -- handles updates from client nodes for metrics

+ 1 - 1
mq/publishers.go

@@ -97,7 +97,7 @@ func HostUpdate(hostUpdate *models.HostUpdate) error {
 	if !servercfg.IsMessageQueueBackend() {
 		return nil
 	}
-	logger.Log(3, "publishing host update to "+hostUpdate.Host.ID.String())
+	logger.Log(3, "----------> HEREEEEE publishing host update to "+hostUpdate.Host.ID.String())
 
 	data, err := json.Marshal(hostUpdate)
 	if err != nil {

+ 10 - 0
mq/util.go

@@ -94,3 +94,13 @@ func getID(topic string) (string, error) {
 	//the last part of the topic will be the node.ID
 	return parts[count-1], nil
 }
+
+// decodes a message queue topic and returns the embedded host.ID
+func getHostID(topic string) (string, error) {
+	parts := strings.Split(topic, "/")
+	count := len(parts)
+	if count < 4 {
+		return "", fmt.Errorf("invalid topic")
+	}
+	return parts[2], nil
+}