util.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package mq
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "strings"
  7. "time"
  8. "github.com/gravitl/netmaker/logic"
  9. "github.com/gravitl/netmaker/models"
  10. "github.com/gravitl/netmaker/netclient/ncutils"
  11. "golang.org/x/exp/slog"
  12. )
  13. func decryptMsgWithHost(host *models.Host, msg []byte) ([]byte, error) {
  14. if host.OS == models.OS_Types.IoT { // just pass along IoT messages
  15. return msg, nil
  16. }
  17. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey() // get server private key
  18. if trafficErr != nil {
  19. return nil, trafficErr
  20. }
  21. serverPrivTKey, err := ncutils.ConvertBytesToKey(trafficKey)
  22. if err != nil {
  23. return nil, err
  24. }
  25. nodePubTKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  26. if err != nil {
  27. return nil, err
  28. }
  29. return ncutils.DeChunk(msg, nodePubTKey, serverPrivTKey)
  30. }
  31. func DecryptMsg(node *models.Node, msg []byte) ([]byte, error) {
  32. if len(msg) <= 24 { // make sure message is of appropriate length
  33. return nil, fmt.Errorf("received invalid message from broker %v", msg)
  34. }
  35. host, err := logic.GetHost(node.HostID.String())
  36. if err != nil {
  37. return nil, err
  38. }
  39. return decryptMsgWithHost(host, msg)
  40. }
  41. func BatchItems[T any](items []T, batchSize int) [][]T {
  42. if batchSize <= 0 {
  43. return nil
  44. }
  45. remainderBatchSize := len(items) % batchSize
  46. nBatches := int(math.Ceil(float64(len(items)) / float64(batchSize)))
  47. batches := make([][]T, nBatches)
  48. for i := range batches {
  49. if i == nBatches-1 && remainderBatchSize > 0 {
  50. batches[i] = make([]T, remainderBatchSize)
  51. } else {
  52. batches[i] = make([]T, batchSize)
  53. }
  54. for j := range batches[i] {
  55. batches[i][j] = items[i*batchSize+j]
  56. }
  57. }
  58. return batches
  59. }
  60. func encryptMsg(host *models.Host, msg []byte) ([]byte, error) {
  61. if host.OS == models.OS_Types.IoT {
  62. return msg, nil
  63. }
  64. // fetch server public key to be certain hasn't changed in transit
  65. trafficKey, trafficErr := logic.RetrievePrivateTrafficKey()
  66. if trafficErr != nil {
  67. return nil, trafficErr
  68. }
  69. serverPrivKey, err := ncutils.ConvertBytesToKey(trafficKey)
  70. if err != nil {
  71. return nil, err
  72. }
  73. nodePubKey, err := ncutils.ConvertBytesToKey(host.TrafficKeyPublic)
  74. if err != nil {
  75. return nil, err
  76. }
  77. if strings.Contains(host.Version, "0.10.0") {
  78. return ncutils.BoxEncrypt(msg, nodePubKey, serverPrivKey)
  79. }
  80. return ncutils.Chunk(msg, nodePubKey, serverPrivKey)
  81. }
  82. func publish(host *models.Host, dest string, msg []byte) error {
  83. encrypted, encryptErr := encryptMsg(host, msg)
  84. if encryptErr != nil {
  85. return encryptErr
  86. }
  87. if mqclient == nil || !mqclient.IsConnectionOpen() {
  88. return errors.New("cannot publish ... mqclient not connected")
  89. }
  90. if token := mqclient.Publish(dest, 0, true, encrypted); !token.WaitTimeout(MQ_TIMEOUT*time.Second) || token.Error() != nil {
  91. var err error
  92. if token.Error() == nil {
  93. err = errors.New("connection timeout")
  94. } else {
  95. slog.Error("publish to mq error", "error", token.Error().Error())
  96. err = token.Error()
  97. }
  98. return err
  99. }
  100. return nil
  101. }
  102. // decodes a message queue topic and returns the embedded node.ID
  103. func GetID(topic string) (string, error) {
  104. parts := strings.Split(topic, "/")
  105. count := len(parts)
  106. if count == 1 {
  107. return "", fmt.Errorf("invalid topic")
  108. }
  109. //the last part of the topic will be the node.ID
  110. return parts[count-1], nil
  111. }