Browse Source

fix AES mq encryption

abhishek9686 1 day ago
parent
commit
d99dbe523a
2 changed files with 147 additions and 15 deletions
  1. 21 2
      mq/migrate.go
  2. 126 13
      mq/util.go

+ 21 - 2
mq/migrate.go

@@ -14,6 +14,7 @@ import (
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/exp/slog"
 )
@@ -107,9 +108,27 @@ func SendPullSYN() error {
 				slog.Warn("error compressing message", "warn", err)
 				continue
 			}
-			encrypted, encryptErr = encryptAESGCM(host.TrafficKeyPublic[0:32], zipped)
+
+			// Get server private key and client public key for AES-GCM encryption
+			trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
+			if trafficErr != nil {
+				slog.Warn("error retrieving traffic key", "warn", trafficErr)
+				continue
+			}
+			serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
+			if err != nil {
+				slog.Warn("error converting server private key", "warn", err)
+				continue
+			}
+			clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
+			if err != nil {
+				slog.Warn("error converting client public key", "warn", err)
+				continue
+			}
+
+			encrypted, encryptErr = encryptAESGCM(serverPrivKey, clientPubKey, zipped)
 			if encryptErr != nil {
-				slog.Warn("error encrypt with encryptMsg", "warn", encryptErr)
+				slog.Warn("error encrypt with encryptAESGCM", "warn", encryptErr)
 				continue
 			}
 		}

+ 126 - 13
mq/util.go

@@ -6,6 +6,7 @@ import (
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/rand"
+	"crypto/sha256"
 	"errors"
 	"fmt"
 	"io"
@@ -16,6 +17,7 @@ import (
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/netclient/ncutils"
+	"golang.org/x/crypto/nacl/box"
 	"golang.org/x/exp/slog"
 )
 
@@ -24,20 +26,54 @@ func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
 		return msg, nil
 	}
 
-	trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
-	if trafficErr != nil {
-		return nil, trafficErr
-	}
-	serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
-	if err != nil {
-		return nil, err
-	}
-	nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
+	// Check version to determine decryption method
+	vlt, err := logic.VersionLessThan(host.Version, "v0.30.0")
 	if err != nil {
-		return nil, err
+		slog.Warn("error checking version less than", "error", err)
+		// Default to old method if version check fails
+		vlt = true
 	}
 
-	return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
+	if vlt {
+		// Old decryption method for versions < v0.30.0
+		trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
+		if trafficErr != nil {
+			return nil, trafficErr
+		}
+		serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
+		if err != nil {
+			return nil, err
+		}
+		nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
+		if err != nil {
+			return nil, err
+		}
+
+		return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
+	} else {
+		// New AES-GCM decryption for versions >= v0.30.0
+		// For client->server messages, the client encrypts using client private key + server public key
+		// The server decrypts using server private key + client public key
+		trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
+		if trafficErr != nil {
+			return nil, trafficErr
+		}
+		serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
+		if err != nil {
+			return nil, err
+		}
+		clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
+		if err != nil {
+			return nil, err
+		}
+
+		// First decrypt, then decompress
+		decrypted, err := decryptAESGCM(serverPrivKey, clientPubKey, msg)
+		if err != nil {
+			return nil, err
+		}
+		return decompressPayload(decrypted)
+	}
 }
 
 func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
@@ -81,7 +117,22 @@ func compressPayload(data []byte) ([]byte, error) {
 	zw.Close()
 	return buf.Bytes(), nil
 }
-func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
+
+// deriveSharedSecret derives a symmetric key from the server's private key and client's public key
+func deriveSharedSecret(serverPrivKey, clientPubKey *[32]byte) []byte {
+	// Use NaCl box.Precompute to derive the shared secret
+	var sharedSecret [32]byte
+	box.Precompute(&sharedSecret, clientPubKey, serverPrivKey)
+
+	// Hash the shared secret to get a 32-byte key for AES
+	hash := sha256.Sum256(sharedSecret[:])
+	return hash[:]
+}
+
+func encryptAESGCM(serverPrivKey, clientPubKey *[32]byte, plaintext []byte) ([]byte, error) {
+	// Derive shared secret for symmetric encryption
+	key := deriveSharedSecret(serverPrivKey, clientPubKey)
+
 	// Create AES block cipher
 	block, err := aes.NewCipher(key)
 	if err != nil {
@@ -105,6 +156,53 @@ func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
 	return ciphertext, nil
 }
 
+func decryptAESGCM(serverPubKey, clientPrivKey *[32]byte, ciphertext []byte) ([]byte, error) {
+	// Derive shared secret for symmetric decryption
+	key := deriveSharedSecret(clientPrivKey, serverPubKey)
+
+	// Create AES block cipher
+	block, err := aes.NewCipher(key)
+	if err != nil {
+		return nil, err
+	}
+
+	// Create GCM cipher
+	aesGCM, err := cipher.NewGCM(block)
+	if err != nil {
+		return nil, err
+	}
+
+	// Extract nonce from ciphertext
+	nonceSize := aesGCM.NonceSize()
+	if len(ciphertext) < nonceSize {
+		return nil, fmt.Errorf("ciphertext too short")
+	}
+
+	nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
+
+	// Decrypt the data
+	plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
+	if err != nil {
+		return nil, err
+	}
+
+	return plaintext, nil
+}
+
+func decompressPayload(data []byte) ([]byte, error) {
+	reader, err := gzip.NewReader(bytes.NewReader(data))
+	if err != nil {
+		return nil, err
+	}
+	defer reader.Close()
+
+	var buf bytes.Buffer
+	if _, err := io.Copy(&buf, reader); err != nil {
+		return nil, err
+	}
+	return buf.Bytes(), nil
+}
+
 func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
 	if host.OS == models.OS_Types.IoT {
 		return msg, nil
@@ -152,7 +250,22 @@ func publish(host *models.Host, dest string, msg []byte) error {
 		if err != nil {
 			return err
 		}
-		encrypted, encryptErr = encryptAESGCM(host.TrafficKeyPublic[0:32], zipped)
+
+		// Get server private key and client public key for AES-GCM encryption
+		trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
+		if trafficErr != nil {
+			return trafficErr
+		}
+		serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
+		if err != nil {
+			return err
+		}
+		clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
+		if err != nil {
+			return err
+		}
+
+		encrypted, encryptErr = encryptAESGCM(serverPrivKey, clientPubKey, zipped)
 		if encryptErr != nil {
 			return encryptErr
 		}