mq.go 6.4 KB


  1. package mq
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "net"
  6. "strconv"
  7. "strings"
  8. "time"
  9. mqtt "github.com/eclipse/paho.mqtt.golang"
  10. "github.com/gravitl/netmaker/database"
  11. "github.com/gravitl/netmaker/logger"
  12. "github.com/gravitl/netmaker/logic"
  13. "github.com/gravitl/netmaker/models"
  14. "github.com/gravitl/netmaker/servercfg"
  15. "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
  16. )
  17. var DefaultHandler mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
  18. logger.Log(0, "MQTT Message: Topic: "+string(msg.Topic())+" Message: "+string(msg.Payload()))
  19. }
  20. var Ping mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
  21. logger.Log(0, "Ping Handler: "+msg.Topic())
  22. go func() {
  23. id, err := GetID(msg.Topic())
  24. if err != nil {
  25. logger.Log(0, "error getting node.ID sent on ping topic ")
  26. return
  27. }
  28. node, err := logic.GetNodeByID(id)
  29. if err != nil {
  30. logger.Log(0, "mq-ping error getting node: "+err.Error())
  31. record, err := database.FetchRecord(database.NODES_TABLE_NAME, id)
  32. if err != nil {
  33. logger.Log(0, "error reading database ", err.Error())
  34. return
  35. }
  36. logger.Log(0, "record from database")
  37. logger.Log(0, record)
  38. return
  39. }
  40. node.SetLastCheckIn()
  41. if err := logic.UpdateNode(&node, &node) ; err != nil {
  42. logger.Log(0, "error updating node "+ err.Error())
  43. }
  44. logger.Log(0, "ping processed")
  45. // --TODO --set client version once feature is implemented.
  46. //node.SetClientVersion(msg.Payload())
  47. }()
  48. }
  49. var PublicKeyUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
  50. logger.Log(0, "PublicKey Handler")
  51. go func() {
  52. logger.Log(0, "public key update "+msg.Topic())
  53. key := string(msg.Payload())
  54. id, err := GetID(msg.Topic())
  55. if err != nil {
  56. logger.Log(0, "error getting node.ID sent on "+msg.Topic()+" "+err.Error())
  57. }
  58. node, err := logic.GetNodeByID(id)
  59. if err != nil {
  60. logger.Log(0, "error retrieving node "+msg.Topic()+" "+err.Error())
  61. }
  62. node.PublicKey = key
  63. node.SetLastCheckIn()
  64. if err := logic.UpdateNode(&node, &node) ; err != nil {
  65. logger.Log(0, "error updating node "+ err.Error())
  66. }
  67. if err := UpdatePeers(client, node); err != nil {
  68. logger.Log(0, "error updating peers "+err.Error())
  69. }
  70. }()
  71. }
  72. var IPUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
  73. go func() {
  74. ip := string(msg.Payload())
  75. logger.Log(0, "IPUpdate Handler")
  76. id, err := GetID(msg.Topic())
  77. logger.Log(0, "ipUpdate recieved from "+id)
  78. if err != nil {
  79. logger.Log(0, "error getting node.ID sent on update/ip topic ")
  80. return
  81. }
  82. node, err := logic.GetNodeByID(id)
  83. if err != nil {
  84. logger.Log(0, "invalid ID recieved on update/ip topic: "+err.Error())
  85. return
  86. }
  87. node.Endpoint = ip
  88. node.SetLastCheckIn()
  89. if err := logic.UpdateNode(&node, &node) ; err != nil {
  90. logger.Log(0, "error updating node "+ err.Error())
  91. }
  92. if err != UpdatePeers(client, node) {
  93. logger.Log(0, "error updating peers "+err.Error())
  94. }
  95. }()
  96. }
  97. func UpdatePeers(client mqtt.Client, newnode models.Node) error {
  98. networkNodes, err := logic.GetNetworkNodes(newnode.Network)
  99. if err != nil {
  100. return err
  101. }
  102. keepalive, _ := time.ParseDuration(string(newnode.PersistentKeepalive)+"s")
  103. for _, node := range networkNodes {
  104. var peers []wgtypes.PeerConfig
  105. var peerUpdate models.PeerUpdate
  106. for _, peer := range networkNodes{
  107. if peer.ID == node.ID {
  108. //skip
  109. continue
  110. }
  111. pubkey, err := wgtypes.ParseKey(peer.PublicKey)
  112. if err != nil {
  113. return err
  114. }
  115. if node.Endpoint == peer.Endpoint {
  116. if node.LocalAddress != peer.LocalAddress && peer.LocalAddress != "" {
  117. peer.Endpoint = peer.LocalAddress
  118. }else {
  119. continue
  120. }
  121. }
  122. endpoint := peer.Endpoint + ":" + strconv.Itoa(int(peer.ListenPort))
  123. //fmt.Println("endpoint: ", endpoint, peer.Endpoint, peer.ListenPort)
  124. address, err := net.ResolveUDPAddr("udp", endpoint)
  125. if err != nil {
  126. return err
  127. }
  128. //calculate Allowed IPs.
  129. var peerData wgtypes.PeerConfig
  130. peerData = wgtypes.PeerConfig{
  131. PublicKey: pubkey,
  132. Endpoint: address,
  133. PersistentKeepaliveInterval: &keepalive,
  134. //AllowedIPs: allowedIPs
  135. }
  136. peers = append (peers, peerData)
  137. }
  138. peerUpdate.Network = node.Network
  139. peerUpdate.Peers = peers
  140. data, err := json.Marshal(&peerUpdate)
  141. if err != nil {
  142. logger.Log(0, "error marshaling peer update "+err.Error())
  143. return err
  144. }
  145. if token := client.Publish("/update/peers/"+node.ID, 0, false, data); token.Wait() && token.Error() != nil {
  146. logger.Log(0, "error sending peer updatte to no")
  147. return err
  148. }
  149. }
  150. return nil
  151. }
  152. var LocalAddressUpdate mqtt.MessageHandler = func(client mqtt.Client, msg mqtt.Message) {
  153. logger.Log(0, "LocalAddressUpdate Handler")
  154. go func() {
  155. logger.Log(0, "LocalAddressUpdate handler")
  156. id, err := GetID(msg.Topic())
  157. if err != nil {
  158. logger.Log(0, "error getting node.ID "+msg.Topic())
  159. return
  160. }
  161. node, err := logic.GetNodeByID(id)
  162. if err != nil {
  163. logger.Log(0, "error get node "+msg.Topic())
  164. return
  165. }
  166. node.LocalAddress = string(msg.Payload())
  167. node.SetLastCheckIn()
  168. if err := UpdatePeers(client, node); err != nil {
  169. logger.Log(0, "error updating peers "+err.Error())
  170. }
  171. }()
  172. }
  173. func GetID(topic string) (string, error) {
  174. parts := strings.Split(topic, "/")
  175. count := len(parts)
  176. if count == 1 {
  177. return "", errors.New("invalid topic")
  178. }
  179. //the last part of the topic will be the node.ID
  180. return parts[count-1], nil
  181. }
  182. func NewPeer(node models.Node) error {
  183. opts := mqtt.NewClientOptions()
  184. broker := servercfg.GetMessageQueueEndpoint()
  185. logger.Log(0, "broker: "+broker)
  186. opts.AddBroker(broker)
  187. client := mqtt.NewClient(opts)
  188. if token := client.Connect(); token.Wait() && token.Error() != nil {
  189. return token.Error()
  190. }
  191. if err := UpdatePeers(client, node); err != nil {
  192. return err
  193. }
  194. return nil
  195. }