Browse Source

use new node model on mq update node handler

Abhishek Kondur 2 years ago
parent
commit
b52a47d8aa
2 changed files with 41 additions and 42 deletions
  1. 36 29
      models/node.go
  2. 5 13
      mq/handlers.go

+ 36 - 29
models/node.go

@@ -8,6 +8,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/google/uuid"
 	"github.com/google/uuid"
+	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 )
 
 
 const (
 const (
@@ -482,37 +483,43 @@ func (node *Node) DoesACLDeny() bool {
 	return node.DefaultACL == "no"
 	return node.DefaultACL == "no"
 }
 }
 
 
-func (ln *LegacyNode) ConvertToNewNode(host *Host) (*Host, *Node) {
+func (ln *LegacyNode) ConvertToNewNode() (*Host, *Node) {
 	var node Node
 	var node Node
-	host.FirewallInUse = ln.FirewallInUse
-	host.Version = ln.Version
-	host.IPForwarding = parseBool(ln.IPForwarding)
-	//host.HostPass = ln.Password
-	host.Name = ln.Name
-	host.ListenPort = int(ln.ListenPort)
-	if _, cidr, err := net.ParseCIDR(ln.LocalAddress); err == nil {
-		host.LocalRange = *cidr
-	} else {
-		if _, cidr, err := net.ParseCIDR(ln.LocalRange); err == nil {
+	//host:= logic.GetHost(node.HostID)
+	var host Host
+	if host.ID.String() == "" {
+		host.ID = uuid.New()
+		host.FirewallInUse = ln.FirewallInUse
+		host.Version = ln.Version
+		host.IPForwarding = parseBool(ln.IPForwarding)
+		host.HostPass = ln.Password
+		host.Name = ln.Name
+		host.ListenPort = int(ln.ListenPort)
+		if _, cidr, err := net.ParseCIDR(ln.LocalAddress); err == nil {
 			host.LocalRange = *cidr
 			host.LocalRange = *cidr
+		} else {
+			if _, cidr, err := net.ParseCIDR(ln.LocalRange); err == nil {
+				host.LocalRange = *cidr
+			}
 		}
 		}
-	}
-	host.LocalListenPort = int(ln.LocalListenPort)
-	host.ProxyListenPort = int(ln.ProxyListenPort)
-	host.MTU = int(ln.MTU)
-	// host.PublicKey, _ = wgtypes.ParseKey(ln.PublicKey)
-	// host.MacAddress, _ = net.ParseMAC(ln.MacAddress)
-	// host.TrafficKeyPublic = ln.TrafficKeys.Mine
-	gateway, err := net.ResolveUDPAddr("udp", ln.InternetGateway)
-	if err == nil {
-		host.InternetGateway = *gateway
-	}
-	nodeID, _ := uuid.Parse(ln.ID)
-	host.Nodes = append(host.Nodes, nodeID.String())
-	host.Interfaces = ln.Interfaces
-	host.EndpointIP = net.ParseIP(ln.Endpoint)
-	host.ProxyEnabled = ln.Proxy
-	node.ID = nodeID
+		host.LocalListenPort = int(ln.LocalListenPort)
+		host.ProxyListenPort = int(ln.ProxyListenPort)
+		host.MTU = int(ln.MTU)
+		host.PublicKey, _ = wgtypes.ParseKey(ln.PublicKey)
+		host.MacAddress, _ = net.ParseMAC(ln.MacAddress)
+		host.TrafficKeyPublic = ln.TrafficKeys.Mine
+		gateway, err := net.ResolveUDPAddr("udp", ln.InternetGateway)
+		if err == nil {
+			host.InternetGateway = *gateway
+		}
+		id, _ := uuid.Parse(ln.ID)
+		host.Nodes = append(host.Nodes, id.String())
+		host.Interfaces = ln.Interfaces
+		host.EndpointIP = net.ParseIP(ln.Endpoint)
+		// host.ProxyEnabled = ln.Proxy // this will always be false..
+	}
+	id, _ := uuid.Parse(ln.ID)
+	node.ID = id
 	node.Network = ln.Network
 	node.Network = ln.Network
 	if _, cidr, err := net.ParseCIDR(ln.NetworkSettings.AddressRange); err == nil {
 	if _, cidr, err := net.ParseCIDR(ln.NetworkSettings.AddressRange); err == nil {
 		node.NetworkRange = *cidr
 		node.NetworkRange = *cidr
@@ -542,7 +549,7 @@ func (ln *LegacyNode) ConvertToNewNode(host *Host) (*Host, *Node) {
 	node.IsIngressGateway = parseBool(ln.IsIngressGateway)
 	node.IsIngressGateway = parseBool(ln.IsIngressGateway)
 	node.DNSOn = parseBool(ln.DNSOn)
 	node.DNSOn = parseBool(ln.DNSOn)
 
 
-	return host, &node
+	return &host, &node
 }
 }
 
 
 // Node.Legacy converts node to legacy format
 // Node.Legacy converts node to legacy format

+ 5 - 13
mq/handlers.go

@@ -90,28 +90,20 @@ func UpdateNode(client mqtt.Client, msg mqtt.Message) {
 			logger.Log(1, "failed to decrypt message for node ", id, decryptErr.Error())
 			logger.Log(1, "failed to decrypt message for node ", id, decryptErr.Error())
 			return
 			return
 		}
 		}
-		var oldNode models.LegacyNode
-		if err := json.Unmarshal(decrypted, &oldNode); err != nil {
+		var newNode models.Node
+		if err := json.Unmarshal(decrypted, &newNode); err != nil {
 			logger.Log(1, "error unmarshaling payload ", err.Error())
 			logger.Log(1, "error unmarshaling payload ", err.Error())
 			return
 			return
 		}
 		}
-		host, err := logic.GetHost(oldNode.HostID)
-		if err != nil && database.IsEmptyRecord(err) {
-			return
-		}
-		host, newNode := oldNode.ConvertToNewNode(host)
-		err = logic.UpsertHost(host)
-		if err != nil {
-			logger.Log(0, "failed to update host: ", err.Error())
-		}
-		ifaceDelta := logic.IfaceDelta(&currentNode, newNode)
+
+		ifaceDelta := logic.IfaceDelta(&currentNode, &newNode)
 		if servercfg.Is_EE && ifaceDelta {
 		if servercfg.Is_EE && ifaceDelta {
 			if err = logic.EnterpriseResetAllPeersFailovers(currentNode.ID.String(), currentNode.Network); err != nil {
 			if err = logic.EnterpriseResetAllPeersFailovers(currentNode.ID.String(), currentNode.Network); err != nil {
 				logger.Log(1, "failed to reset failover list during node update", currentNode.ID.String(), currentNode.Network)
 				logger.Log(1, "failed to reset failover list during node update", currentNode.ID.String(), currentNode.Network)
 			}
 			}
 		}
 		}
 		newNode.SetLastCheckIn()
 		newNode.SetLastCheckIn()
-		if err := logic.UpdateNode(&currentNode, newNode); err != nil {
+		if err := logic.UpdateNode(&currentNode, &newNode); err != nil {
 			logger.Log(1, "error saving node", err.Error())
 			logger.Log(1, "error saving node", err.Error())
 			return
 			return
 		}
 		}