util.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. package mq
  2. import (
  3. "bytes"
  4. "compress/gzip"
  5. "crypto/aes"
  6. "crypto/cipher"
  7. "crypto/rand"
  8. "crypto/sha256"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "math"
  13. "strings"
  14. "time"
  15. "github.com/gravitl/netmaker/logic"
  16. "github.com/gravitl/netmaker/models"
  17. "github.com/gravitl/netmaker/netclient/ncutils"
  18. "golang.org/x/crypto/nacl/box"
  19. "golang.org/x/exp/slog"
  20. )
  21. func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
  22. if host.OS == models.OS_Types.IoT { // just pass along IoT messages
  23. return msg, nil
  24. }
  25. // Check version to determine decryption method
  26. vlt, err := logic.VersionLessThan(host.Version, "v0.30.0")
  27. if err != nil {
  28. slog.Warn("error checking version less than", "error", err)
  29. // Default to old method if version check fails
  30. vlt = true
  31. }
  32. if vlt {
  33. // Old decryption method for versions < v0.30.0
  34. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
  35. if trafficErr != nil {
  36. return nil, trafficErr
  37. }
  38. serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
  39. if err != nil {
  40. return nil, err
  41. }
  42. nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  43. if err != nil {
  44. return nil, err
  45. }
  46. return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
  47. } else {
  48. // New AES-GCM decryption for versions >= v0.30.0
  49. // For client->server messages, the client encrypts using client private key + server public key
  50. // The server decrypts using server private key + client public key
  51. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
  52. if trafficErr != nil {
  53. return nil, trafficErr
  54. }
  55. serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
  56. if err != nil {
  57. return nil, err
  58. }
  59. clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  60. if err != nil {
  61. return nil, err
  62. }
  63. // First decrypt, then decompress
  64. decrypted, err := decryptAESGCM(serverPrivKey, clientPubKey, msg)
  65. if err != nil {
  66. return nil, err
  67. }
  68. return decompressPayload(decrypted)
  69. }
  70. }
  71. func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
  72. if len(msg) <= 24 { // make sure message is of appropriate length
  73. return nil, fmt.Errorf("received invalid message from broker %v", msg)
  74. }
  75. host, err := logic.GetHost(node.HostID.String())
  76. if err != nil {
  77. return nil, err
  78. }
  79. return decryptMsgWithHost(host, msg)
  80. }
  81. func BatchItems[T any](items []T, batchSize int) [][]T {
  82. if batchSize <= 0 {
  83. return nil
  84. }
  85. remainderBatchSize := len(items) % batchSize
  86. nBatches := int(math.Ceil(float64(len(items)) / float64(batchSize)))
  87. batches := make([][]T, nBatches)
  88. for i := range batches {
  89. if i == nBatches-1 && remainderBatchSize > 0 {
  90. batches[i] = make([]T, remainderBatchSize)
  91. } else {
  92. batches[i] = make([]T, batchSize)
  93. }
  94. for j := range batches[i] {
  95. batches[i][j] = items[i*batchSize+j]
  96. }
  97. }
  98. return batches
  99. }
  100. func compressPayload(data []byte) ([]byte, error) {
  101. var buf bytes.Buffer
  102. zw := gzip.NewWriter(&buf)
  103. if _, err := zw.Write(data); err != nil {
  104. return nil, err
  105. }
  106. zw.Close()
  107. return buf.Bytes(), nil
  108. }
  109. // deriveSharedSecret derives a symmetric key from the server's private key and client's public key
  110. func deriveSharedSecret(serverPrivKey, clientPubKey *[32]byte) []byte {
  111. // Use NaCl box.Precompute to derive the shared secret
  112. var sharedSecret [32]byte
  113. box.Precompute(&sharedSecret, clientPubKey, serverPrivKey)
  114. // Hash the shared secret to get a 32-byte key for AES
  115. hash := sha256.Sum256(sharedSecret[:])
  116. return hash[:]
  117. }
  118. func encryptAESGCM(serverPrivKey, clientPubKey *[32]byte, plaintext []byte) ([]byte, error) {
  119. // Derive shared secret for symmetric encryption
  120. key := deriveSharedSecret(serverPrivKey, clientPubKey)
  121. // Create AES block cipher
  122. block, err := aes.NewCipher(key)
  123. if err != nil {
  124. return nil, err
  125. }
  126. // Create GCM (Galois/Counter Mode) cipher
  127. aesGCM, err := cipher.NewGCM(block)
  128. if err != nil {
  129. return nil, err
  130. }
  131. // Create a random nonce
  132. nonce := make([]byte, aesGCM.NonceSize())
  133. if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
  134. return nil, err
  135. }
  136. // Encrypt the data
  137. ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
  138. return ciphertext, nil
  139. }
  140. func decryptAESGCM(serverPubKey, clientPrivKey *[32]byte, ciphertext []byte) ([]byte, error) {
  141. // Derive shared secret for symmetric decryption
  142. key := deriveSharedSecret(clientPrivKey, serverPubKey)
  143. // Create AES block cipher
  144. block, err := aes.NewCipher(key)
  145. if err != nil {
  146. return nil, err
  147. }
  148. // Create GCM cipher
  149. aesGCM, err := cipher.NewGCM(block)
  150. if err != nil {
  151. return nil, err
  152. }
  153. // Extract nonce from ciphertext
  154. nonceSize := aesGCM.NonceSize()
  155. if len(ciphertext) < nonceSize {
  156. return nil, fmt.Errorf("ciphertext too short")
  157. }
  158. nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]
  159. // Decrypt the data
  160. plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil)
  161. if err != nil {
  162. return nil, err
  163. }
  164. return plaintext, nil
  165. }
  166. func decompressPayload(data []byte) ([]byte, error) {
  167. reader, err := gzip.NewReader(bytes.NewReader(data))
  168. if err != nil {
  169. return nil, err
  170. }
  171. defer reader.Close()
  172. var buf bytes.Buffer
  173. if _, err := io.Copy(&buf, reader); err != nil {
  174. return nil, err
  175. }
  176. return buf.Bytes(), nil
  177. }
  178. func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
  179. if host.OS == models.OS_Types.IoT {
  180. return msg, nil
  181. }
  182. // fetch server public key to be certain hasn't changed in transit
  183. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
  184. if trafficErr != nil {
  185. return nil, trafficErr
  186. }
  187. serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
  188. if err != nil {
  189. return nil, err
  190. }
  191. nodePubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  192. if err != nil {
  193. return nil, err
  194. }
  195. if strings.Contains(host.Version, "0.10.0") {
  196. return ncutils.BoxEncrypt(msg, nodePubKey, serverPrivKey)
  197. }
  198. return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
  199. }
  200. func publish(host *models.Host, dest string, msg []byte) error {
  201. var encrypted []byte
  202. var encryptErr error
  203. vlt, err := logic.VersionLessThan(host.Version, "v0.30.0")
  204. if err != nil {
  205. slog.Warn("error checking version less than", "error", err)
  206. return err
  207. }
  208. if vlt {
  209. encrypted, encryptErr = encryptMsg(host, msg)
  210. if encryptErr != nil {
  211. return encryptErr
  212. }
  213. } else {
  214. zipped, err := compressPayload(msg)
  215. if err != nil {
  216. return err
  217. }
  218. // Get server private key and client public key for AES-GCM encryption
  219. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
  220. if trafficErr != nil {
  221. return trafficErr
  222. }
  223. serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
  224. if err != nil {
  225. return err
  226. }
  227. clientPubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  228. if err != nil {
  229. return err
  230. }
  231. encrypted, encryptErr = encryptAESGCM(serverPrivKey, clientPubKey, zipped)
  232. if encryptErr != nil {
  233. return encryptErr
  234. }
  235. }
  236. if mqclient == nil || !mqclient.IsConnectionOpen() {
  237. return errors.New("cannot publish ... mqclient not connected")
  238. }
  239. if token := mqclient.Publish(dest, 0, true, encrypted); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
  240. var err error
  241. if token.Error() == nil {
  242. err = errors.New("connection timeout")
  243. } else {
  244. slog.Error("publish to mq error", "error", token.Error().Error())
  245. err = token.Error()
  246. }
  247. return err
  248. }
  249. return nil
  250. }
  251. // decodes a message queue topic and returns the embedded node.ID
  252. func GetID(topic string) (string, error) {
  253. parts := strings.Split(topic, "/")
  254. count := len(parts)
  255. if count == 1 {
  256. return "", fmt.Errorf("invalid topic")
  257. }
  258. //the last part of the topic will be the node.ID
  259. return parts[count-1], nil
  260. }