2
0

host_session.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. package auth
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "log/slog"
  6. "strings"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/gorilla/websocket"
  10. "github.com/gravitl/netmaker/logger"
  11. "github.com/gravitl/netmaker/logic"
  12. "github.com/gravitl/netmaker/logic/hostactions"
  13. "github.com/gravitl/netmaker/logic/pro/netcache"
  14. "github.com/gravitl/netmaker/models"
  15. "github.com/gravitl/netmaker/mq"
  16. "github.com/gravitl/netmaker/servercfg"
  17. )
  18. // SessionHandler - called by the HTTP router when user
  19. // is calling netclient with join/register -s parameter in order to authenticate
  20. // via SSO mechanism by OAuth2 protocol flow.
  21. // This triggers a session start and it is managed by the flow implemented here and callback
  22. // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
  23. func SessionHandler(conn *websocket.Conn) {
  24. defer conn.Close()
  25. // If reached here we have a session from user to handle...
  26. messageType, message, err := conn.ReadMessage()
  27. if err != nil {
  28. logger.Log(0, "Error during message reading:", err.Error())
  29. return
  30. }
  31. var registerMessage models.RegisterMsg
  32. if err = json.Unmarshal(message, &registerMessage); err != nil {
  33. logger.Log(0, "Failed to unmarshall data err=", err.Error())
  34. return
  35. }
  36. if registerMessage.RegisterHost.ID == uuid.Nil {
  37. logger.Log(0, "invalid host registration attempted")
  38. return
  39. }
  40. req := new(netcache.CValue)
  41. req.Value = string(registerMessage.RegisterHost.ID.String())
  42. req.Network = registerMessage.Network
  43. req.Host = registerMessage.RegisterHost
  44. req.ALL = registerMessage.JoinAll
  45. req.Pass = ""
  46. req.User = registerMessage.User
  47. if len(req.User) > 0 && len(registerMessage.Password) == 0 {
  48. logger.Log(0, "invalid host registration attempted")
  49. return
  50. }
  51. // Add any extra parameter provided in the configuration to the Authorize Endpoint request??
  52. stateStr := logic.RandomString(node_signin_length)
  53. if err := netcache.Set(stateStr, req); err != nil {
  54. logger.Log(0, "Failed to process sso request -", err.Error())
  55. return
  56. }
  57. defer netcache.Del(stateStr)
  58. // Wait for the user to finish his auth flow...
  59. timeout := make(chan bool, 2)
  60. answer := make(chan netcache.CValue, 1)
  61. defer close(answer)
  62. defer close(timeout)
  63. if len(registerMessage.User) > 0 { // handle basic auth
  64. logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "user:", registerMessage.User)
  65. if !logic.IsBasicAuthEnabled() {
  66. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  67. if err != nil {
  68. logger.Log(0, "error during message writing:", err.Error())
  69. }
  70. }
  71. _, err := logic.VerifyAuthRequest(models.UserAuthParams{
  72. UserName: registerMessage.User,
  73. Password: registerMessage.Password,
  74. })
  75. if err != nil {
  76. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  77. if err != nil {
  78. logger.Log(0, "error during message writing:", err.Error())
  79. }
  80. return
  81. }
  82. req.Pass = req.Host.ID.String()
  83. // user, err := logic.GetUser(req.User)
  84. // if err != nil {
  85. // logger.Log(0, "failed to get user", req.User, "from database")
  86. // err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  87. // if err != nil {
  88. // logger.Log(0, "error during message writing:", err.Error())
  89. // }
  90. // return
  91. // }
  92. // if !user.IsAdmin && !user.IsSuperAdmin {
  93. // logger.Log(0, "user", req.User, "is neither an admin or superadmin. denying registeration")
  94. // conn.WriteMessage(messageType, []byte("cannot register with a non-admin or non-superadmin"))
  95. // err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  96. // if err != nil {
  97. // logger.Log(0, "error during message writing:", err.Error())
  98. // }
  99. // return
  100. // }
  101. if err = netcache.Set(stateStr, req); err != nil { // give the user's host access in the DB
  102. logger.Log(0, "machine failed to complete join on network,", registerMessage.Network, "-", err.Error())
  103. return
  104. }
  105. } else { // handle SSO / OAuth
  106. if auth_provider == nil {
  107. err = conn.WriteMessage(messageType, []byte("Oauth not configured"))
  108. if err != nil {
  109. logger.Log(0, "error during message writing:", err.Error())
  110. }
  111. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  112. if err != nil {
  113. logger.Log(0, "error during message writing:", err.Error())
  114. }
  115. return
  116. }
  117. logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "via SSO")
  118. redirectUrl := fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
  119. err = conn.WriteMessage(messageType, []byte(redirectUrl))
  120. if err != nil {
  121. logger.Log(0, "error during message writing:", err.Error())
  122. }
  123. }
  124. go func() {
  125. for {
  126. msgType, _, err := conn.ReadMessage()
  127. if err != nil || msgType == websocket.CloseMessage {
  128. netcache.Del(stateStr)
  129. return
  130. }
  131. }
  132. }()
  133. go func() {
  134. for {
  135. cachedReq, err := netcache.Get(stateStr)
  136. if err != nil {
  137. logger.Log(0, "oauth state has been deleted ", err.Error())
  138. timeout <- true
  139. break
  140. } else if len(cachedReq.User) > 0 {
  141. logger.Log(0, "host SSO process completed for user", cachedReq.User)
  142. answer <- *cachedReq
  143. break
  144. }
  145. time.Sleep(time.Second)
  146. }
  147. }()
  148. select {
  149. case result := <-answer: // a read from req.answerCh has occurred
  150. // add the host, if not exists, handle like enrollment registration
  151. if !logic.HostExists(&result.Host) { // check if host already exists, add if not
  152. if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
  153. if err := mq.GetEmqxHandler().CreateEmqxUser(result.Host.ID.String(), result.Host.HostPass); err != nil {
  154. logger.Log(0, "failed to create host credentials for EMQX: ", err.Error())
  155. return
  156. }
  157. }
  158. _ = logic.CheckHostPorts(&result.Host)
  159. if err := logic.CreateHost(&result.Host); err != nil {
  160. handleHostRegErr(conn, err)
  161. return
  162. }
  163. }
  164. key, keyErr := logic.RetrievePublicTrafficKey()
  165. if keyErr != nil {
  166. handleHostRegErr(conn, err)
  167. return
  168. }
  169. currHost, err := logic.GetHost(result.Host.ID.String())
  170. if err != nil {
  171. handleHostRegErr(conn, err)
  172. return
  173. }
  174. var currentNetworks = []string{}
  175. if result.ALL {
  176. currentNets, err := logic.GetNetworks()
  177. if err == nil && len(currentNets) > 0 {
  178. for i := range currentNets {
  179. currentNetworks = append(currentNetworks, currentNets[i].NetID)
  180. }
  181. }
  182. } else if len(result.Network) > 0 {
  183. currentNetworks = append(currentNetworks, result.Network)
  184. }
  185. var netsToAdd = []string{} // track the networks not currently owned by host
  186. hostNets := logic.GetHostNetworks(currHost.ID.String())
  187. for _, newNet := range currentNetworks {
  188. if !logic.StringSliceContains(hostNets, newNet) {
  189. if len(result.User) > 0 {
  190. _, err := isUserIsAllowed(result.User, newNet)
  191. if err != nil {
  192. logger.Log(0, "unauthorized user", result.User, "attempted to register to network", newNet)
  193. handleHostRegErr(conn, err)
  194. return
  195. }
  196. }
  197. netsToAdd = append(netsToAdd, newNet)
  198. }
  199. }
  200. server := logic.GetServerInfo()
  201. server.TrafficKey = key
  202. result.Host.HostPass = ""
  203. response := models.RegisterResponse{
  204. ServerConf: server,
  205. RequestedHost: result.Host,
  206. }
  207. reponseData, err := json.Marshal(&response)
  208. if err != nil {
  209. handleHostRegErr(conn, err)
  210. return
  211. }
  212. if err = conn.WriteMessage(messageType, reponseData); err != nil {
  213. logger.Log(0, "error during message writing:", err.Error())
  214. }
  215. go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil, []models.TagID{})
  216. case <-timeout: // the read from req.answerCh has timed out
  217. logger.Log(0, "timeout signal recv,exiting oauth socket conn")
  218. break
  219. }
  220. // Cleanly close the connection by sending a close message and then
  221. // waiting (with timeout) for the server to close the connection.
  222. if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
  223. logger.Log(0, "write close:", err.Error())
  224. return
  225. }
  226. }
  227. // CheckNetRegAndHostUpdate - run through networks and send a host update
  228. func CheckNetRegAndHostUpdate(networks []string, h *models.Host, relayNodeId uuid.UUID, tags []models.TagID) {
  229. // publish host update through MQ
  230. for i := range networks {
  231. network := networks[i]
  232. if ok, _ := logic.NetworkExists(network); ok {
  233. newNode, err := logic.UpdateHostNetwork(h, network, true)
  234. if err == nil || strings.Contains(err.Error(), "host already part of network") {
  235. if len(tags) > 0 {
  236. newNode.Tags = make(map[models.TagID]struct{})
  237. for _, tagI := range tags {
  238. newNode.Tags[tagI] = struct{}{}
  239. }
  240. logic.UpsertNode(newNode)
  241. }
  242. if relayNodeId != uuid.Nil && !newNode.IsRelayed {
  243. // check if relay node exists and acting as relay
  244. relaynode, err := logic.GetNodeByID(relayNodeId.String())
  245. if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network {
  246. slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), relayNodeId.String(), network))
  247. newNode.IsRelayed = true
  248. newNode.RelayedBy = relayNodeId.String()
  249. updatedRelayNode := relaynode
  250. updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String())
  251. logic.UpdateRelayed(&relaynode, &updatedRelayNode)
  252. if err := logic.UpsertNode(&updatedRelayNode); err != nil {
  253. slog.Error("failed to update node", "nodeid", relayNodeId.String())
  254. }
  255. if err := logic.UpsertNode(newNode); err != nil {
  256. slog.Error("failed to update node", "nodeid", relayNodeId.String())
  257. }
  258. } else {
  259. slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err)
  260. }
  261. }
  262. if err != nil && strings.Contains(err.Error(), "host already part of network") {
  263. continue
  264. }
  265. } else {
  266. logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, network, err.Error())
  267. continue
  268. }
  269. logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
  270. hostactions.AddAction(models.HostUpdate{
  271. Action: models.JoinHostToNetwork,
  272. Host: *h,
  273. Node: *newNode,
  274. })
  275. if h.IsDefault {
  276. // make host failover
  277. logic.CreateFailOver(*newNode)
  278. // make host remote access gateway
  279. logic.CreateIngressGateway(network, newNode.ID.String(), models.IngressRequest{})
  280. logic.CreateRelay(models.RelayRequest{
  281. NodeID: newNode.ID.String(),
  282. NetID: network,
  283. })
  284. }
  285. }
  286. }
  287. if servercfg.IsMessageQueueBackend() {
  288. mq.HostUpdate(&models.HostUpdate{
  289. Action: models.RequestAck,
  290. Host: *h,
  291. })
  292. if err := mq.PublishPeerUpdate(false); err != nil {
  293. logger.Log(0, "failed to publish peer update during registration -", err.Error())
  294. }
  295. }
  296. }
  297. func handleHostRegErr(conn *websocket.Conn, err error) {
  298. _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  299. if err != nil {
  300. logger.Log(0, "error during host registration via auth:", err.Error())
  301. }
  302. }