util.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. package mq
  2. import (
  3. "bytes"
  4. "compress/gzip"
  5. "crypto/aes"
  6. "crypto/cipher"
  7. "crypto/rand"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "math"
  12. "strings"
  13. "time"
  14. "unicode"
  15. "github.com/blang/semver"
  16. "github.com/gravitl/netmaker/logic"
  17. "github.com/gravitl/netmaker/models"
  18. "github.com/gravitl/netmaker/netclient/ncutils"
  19. "golang.org/x/exp/slog"
  20. )
  21. func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
  22. if host.OS == models.OS_Types.IoT { // just pass along IoT messages
  23. return msg, nil
  24. }
  25. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
  26. if trafficErr != nil {
  27. return nil, trafficErr
  28. }
  29. serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
  30. if err != nil {
  31. return nil, err
  32. }
  33. nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  34. if err != nil {
  35. return nil, err
  36. }
  37. return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
  38. }
  39. func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
  40. if len(msg) <= 24 { // make sure message is of appropriate length
  41. return nil, fmt.Errorf("received invalid message from broker %v", msg)
  42. }
  43. host, err := logic.GetHost(node.HostID.String())
  44. if err != nil {
  45. return nil, err
  46. }
  47. return decryptMsgWithHost(host, msg)
  48. }
  49. func BatchItems[T any](items []T, batchSize int) [][]T {
  50. if batchSize <= 0 {
  51. return nil
  52. }
  53. remainderBatchSize := len(items) % batchSize
  54. nBatches := int(math.Ceil(float64(len(items)) / float64(batchSize)))
  55. batches := make([][]T, nBatches)
  56. for i := range batches {
  57. if i == nBatches-1 && remainderBatchSize > 0 {
  58. batches[i] = make([]T, remainderBatchSize)
  59. } else {
  60. batches[i] = make([]T, batchSize)
  61. }
  62. for j := range batches[i] {
  63. batches[i][j] = items[i*batchSize+j]
  64. }
  65. }
  66. return batches
  67. }
  68. func compressPayload(data []byte) ([]byte, error) {
  69. var buf bytes.Buffer
  70. zw := gzip.NewWriter(&buf)
  71. if _, err := zw.Write(data); err != nil {
  72. return nil, err
  73. }
  74. zw.Close()
  75. return buf.Bytes(), nil
  76. }
  77. func encryptAESGCM(key, plaintext []byte) ([]byte, error) {
  78. // Create AES block cipher
  79. block, err := aes.NewCipher(key)
  80. if err != nil {
  81. return nil, err
  82. }
  83. // Create GCM (Galois/Counter Mode) cipher
  84. aesGCM, err := cipher.NewGCM(block)
  85. if err != nil {
  86. return nil, err
  87. }
  88. // Create a random nonce
  89. nonce := make([]byte, aesGCM.NonceSize())
  90. if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
  91. return nil, err
  92. }
  93. // Encrypt the data
  94. ciphertext := aesGCM.Seal(nonce, nonce, plaintext, nil)
  95. return ciphertext, nil
  96. }
  97. func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
  98. if host.OS == models.OS_Types.IoT {
  99. return msg, nil
  100. }
  101. // fetch server public key to be certain hasn't changed in transit
  102. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
  103. if trafficErr != nil {
  104. return nil, trafficErr
  105. }
  106. serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
  107. if err != nil {
  108. return nil, err
  109. }
  110. nodePubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  111. if err != nil {
  112. return nil, err
  113. }
  114. if strings.Contains(host.Version, "0.10.0") {
  115. return ncutils.BoxEncrypt(msg, nodePubKey, serverPrivKey)
  116. }
  117. return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
  118. }
  119. func publish(host *models.Host, dest string, msg []byte) error {
  120. var encrypted []byte
  121. var encryptErr error
  122. vlt, err := versionLessThan(host.Version, "v0.30.0")
  123. if err != nil {
  124. slog.Warn("error checking version less than", "error", err)
  125. return err
  126. }
  127. if vlt {
  128. encrypted, encryptErr = encryptMsg(host, msg)
  129. if encryptErr != nil {
  130. return encryptErr
  131. }
  132. } else {
  133. zipped, err := compressPayload(msg)
  134. if err != nil {
  135. return err
  136. }
  137. encrypted, encryptErr = encryptAESGCM(host.TrafficKeyPublic[0:32], zipped)
  138. if encryptErr != nil {
  139. return encryptErr
  140. }
  141. }
  142. if mqclient == nil || !mqclient.IsConnectionOpen() {
  143. return errors.New("cannot publish ... mqclient not connected")
  144. }
  145. if token := mqclient.Publish(dest, 0, true, encrypted); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
  146. var err error
  147. if token.Error() == nil {
  148. err = errors.New("connection timeout")
  149. } else {
  150. slog.Error("publish to mq error", "error", token.Error().Error())
  151. err = token.Error()
  152. }
  153. return err
  154. }
  155. return nil
  156. }
  157. // decodes a message queue topic and returns the embedded node.ID
  158. func GetID(topic string) (string, error) {
  159. parts := strings.Split(topic, "/")
  160. count := len(parts)
  161. if count == 1 {
  162. return "", fmt.Errorf("invalid topic")
  163. }
  164. //the last part of the topic will be the node.ID
  165. return parts[count-1], nil
  166. }
  167. // versionLessThan checks if v1 < v2 semantically
  168. // dev is the latest version
  169. func versionLessThan(v1, v2 string) (bool, error) {
  170. if v1 == "dev" {
  171. return false, nil
  172. }
  173. if v2 == "dev" {
  174. return true, nil
  175. }
  176. semVer1 := strings.TrimFunc(v1, func(r rune) bool {
  177. return !unicode.IsNumber(r)
  178. })
  179. semVer2 := strings.TrimFunc(v2, func(r rune) bool {
  180. return !unicode.IsNumber(r)
  181. })
  182. sv1, err := semver.Parse(semVer1)
  183. if err != nil {
  184. return false, fmt.Errorf("failed to parse semver1 (%s): %w", semVer1, err)
  185. }
  186. sv2, err := semver.Parse(semVer2)
  187. if err != nil {
  188. return false, fmt.Errorf("failed to parse semver2 (%s): %w", semVer2, err)
  189. }
  190. return sv1.LT(sv2), nil
  191. }