host_session.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. package auth
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log/slog"
  7. "strings"
  8. "time"
  9. "github.com/google/uuid"
  10. "github.com/gorilla/websocket"
  11. "github.com/gravitl/netmaker/db"
  12. "github.com/gravitl/netmaker/logger"
  13. "github.com/gravitl/netmaker/logic"
  14. "github.com/gravitl/netmaker/logic/hostactions"
  15. "github.com/gravitl/netmaker/logic/pro/netcache"
  16. "github.com/gravitl/netmaker/models"
  17. "github.com/gravitl/netmaker/mq"
  18. "github.com/gravitl/netmaker/schema"
  19. "github.com/gravitl/netmaker/servercfg"
  20. )
  21. // SessionHandler - called by the HTTP router when user
  22. // is calling netclient with join/register -s parameter in order to authenticate
  23. // via SSO mechanism by OAuth2 protocol flow.
  24. // This triggers a session start and it is managed by the flow implemented here and callback
  25. // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
  26. func SessionHandler(conn *websocket.Conn) {
  27. defer conn.Close()
  28. // If reached here we have a session from user to handle...
  29. messageType, message, err := conn.ReadMessage()
  30. if err != nil {
  31. logger.Log(0, "Error during message reading:", err.Error())
  32. return
  33. }
  34. var registerMessage models.RegisterMsg
  35. if err = json.Unmarshal(message, &registerMessage); err != nil {
  36. logger.Log(0, "Failed to unmarshall data err=", err.Error())
  37. return
  38. }
  39. if registerMessage.RegisterHost.ID == uuid.Nil {
  40. logger.Log(0, "invalid host registration attempted")
  41. return
  42. }
  43. req := new(netcache.CValue)
  44. req.Value = string(registerMessage.RegisterHost.ID.String())
  45. req.Network = registerMessage.Network
  46. req.Host = registerMessage.RegisterHost
  47. req.ALL = registerMessage.JoinAll
  48. req.Pass = ""
  49. req.User = registerMessage.User
  50. if len(req.User) > 0 && len(registerMessage.Password) == 0 {
  51. logger.Log(0, "invalid host registration attempted")
  52. return
  53. }
  54. // Add any extra parameter provided in the configuration to the Authorize Endpoint request??
  55. stateStr := logic.RandomString(node_signin_length)
  56. if err := netcache.Set(stateStr, req); err != nil {
  57. logger.Log(0, "Failed to process sso request -", err.Error())
  58. return
  59. }
  60. defer netcache.Del(stateStr)
  61. // Wait for the user to finish his auth flow...
  62. timeout := make(chan bool, 2)
  63. answer := make(chan netcache.CValue, 1)
  64. defer close(answer)
  65. defer close(timeout)
  66. if len(registerMessage.User) > 0 { // handle basic auth
  67. logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "user:", registerMessage.User)
  68. if !logic.IsBasicAuthEnabled() {
  69. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  70. if err != nil {
  71. logger.Log(0, "error during message writing:", err.Error())
  72. }
  73. }
  74. _, err := logic.VerifyAuthRequest(models.UserAuthParams{
  75. UserName: registerMessage.User,
  76. Password: registerMessage.Password,
  77. }, logic.NetclientApp)
  78. if err != nil {
  79. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  80. if err != nil {
  81. logger.Log(0, "error during message writing:", err.Error())
  82. }
  83. return
  84. }
  85. req.Pass = req.Host.ID.String()
  86. // user, err := logic.GetUser(req.User)
  87. // if err != nil {
  88. // logger.Log(0, "failed to get user", req.User, "from database")
  89. // err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  90. // if err != nil {
  91. // logger.Log(0, "error during message writing:", err.Error())
  92. // }
  93. // return
  94. // }
  95. // if !user.IsAdmin && !user.IsSuperAdmin {
  96. // logger.Log(0, "user", req.User, "is neither an admin or superadmin. denying registeration")
  97. // conn.WriteMessage(messageType, []byte("cannot register with a non-admin or non-superadmin"))
  98. // err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  99. // if err != nil {
  100. // logger.Log(0, "error during message writing:", err.Error())
  101. // }
  102. // return
  103. // }
  104. if err = netcache.Set(stateStr, req); err != nil { // give the user's host access in the DB
  105. logger.Log(0, "machine failed to complete join on network,", registerMessage.Network, "-", err.Error())
  106. return
  107. }
  108. } else { // handle SSO / OAuth
  109. if !logic.IsOAuthConfigured() {
  110. err = conn.WriteMessage(messageType, []byte("Oauth not configured"))
  111. if err != nil {
  112. logger.Log(0, "error during message writing:", err.Error())
  113. }
  114. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  115. if err != nil {
  116. logger.Log(0, "error during message writing:", err.Error())
  117. }
  118. return
  119. }
  120. logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "via SSO")
  121. redirectUrl := fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
  122. err = conn.WriteMessage(messageType, []byte(redirectUrl))
  123. if err != nil {
  124. logger.Log(0, "error during message writing:", err.Error())
  125. }
  126. }
  127. go func() {
  128. for {
  129. msgType, _, err := conn.ReadMessage()
  130. if err != nil || msgType == websocket.CloseMessage {
  131. netcache.Del(stateStr)
  132. return
  133. }
  134. }
  135. }()
  136. go func() {
  137. for {
  138. cachedReq, err := netcache.Get(stateStr)
  139. if err != nil {
  140. logger.Log(0, "oauth state has been deleted ", err.Error())
  141. timeout <- true
  142. break
  143. } else if len(cachedReq.User) > 0 {
  144. logger.Log(0, "host SSO process completed for user", cachedReq.User)
  145. answer <- *cachedReq
  146. break
  147. }
  148. time.Sleep(time.Second)
  149. }
  150. }()
  151. select {
  152. case result := <-answer: // a read from req.answerCh has occurred
  153. // add the host, if not exists, handle like enrollment registration
  154. if !logic.HostExists(&result.Host) { // check if host already exists, add if not
  155. if servercfg.GetBrokerType() == servercfg.EmqxBrokerType {
  156. if err := mq.GetEmqxHandler().CreateEmqxUser(result.Host.ID.String(), result.Host.HostPass); err != nil {
  157. logger.Log(0, "failed to create host credentials for EMQX: ", err.Error())
  158. return
  159. }
  160. }
  161. _ = logic.CheckHostPorts(&result.Host)
  162. if err := logic.CreateHost(&result.Host); err != nil {
  163. handleHostRegErr(conn, err)
  164. return
  165. }
  166. }
  167. key, keyErr := logic.RetrievePublicTrafficKey()
  168. if keyErr != nil {
  169. handleHostRegErr(conn, err)
  170. return
  171. }
  172. currHost, err := logic.GetHost(result.Host.ID.String())
  173. if err != nil {
  174. handleHostRegErr(conn, err)
  175. return
  176. }
  177. var currentNetworks = []string{}
  178. if result.ALL {
  179. currentNets, err := logic.GetNetworks()
  180. if err == nil && len(currentNets) > 0 {
  181. for i := range currentNets {
  182. currentNetworks = append(currentNetworks, currentNets[i].NetID)
  183. }
  184. }
  185. } else if len(result.Network) > 0 {
  186. currentNetworks = append(currentNetworks, result.Network)
  187. }
  188. var netsToAdd = []string{} // track the networks not currently owned by host
  189. hostNets := logic.GetHostNetworks(currHost.ID.String())
  190. for _, newNet := range currentNetworks {
  191. if !logic.StringSliceContains(hostNets, newNet) {
  192. if len(result.User) > 0 {
  193. _, err := isUserIsAllowed(result.User, newNet)
  194. if err != nil {
  195. logger.Log(0, "unauthorized user", result.User, "attempted to register to network", newNet)
  196. handleHostRegErr(conn, err)
  197. return
  198. }
  199. }
  200. netsToAdd = append(netsToAdd, newNet)
  201. }
  202. }
  203. server := logic.GetServerInfo()
  204. server.TrafficKey = key
  205. result.Host.HostPass = ""
  206. response := models.RegisterResponse{
  207. ServerConf: server,
  208. RequestedHost: result.Host,
  209. }
  210. reponseData, err := json.Marshal(&response)
  211. if err != nil {
  212. handleHostRegErr(conn, err)
  213. return
  214. }
  215. if err = conn.WriteMessage(messageType, reponseData); err != nil {
  216. logger.Log(0, "error during message writing:", err.Error())
  217. }
  218. go CheckNetRegAndHostUpdate(models.EnrollmentKey{Networks: netsToAdd}, &result.Host, "")
  219. case <-timeout: // the read from req.answerCh has timed out
  220. logger.Log(0, "timeout signal recv,exiting oauth socket conn")
  221. break
  222. }
  223. // Cleanly close the connection by sending a close message and then
  224. // waiting (with timeout) for the server to close the connection.
  225. if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
  226. logger.Log(0, "write close:", err.Error())
  227. return
  228. }
  229. }
  230. // CheckNetRegAndHostUpdate - run through networks and send a host update
  231. func CheckNetRegAndHostUpdate(key models.EnrollmentKey, h *models.Host, username string) {
  232. // publish host update through MQ
  233. for _, netID := range key.Networks {
  234. if network, err := logic.GetNetwork(netID); err == nil {
  235. if network.AutoJoin == "false" {
  236. if logic.DoesHostExistinTheNetworkAlready(h, models.NetworkID(netID)) {
  237. continue
  238. }
  239. if err := (&schema.PendingHost{
  240. HostID: h.ID.String(),
  241. Network: netID,
  242. }).CheckIfPendingHostExists(db.WithContext(context.TODO())); err == nil {
  243. continue
  244. }
  245. keyB, _ := json.Marshal(key)
  246. // add host to pending host table
  247. p := schema.PendingHost{
  248. ID: uuid.NewString(),
  249. HostID: h.ID.String(),
  250. Hostname: h.Name,
  251. Network: netID,
  252. PublicKey: h.PublicKey.String(),
  253. OS: h.OS,
  254. Location: h.Location,
  255. Version: h.Version,
  256. EnrollmentKey: keyB,
  257. RequestedAt: time.Now().UTC(),
  258. }
  259. p.Create(db.WithContext(context.TODO()))
  260. continue
  261. }
  262. logic.LogEvent(&models.Event{
  263. Action: models.JoinHostToNet,
  264. Source: models.Subject{
  265. ID: key.Value,
  266. Name: key.Tags[0],
  267. Type: models.EnrollmentKeySub,
  268. },
  269. TriggeredBy: username,
  270. Target: models.Subject{
  271. ID: h.ID.String(),
  272. Name: h.Name,
  273. Type: models.DeviceSub,
  274. },
  275. NetworkID: models.NetworkID(netID),
  276. Origin: models.Dashboard,
  277. })
  278. newNode, err := logic.UpdateHostNetwork(h, netID, true)
  279. if err == nil || strings.Contains(err.Error(), "host already part of network") {
  280. if len(key.Groups) > 0 {
  281. newNode.Tags = make(map[models.TagID]struct{})
  282. for _, tagI := range key.Groups {
  283. newNode.Tags[tagI] = struct{}{}
  284. }
  285. logic.UpsertNode(newNode)
  286. }
  287. if key.Relay != uuid.Nil && !newNode.IsRelayed {
  288. // check if relay node exists and acting as relay
  289. relaynode, err := logic.GetNodeByID(key.Relay.String())
  290. if err == nil && relaynode.IsGw && relaynode.Network == newNode.Network {
  291. slog.Error(fmt.Sprintf("adding relayed node %s to relay %s on network %s", newNode.ID.String(), key.Relay.String(), netID))
  292. newNode.IsRelayed = true
  293. newNode.RelayedBy = key.Relay.String()
  294. updatedRelayNode := relaynode
  295. updatedRelayNode.RelayedNodes = append(updatedRelayNode.RelayedNodes, newNode.ID.String())
  296. logic.UpdateRelayed(&relaynode, &updatedRelayNode)
  297. if err := logic.UpsertNode(&updatedRelayNode); err != nil {
  298. slog.Error("failed to update node", "nodeid", key.Relay.String())
  299. }
  300. if err := logic.UpsertNode(newNode); err != nil {
  301. slog.Error("failed to update node", "nodeid", key.Relay.String())
  302. }
  303. } else {
  304. slog.Error("failed to relay node. maybe specified relay node is actually not a relay? Or the relayed node is not in the same network with relay?", "err", err)
  305. }
  306. }
  307. if err != nil && strings.Contains(err.Error(), "host already part of network") {
  308. continue
  309. }
  310. } else {
  311. logger.Log(0, "failed to add host to network:", h.ID.String(), h.Name, netID, err.Error())
  312. continue
  313. }
  314. logger.Log(1, "added new node", newNode.ID.String(), "to host", h.Name)
  315. hostactions.AddAction(models.HostUpdate{
  316. Action: models.JoinHostToNetwork,
  317. Host: *h,
  318. Node: *newNode,
  319. })
  320. if h.IsDefault {
  321. // make host failover
  322. logic.CreateFailOver(*newNode)
  323. // make host remote access gateway
  324. logic.CreateIngressGateway(netID, newNode.ID.String(), models.IngressRequest{})
  325. logic.CreateRelay(models.RelayRequest{
  326. NodeID: newNode.ID.String(),
  327. NetID: netID,
  328. })
  329. }
  330. }
  331. }
  332. if servercfg.IsMessageQueueBackend() {
  333. mq.HostUpdate(&models.HostUpdate{
  334. Action: models.RequestAck,
  335. Host: *h,
  336. })
  337. if err := mq.PublishPeerUpdate(false); err != nil {
  338. logger.Log(0, "failed to publish peer update during registration -", err.Error())
  339. }
  340. }
  341. }
  342. func handleHostRegErr(conn *websocket.Conn, err error) {
  343. _ = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  344. if err != nil {
  345. logger.Log(0, "error during host registration via auth:", err.Error())
  346. }
  347. }