浏览代码

add host version check for message encrypt

Max Ma 9 月之前
父节点
当前提交
1576e4ce10
共有 3 个文件被更改,包括 80 次插入4 次删除
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 77 4
      mq/util.go

+ 1 - 0
go.mod

@@ -3,6 +3,7 @@ module github.com/gravitl/netmaker
 go 1.23
 
 require (
+	github.com/blang/semver v3.5.1+incompatible
 	github.com/eclipse/paho.mqtt.golang v1.4.3
 	github.com/go-playground/validator/v10 v10.23.0
 	github.com/golang-jwt/jwt/v4 v4.5.1

+ 2 - 0
go.sum

@@ -2,6 +2,8 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2Qx
 cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
 filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
 filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
+github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ=
+github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk=
 github.com/c-robinson/iplib v1.0.8 h1:exDRViDyL9UBLcfmlxxkY5odWX5092nPsQIykHXhIn4=
 github.com/c-robinson/iplib v1.0.8/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo=
 github.com/coreos/go-oidc/v3 v3.9.0 h1:0J/ogVOd4y8P0f0xUh8l9t07xRP/d8tccvjHl2dcsSo=

+ 77 - 4
mq/util.go

@@ -12,7 +12,9 @@ import (
 	"math"
 	"strings"
 	"time"
+	"unicode"
 
+	"github.com/blang/semver"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/ncutils"
@@ -105,14 +107,59 @@ func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
 	return ciphertext, nil
 }
 
+func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
+	if host.OS == models.OS_Types.IoT {
+		return msg, nil
+	}
+
+	// fetch server public key to be certain hasn't changed in transit
+	trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
+	if trafficErr != nil {
+		return nil, trafficErr
+	}
+
+	serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
+	if err != nil {
+		return nil, err
+	}
+
+	nodePubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
+	if err != nil {
+		return nil, err
+	}
+
+	if strings.Contains(host.Version, "0.10.0") {
+		return ncutils.BoxEncrypt(msg, nodePubKey, serverPrivKey)
+	}
+
+	return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
+}
+
 func publish(host *models.Host, dest string, msg []byte) error {
-	zipped, err := compressPayload(msg)
+
+	var encrypted []byte
+	var encryptErr error
+	vlt, err := versionLessThan(host.Version, "v0.30.0")
 	if err != nil {
+		slog.Warn("error checking version less than", "error", err)
 		return err
 	}
-	encrypted, encryptErr := encryptAESGCM(host.TrafficKeyPublic[0:32], zipped)
-	if encryptErr != nil {
-		return encryptErr
+	slog.Error("host.Version: ", "Debug", host.Version)
+	slog.Error("host.Version less than v0.30.0: ", "Debug", vlt)
+	if vlt {
+		encrypted, encryptErr = encryptMsg(host, msg)
+		if encryptErr != nil {
+			return encryptErr
+		}
+	} else {
+		zipped, err := compressPayload(msg)
+		if err != nil {
+			return err
+		}
+		encrypted, encryptErr = encryptAESGCM(host.TrafficKeyPublic[0:32], zipped)
+		if encryptErr != nil {
+			return encryptErr
+		}
 	}
 
 	if mqclient == nil || !mqclient.IsConnectionOpen() {
@@ -142,3 +189,29 @@ func GetID(topic string) (string, error) {
 	//the last part of the topic will be the node.ID
 	return parts[count-1], nil
 }
+
+// versionLessThan checks if v1 < v2 semantically
+// dev is the latest version
+func versionLessThan(v1, v2 string) (bool, error) {
+	if v1 == "dev" {
+		return false, nil
+	}
+	if v2 == "dev" {
+		return true, nil
+	}
+	semVer1 := strings.TrimFunc(v1, func(r rune) bool {
+		return !unicode.IsNumber(r)
+	})
+	semVer2 := strings.TrimFunc(v2, func(r rune) bool {
+		return !unicode.IsNumber(r)
+	})
+	sv1, err := semver.Parse(semVer1)
+	if err != nil {
+		return false, fmt.Errorf("failed to parse semver1 (%s): %w", semVer1, err)
+	}
+	sv2, err := semver.Parse(semVer2)
+	if err != nil {
+		return false, fmt.Errorf("failed to parse semver2 (%s): %w", semVer2, err)
+	}
+	return sv1.LT(sv2), nil
+}