nodesession.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package auth
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "strings"
  6. "time"
  7. "github.com/gorilla/websocket"
  8. "github.com/gravitl/netmaker/logger"
  9. "github.com/gravitl/netmaker/logic"
  10. "github.com/gravitl/netmaker/logic/pro/netcache"
  11. "github.com/gravitl/netmaker/models"
  12. "github.com/gravitl/netmaker/models/promodels"
  13. "github.com/gravitl/netmaker/servercfg"
  14. )
  15. // SessionHandler - called by the HTTP router when user
  16. // is calling netclient with --login-server parameter in order to authenticate
  17. // via SSO mechanism by OAuth2 protocol flow.
  18. // This triggers a session start and it is managed by the flow implemented here and callback
  19. // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
  20. func SessionHandler(conn *websocket.Conn) {
  21. defer conn.Close()
  22. // If reached here we have a session from user to handle...
  23. messageType, message, err := conn.ReadMessage()
  24. if err != nil {
  25. logger.Log(0, "Error during message reading:", err.Error())
  26. return
  27. }
  28. var loginMessage promodels.LoginMsg
  29. err = json.Unmarshal(message, &loginMessage)
  30. if err != nil {
  31. logger.Log(0, "Failed to unmarshall data err=", err.Error())
  32. return
  33. }
  34. logger.Log(1, "SSO node join attempted with info network:", loginMessage.Network, "node identifier:", loginMessage.Mac, "user:", loginMessage.User)
  35. req := new(netcache.CValue)
  36. req.Value = string(loginMessage.Mac)
  37. req.Network = loginMessage.Network
  38. req.Pass = ""
  39. req.User = ""
  40. // Add any extra parameter provided in the configuration to the Authorize Endpoint request??
  41. stateStr := logic.RandomString(node_signin_length)
  42. if err := netcache.Set(stateStr, req); err != nil {
  43. logger.Log(0, "Failed to process sso request -", err.Error())
  44. return
  45. }
  46. // Wait for the user to finish his auth flow...
  47. // TBD: what should be the timeout here ?
  48. timeout := make(chan bool, 1)
  49. answer := make(chan string, 1)
  50. defer close(answer)
  51. defer close(timeout)
  52. if _, err = logic.GetNetwork(loginMessage.Network); err != nil {
  53. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  54. if err != nil {
  55. logger.Log(0, "error during message writing:", err.Error())
  56. }
  57. return
  58. }
  59. if loginMessage.User != "" { // handle basic auth
  60. // verify that server supports basic auth, then authorize the request with given credentials
  61. // check if user is allowed to join via node sso
  62. // i.e. user is admin or user has network permissions
  63. if !servercfg.IsBasicAuthEnabled() {
  64. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  65. if err != nil {
  66. logger.Log(0, "error during message writing:", err.Error())
  67. }
  68. }
  69. _, err := logic.VerifyAuthRequest(models.UserAuthParams{
  70. UserName: loginMessage.User,
  71. Password: loginMessage.Password,
  72. })
  73. if err != nil {
  74. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  75. if err != nil {
  76. logger.Log(0, "error during message writing:", err.Error())
  77. }
  78. return
  79. }
  80. user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false)
  81. if err != nil {
  82. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  83. if err != nil {
  84. logger.Log(0, "error during message writing:", err.Error())
  85. }
  86. return
  87. }
  88. accessToken, err := requestAccessKey(loginMessage.Network, 1, user.UserName)
  89. if err != nil {
  90. req.Pass = fmt.Sprintf("Error from the netmaker controller %s", err.Error())
  91. } else {
  92. req.Pass = fmt.Sprintf("AccessToken: %s", accessToken)
  93. }
  94. // Give the user the access token via Pass in the DB
  95. if err = netcache.Set(stateStr, req); err != nil {
  96. logger.Log(0, "machine failed to complete join on network,", loginMessage.Network, "-", err.Error())
  97. return
  98. }
  99. } else { // handle SSO / OAuth
  100. if auth_provider == nil {
  101. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  102. if err != nil {
  103. logger.Log(0, "error during message writing:", err.Error())
  104. }
  105. return
  106. }
  107. redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
  108. err = conn.WriteMessage(messageType, []byte(redirectUrl))
  109. if err != nil {
  110. logger.Log(0, "error during message writing:", err.Error())
  111. }
  112. }
  113. go func() {
  114. for {
  115. cachedReq, err := netcache.Get(stateStr)
  116. if err != nil {
  117. if strings.Contains(err.Error(), "expired") {
  118. logger.Log(0, "timeout occurred while waiting for SSO on network", loginMessage.Network)
  119. timeout <- true
  120. break
  121. }
  122. continue
  123. } else if cachedReq.Pass != "" {
  124. logger.Log(0, "node SSO process completed for user", cachedReq.User, "on network", loginMessage.Network)
  125. answer <- cachedReq.Pass
  126. break
  127. }
  128. time.Sleep(500) // try it 2 times per second to see if auth is completed
  129. }
  130. }()
  131. select {
  132. case result := <-answer:
  133. // a read from req.answerCh has occurred
  134. err = conn.WriteMessage(messageType, []byte(result))
  135. if err != nil {
  136. logger.Log(0, "Error during message writing:", err.Error())
  137. }
  138. case <-timeout:
  139. logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network)
  140. // the read from req.answerCh has timed out
  141. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  142. if err != nil {
  143. logger.Log(0, "Error during message writing:", err.Error())
  144. }
  145. }
  146. // The entry is not needed anymore, but we will let the producer to close it to avoid panic cases
  147. if err = netcache.Del(stateStr); err != nil {
  148. logger.Log(0, "failed to remove node SSO cache entry", err.Error())
  149. }
  150. // Cleanly close the connection by sending a close message and then
  151. // waiting (with timeout) for the server to close the connection.
  152. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  153. if err != nil {
  154. logger.Log(0, "write close:", err.Error())
  155. return
  156. }
  157. }