123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300 |
- package mq
- import (
- "bytes"
- "compress/gzip"
- "crypto/aes"
- "crypto/cipher"
- "crypto/rand"
- "crypto/sha256"
- "errors"
- "fmt"
- "io"
- "math"
- "strings"
- "time"
- "github.com/gravitl/netmaker/logic"
- "github.com/gravitl/netmaker/models"
- "github.com/gravitl/netmaker/netclient/ncutils"
- "golang.org/x/crypto/nacl/box"
- "golang.org/x/exp/slog"
- )
- func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
- if host.OS == models.OS_Types.IoT { // just pass along IoT messages
- return msg, nil
- }
- // Check version to determine decryption method
- vlt, err := logic.VersionLessThan(host.Version, "v0.30.0")
- if err != nil {
- slog.Warn("error checking version less than", "error", err)
- // Default to old method if version check fails
- vlt = true
- }
- if vlt {
- // Old decryption method for versions < v0.30.0
- trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
- if trafficErr != nil {
- return nil, trafficErr
- }
- serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
- if err != nil {
- return nil, err
- }
- nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
- if err != nil {
- return nil, err
- }
- return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
- } else {
- // New AES-GCM decryption for versions >= v0.30.0
- // For client->server messages, the client encrypts using client private key + server public key
- // The server decrypts using server private key + client public key
- trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
- if trafficErr != nil {
- return nil, trafficErr
- }
- serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
- if err != nil {
- return nil, err
- }
- clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
- if err != nil {
- return nil, err
- }
- // First decrypt, then decompress
- decrypted, err := decryptAESGCM(serverPrivKey, clientPubKey, msg)
- if err != nil {
- return nil, err
- }
- return decompressPayload(decrypted)
- }
- }
- func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
- if len(msg) <= 24 { // make sure message is of appropriate length
- return nil, fmt.Errorf("received invalid message from broker %v", msg)
- }
- host, err := logic.GetHost(node.HostID.String())
- if err != nil {
- return nil, err
- }
- return decryptMsgWithHost(host, msg)
- }
- func BatchItems[T any](items []T, batchSize int) [][]T {
- if batchSize <= 0 {
- return nil
- }
- remainderBatchSize := len(items) % batchSize
- nBatches := int(math.Ceil(float64(len(items)) / float64(batchSize)))
- batches := make([][]T, nBatches)
- for i := range batches {
- if i == nBatches-1 && remainderBatchSize > 0 {
- batches[i] = make([]T, remainderBatchSize)
- } else {
- batches[i] = make([]T, batchSize)
- }
- for j := range batches[i] {
- batches[i][j] = items[i*batchSize+j]
- }
- }
- return batches
- }
- func compressPayload(data []byte) ([]byte, error) {
- var buf bytes.Buffer
- zw := gzip.NewWriter(&buf)
- if _, err := zw.Write(data); err != nil {
- return nil, err
- }
- zw.Close()
- return buf.Bytes(), nil
- }
- // deriveSharedSecret derives a symmetric key from the server's private key and client's public key
- func deriveSharedSecret(serverPrivKey, clientPubKey *[32]byte) []byte {
- // Use NaCl box.Precompute to derive the shared secret
- var sharedSecret [32]byte
- box.Precompute(&sharedSecret, clientPubKey, serverPrivKey)
- // Hash the shared secret to get a 32-byte key for AES
- hash := sha256.Sum256(sharedSecret[:])
- return hash[:]
- }
- func encryptAESGCM(serverPrivKey, clientPubKey *[32]byte, plaintext []byte) ([]byte, error) {
- // Derive shared secret for symmetric encryption
- key := deriveSharedSecret(serverPrivKey, clientPubKey)
- // Create AES block cipher
- block, err := aes.NewCipher(key)
- if err != nil {
- return nil, err
- }
- // Create GCM (Galois/Counter Mode) cipher
- aesGCM, err := cipher.NewGCM(block)
- if err != nil {
- return nil, err
- }
- // Create a random nonce
- nonce := make([]byte, aesGCM.NonceSize())
- if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
- return nil, err
- }
- // Encrypt the data
- ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
- return ciphertext, nil
- }
- func decryptAESGCM(serverPubKey, clientPrivKey *[32]byte, ciphertext []byte) ([]byte, error) {
- // Derive shared secret for symmetric decryption
- key := deriveSharedSecret(clientPrivKey, serverPubKey)
- // Create AES block cipher
- block, err := aes.NewCipher(key)
- if err != nil {
- return nil, err
- }
- // Create GCM cipher
- aesGCM, err := cipher.NewGCM(block)
- if err != nil {
- return nil, err
- }
- // Extract nonce from ciphertext
- nonceSize := aesGCM.NonceSize()
- if len(ciphertext) < nonceSize {
- return nil, fmt.Errorf("ciphertext too short")
- }
- nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
- // Decrypt the data
- plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
- if err != nil {
- return nil, err
- }
- return plaintext, nil
- }
- func decompressPayload(data []byte) ([]byte, error) {
- reader, err := gzip.NewReader(bytes.NewReader(data))
- if err != nil {
- return nil, err
- }
- defer reader.Close()
- var buf bytes.Buffer
- if _, err := io.Copy(&buf, reader); err != nil {
- return nil, err
- }
- return buf.Bytes(), nil
- }
- func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
- if host.OS == models.OS_Types.IoT {
- return msg, nil
- }
- // fetch server public key to be certain hasn't changed in transit
- trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
- if trafficErr != nil {
- return nil, trafficErr
- }
- serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
- if err != nil {
- return nil, err
- }
- nodePubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
- if err != nil {
- return nil, err
- }
- if strings.Contains(host.Version, "0.10.0") {
- return ncutils.BoxEncrypt(msg, nodePubKey, serverPrivKey)
- }
- return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
- }
- func publish(host *models.Host, dest string, msg []byte) error {
- var encrypted []byte
- var encryptErr error
- vlt, err := logic.VersionLessThan(host.Version, "v0.30.0")
- if err != nil {
- slog.Warn("error checking version less than", "error", err)
- return err
- }
- if vlt {
- encrypted, encryptErr = encryptMsg(host, msg)
- if encryptErr != nil {
- return encryptErr
- }
- } else {
- zipped, err := compressPayload(msg)
- if err != nil {
- return err
- }
- // Get server private key and client public key for AES-GCM encryption
- trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
- if trafficErr != nil {
- return trafficErr
- }
- serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
- if err != nil {
- return err
- }
- clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
- if err != nil {
- return err
- }
- encrypted, encryptErr = encryptAESGCM(serverPrivKey, clientPubKey, zipped)
- if encryptErr != nil {
- return encryptErr
- }
- }
- if mqclient == nil || !mqclient.IsConnectionOpen() {
- return errors.New("cannot publish ... mqclient not connected")
- }
- if token := mqclient.Publish(dest, 0, true, encrypted); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
- var err error
- if token.Error() == nil {
- err = errors.New("connection timeout")
- } else {
- slog.Error("publish to mq error", "error", token.Error().Error())
- err = token.Error()
- }
- return err
- }
- return nil
- }
- // decodes a message queue topic and returns the embedded node.ID
- func GetID(topic string) (string, error) {
- parts := strings.Split(topic, "/")
- count := len(parts)
- if count == 1 {
- return "", fmt.Errorf("invalid topic")
- }
- //the last part of the topic will be the node.ID
- return parts[count-1], nil
- }
|