util.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package mq
  2. import (
  3. "bytes"
  4. "compress/gzip"
  5. "crypto/aes"
  6. "crypto/cipher"
  7. "crypto/rand"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "math"
  12. "strings"
  13. "time"
  14. "github.com/gravitl/netmaker/logic"
  15. "github.com/gravitl/netmaker/models"
  16. "github.com/gravitl/netmaker/netclient/ncutils"
  17. "golang.org/x/exp/slog"
  18. )
  19. func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
  20. if host.OS == models.OS_Types.IoT { // just pass along IoT messages
  21. return msg, nil
  22. }
  23. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
  24. if trafficErr != nil {
  25. return nil, trafficErr
  26. }
  27. serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
  28. if err != nil {
  29. return nil, err
  30. }
  31. nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  32. if err != nil {
  33. return nil, err
  34. }
  35. return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
  36. }
  37. func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
  38. if len(msg) <= 24 { // make sure message is of appropriate length
  39. return nil, fmt.Errorf("received invalid message from broker %v", msg)
  40. }
  41. host, err := logic.GetHost(node.HostID.String())
  42. if err != nil {
  43. return nil, err
  44. }
  45. return decryptMsgWithHost(host, msg)
  46. }
  47. func BatchItems[T any](items []T, batchSize int) [][]T {
  48. if batchSize <= 0 {
  49. return nil
  50. }
  51. remainderBatchSize := len(items) % batchSize
  52. nBatches := int(math.Ceil(float64(len(items)) / float64(batchSize)))
  53. batches := make([][]T, nBatches)
  54. for i := range batches {
  55. if i == nBatches-1 && remainderBatchSize > 0 {
  56. batches[i] = make([]T, remainderBatchSize)
  57. } else {
  58. batches[i] = make([]T, batchSize)
  59. }
  60. for j := range batches[i] {
  61. batches[i][j] = items[i*batchSize+j]
  62. }
  63. }
  64. return batches
  65. }
  66. func compressPayload(data []byte) ([]byte, error) {
  67. var buf bytes.Buffer
  68. zw := gzip.NewWriter(&buf)
  69. if _, err := zw.Write(data); err != nil {
  70. return nil, err
  71. }
  72. zw.Close()
  73. return buf.Bytes(), nil
  74. }
  75. func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
  76. // Create AES block cipher
  77. block, err := aes.NewCipher(key)
  78. if err != nil {
  79. return nil, err
  80. }
  81. // Create GCM (Galois/Counter Mode) cipher
  82. aesGCM, err := cipher.NewGCM(block)
  83. if err != nil {
  84. return nil, err
  85. }
  86. // Create a random nonce
  87. nonce := make([]byte, aesGCM.NonceSize())
  88. if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
  89. return nil, err
  90. }
  91. // Encrypt the data
  92. ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
  93. return ciphertext, nil
  94. }
  95. func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
  96. if host.OS == models.OS_Types.IoT {
  97. return msg, nil
  98. }
  99. // fetch server public key to be certain hasn't changed in transit
  100. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
  101. if trafficErr != nil {
  102. return nil, trafficErr
  103. }
  104. serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
  105. if err != nil {
  106. return nil, err
  107. }
  108. nodePubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  109. if err != nil {
  110. return nil, err
  111. }
  112. if strings.Contains(host.Version, "0.10.0") {
  113. return ncutils.BoxEncrypt(msg, nodePubKey, serverPrivKey)
  114. }
  115. return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
  116. }
  117. func publish(host *models.Host, dest string, msg []byte) error {
  118. var encrypted []byte
  119. var encryptErr error
  120. vlt, err := logic.VersionLessThan(host.Version, "v0.30.0")
  121. if err != nil {
  122. slog.Warn("error checking version less than", "error", err)
  123. return err
  124. }
  125. if vlt {
  126. encrypted, encryptErr = encryptMsg(host, msg)
  127. if encryptErr != nil {
  128. return encryptErr
  129. }
  130. } else {
  131. zipped, err := compressPayload(msg)
  132. if err != nil {
  133. return err
  134. }
  135. encrypted, encryptErr = encryptAESGCM(host.TrafficKeyPublic[0:32], zipped)
  136. if encryptErr != nil {
  137. return encryptErr
  138. }
  139. }
  140. if mqclient == nil || !mqclient.IsConnectionOpen() {
  141. return errors.New("cannot publish ... mqclient not connected")
  142. }
  143. if token := mqclient.Publish(dest, 0, true, encrypted); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
  144. var err error
  145. if token.Error() == nil {
  146. err = errors.New("connection timeout")
  147. } else {
  148. slog.Error("publish to mq error", "error", token.Error().Error())
  149. err = token.Error()
  150. }
  151. return err
  152. }
  153. return nil
  154. }
  155. // decodes a message queue topic and returns the embedded node.ID
  156. func GetID(topic string) (string, error) {
  157. parts := strings.Split(topic, "/")
  158. count := len(parts)
  159. if count == 1 {
  160. return "", fmt.Errorf("invalid topic")
  161. }
  162. //the last part of the topic will be the node.ID
  163. return parts[count-1], nil
  164. }