host_session.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package auth
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/google/uuid"
  8. "github.com/gorilla/websocket"
  9. "github.com/gravitl/netmaker/logger"
  10. "github.com/gravitl/netmaker/logic"
  11. "github.com/gravitl/netmaker/logic/hostactions"
  12. "github.com/gravitl/netmaker/logic/pro/netcache"
  13. "github.com/gravitl/netmaker/models"
  14. "github.com/gravitl/netmaker/mq"
  15. "github.com/gravitl/netmaker/servercfg"
  16. "golang.org/x/exp/slog"
  17. )
  18. // SessionHandler - called by the HTTP router when user
  19. // is calling netclient with join/register -s parameter in order to authenticate
  20. // via SSO mechanism by OAuth2 protocol flow.
  21. // This triggers a session start and it is managed by the flow implemented here and callback
  22. // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
  23. func SessionHandler(conn *websocket.Conn) {
  24. defer conn.Close()
  25. // If reached here we have a session from user to handle...
  26. messageType, message, err := conn.ReadMessage()
  27. if err != nil {
  28. logger.Log(0, "Error during message reading:", err.Error())
  29. return
  30. }
  31. var registerMessage models.RegisterMsg
  32. if err = json.Unmarshal(message, &registerMessage); err != nil {
  33. logger.Log(0, "Failed to unmarshall data err=", err.Error())
  34. return
  35. }
  36. if registerMessage.RegisterHost.ID == uuid.Nil {
  37. logger.Log(0, "invalid host registration attempted")
  38. return
  39. }
  40. req := new(netcache.CValue)
  41. req.Value = string(registerMessage.RegisterHost.ID.String())
  42. req.Network = registerMessage.Network
  43. req.Host = registerMessage.RegisterHost
  44. req.ALL = registerMessage.JoinAll
  45. req.Pass = ""
  46. req.User = registerMessage.User
  47. if len(req.User) > 0 && len(registerMessage.Password) == 0 {
  48. logger.Log(0, "invalid host registration attempted")
  49. return
  50. }
  51. // Add any extra parameter provided in the configuration to the Authorize Endpoint request??
  52. stateStr := logic.RandomString(node_signin_length)
  53. if err := netcache.Set(stateStr, req); err != nil {
  54. logger.Log(0, "Failed to process sso request -", err.Error())
  55. return
  56. }
  57. // Wait for the user to finish his auth flow...
  58. timeout := make(chan bool, 1)
  59. answer := make(chan netcache.CValue, 1)
  60. defer close(answer)
  61. defer close(timeout)
  62. if len(registerMessage.User) > 0 { // handle basic auth
  63. logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "user:", registerMessage.User)
  64. if !servercfg.IsBasicAuthEnabled() {
  65. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  66. if err != nil {
  67. logger.Log(0, "error during message writing:", err.Error())
  68. }
  69. }
  70. _, err := logic.VerifyAuthRequest(models.UserAuthParams{
  71. UserName: registerMessage.User,
  72. Password: registerMessage.Password,
  73. })
  74. if err != nil {
  75. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  76. if err != nil {
  77. logger.Log(0, "error during message writing:", err.Error())
  78. }
  79. return
  80. }
  81. req.Pass = req.Host.ID.String()
  82. if err = netcache.Set(stateStr, req); err != nil { // give the user's host access in the DB
  83. logger.Log(0, "machine failed to complete join on network,", registerMessage.Network, "-", err.Error())
  84. return
  85. }
  86. } else { // handle SSO / OAuth
  87. if auth_provider == nil {
  88. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  89. if err != nil {
  90. logger.Log(0, "error during message writing:", err.Error())
  91. }
  92. return
  93. }
  94. logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "via SSO")
  95. redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
  96. err = conn.WriteMessage(messageType, []byte(redirectUrl))
  97. if err != nil {
  98. logger.Log(0, "error during message writing:", err.Error())
  99. }
  100. }
  101. go func() {
  102. for {
  103. cachedReq, err := netcache.Get(stateStr)
  104. if err != nil {
  105. if strings.Contains(err.Error(), "expired") {
  106. logger.Log(1, "timeout occurred while waiting for SSO registration")
  107. timeout <- true
  108. break
  109. }
  110. continue
  111. } else if len(cachedReq.User) > 0 {
  112. logger.Log(0, "host SSO process completed for user", cachedReq.User)
  113. answer <- *cachedReq
  114. break
  115. }
  116. time.Sleep(500) // try it 2 times per second to see if auth is completed
  117. }
  118. }()
  119. select {
  120. case result := <-answer: // a read from req.answerCh has occurred
  121. // add the host, if not exists, handle like enrollment registration
  122. hostPass := result.Host.HostPass
  123. if !logic.HostExists(&result.Host) { // check if host already exists, add if not
  124. if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
  125. if err := mq.CreateEmqxUser(result.Host.ID.String(), result.Host.HostPass, false); err != nil {
  126. logger.Log(0, "failed to create host credentials for EMQX: ", err.Error())
  127. return
  128. }
  129. if err := mq.CreateHostACL(result.Host.ID.String(), servercfg.GetServerInfo().Server); err != nil {
  130. logger.Log(0, "failed to add host ACL rules to EMQX: ", err.Error())
  131. return
  132. }
  133. }
  134. logic.CheckHostPorts(&result.Host)
  135. if err := logic.CreateHost(&result.Host); err != nil {
  136. handleHostRegErr(conn, err)
  137. return
  138. }
  139. }
  140. key, keyErr := logic.RetrievePublicTrafficKey()
  141. if keyErr != nil {
  142. handleHostRegErr(conn, err)
  143. return
  144. }
  145. currHost, err := logic.GetHost(result.Host.ID.String())
  146. if err != nil {
  147. handleHostRegErr(conn, err)
  148. return
  149. }
  150. var currentNetworks = []string{}
  151. if result.ALL {
  152. currentNets, err := logic.GetNetworks()
  153. if err == nil && len(currentNets) > 0 {
  154. for i := range currentNets {
  155. currentNetworks = append(currentNetworks, currentNets[i].NetID)
  156. }
  157. }
  158. } else if len(result.Network) > 0 {
  159. currentNetworks = append(currentNetworks, result.Network)
  160. }
  161. var netsToAdd = []string{} // track the networks not currently owned by host
  162. hostNets := logic.GetHostNetworks(currHost.ID.String())
  163. for _, newNet := range currentNetworks {
  164. if !logic.StringSliceContains(hostNets, newNet) {
  165. if len(result.User) > 0 {
  166. _, err := isUserIsAllowed(result.User, newNet, false)
  167. if err != nil {
  168. logger.Log(0, "unauthorized user", result.User, "attempted to register to network", newNet)
  169. handleHostRegErr(conn, err)
  170. return
  171. }
  172. }
  173. netsToAdd = append(netsToAdd, newNet)
  174. }
  175. }
  176. server := servercfg.GetServerInfo()
  177. server.TrafficKey = key
  178. if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
  179. // set MQ username and password for EMQX clients
  180. server.MQUserName = result.Host.ID.String()
  181. server.MQPassword = hostPass
  182. }
  183. result.Host.HostPass = ""
  184. response := models.RegisterResponse{
  185. ServerConf: server,
  186. RequestedHost: result.Host,
  187. }
  188. reponseData, err := json.Marshal(&response)
  189. if err != nil {
  190. handleHostRegErr(conn, err)
  191. return
  192. }
  193. if err = conn.WriteMessage(messageType, reponseData); err != nil {
  194. logger.Log(0, "error during message writing:", err.Error())
  195. }
  196. go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil)
  197. case <-timeout: // the read from req.answerCh has timed out
  198. if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
  199. logger.Log(0, "error during timeout message writing:", err.Error())
  200. }
  201. }
  202. // The entry is not needed anymore, but we will let the producer to close it to avoid panic cases
  203. if err = netcache.Del(stateStr); err != nil {
  204. logger.Log(0, "failed to remove node SSO cache entry", err.Error())
  205. }
  206. // Cleanly close the connection by sending a close message and then
  207. // waiting (with timeout) for the server to close the connection.
  208. if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
  209. logger.Log(0, "write close:", err.Error())
  210. return
  211. }
  212. }
  213. // CheckNetRegAndHostUpdate - run through networks and send a host update
  214. func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID) {
  215. // publish host update through MQ
  216. for i := range networks {
  217. network := networks[i]
  218. if ok, _ := logic.NetworkExists(network); ok {
  219. newNode, err := logic.UpdateHostNetwork(h, network, true)
  220. if err != nil {
  221. logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
  222. continue
  223. }
  224. if relayNodeId != uuid.Nil && !newNode.IsRelayed {
  225. newNode.IsRelayed = true
  226. newNode.RelayedBy = relayNodeId.String()
  227. slog.Info(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network))
  228. if err := logic.UpsertNode(newNode); err != nil {
  229. slog.Error("failed to update node", "nodeid", relayNodeId.String())
  230. }
  231. }
  232. logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
  233. hostactions.AddAction(models.HostUpdate{
  234. Action: models.JoinHostToNetwork,
  235. Host: *h,
  236. Node: *newNode,
  237. })
  238. }
  239. }
  240. if servercfg.IsMessageQueueBackend() {
  241. mq.HostUpdate(&models.HostUpdate{
  242. Action: models.RequestAck,
  243. Host: *h,
  244. })
  245. if err := mq.PublishPeerUpdate(false); err != nil {
  246. logger.Log(0, "failed to publish peer update during registration -", err.Error())
  247. }
  248. }
  249. }
  250. func handleHostRegErr(conn *websocket.Conn, err error) {
  251. _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  252. if err != nil {
  253. logger.Log(0, "error during host registration via auth:", err.Error())
  254. }
  255. }