nodesession.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package auth
  2. import (
  3. "encoding/hex"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "github.com/gorilla/websocket"
  9. "github.com/gravitl/netmaker/logger"
  10. "github.com/gravitl/netmaker/logic"
  11. "github.com/gravitl/netmaker/logic/pro/netcache"
  12. "github.com/gravitl/netmaker/models"
  13. "github.com/gravitl/netmaker/models/promodels"
  14. "github.com/gravitl/netmaker/servercfg"
  15. )
  16. // SessionHandler - called by the HTTP router when user
  17. // is calling netclient with --login-server parameter in order to authenticate
  18. // via SSO mechanism by OAuth2 protocol flow.
  19. // This triggers a session start and it is managed by the flow implemented here and callback
  20. // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
  21. func SessionHandler(conn *websocket.Conn) {
  22. defer conn.Close()
  23. // If reached here we have a session from user to handle...
  24. messageType, message, err := conn.ReadMessage()
  25. if err != nil {
  26. logger.Log(0, "Error during message reading:", err.Error())
  27. return
  28. }
  29. var loginMessage promodels.LoginMsg
  30. err = json.Unmarshal(message, &loginMessage)
  31. if err != nil {
  32. logger.Log(0, "Failed to unmarshall data err=", err.Error())
  33. return
  34. }
  35. logger.Log(1, "SSO node join attempted with info network:", loginMessage.Network, "node identifier:", loginMessage.Mac, "user:", loginMessage.User)
  36. req := new(netcache.CValue)
  37. req.Value = string(loginMessage.Mac)
  38. req.Network = loginMessage.Network
  39. req.Pass = ""
  40. req.User = ""
  41. // Add any extra parameter provided in the configuration to the Authorize Endpoint request??
  42. stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length)))
  43. if err := netcache.Set(stateStr, req); err != nil {
  44. logger.Log(0, "Failed to process sso request -", err.Error())
  45. return
  46. }
  47. // Wait for the user to finish his auth flow...
  48. // TBD: what should be the timeout here ?
  49. timeout := make(chan bool, 1)
  50. answer := make(chan string, 1)
  51. defer close(answer)
  52. defer close(timeout)
  53. if _, err = logic.GetNetwork(loginMessage.Network); err != nil {
  54. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  55. if err != nil {
  56. logger.Log(0, "error during message writing:", err.Error())
  57. }
  58. return
  59. }
  60. if loginMessage.User != "" { // handle basic auth
  61. // verify that server supports basic auth, then authorize the request with given credentials
  62. // check if user is allowed to join via node sso
  63. // i.e. user is admin or user has network permissions
  64. if !servercfg.IsBasicAuthEnabled() {
  65. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  66. if err != nil {
  67. logger.Log(0, "error during message writing:", err.Error())
  68. }
  69. }
  70. _, err := logic.VerifyAuthRequest(models.UserAuthParams{
  71. UserName: loginMessage.User,
  72. Password: loginMessage.Password,
  73. })
  74. if err != nil {
  75. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  76. if err != nil {
  77. logger.Log(0, "error during message writing:", err.Error())
  78. }
  79. return
  80. }
  81. user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false)
  82. if err != nil {
  83. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  84. if err != nil {
  85. logger.Log(0, "error during message writing:", err.Error())
  86. }
  87. return
  88. }
  89. accessToken, err := requestAccessKey(loginMessage.Network, 1, user.UserName)
  90. if err != nil {
  91. req.Pass = fmt.Sprintf("Error from the netmaker controller %s", err.Error())
  92. } else {
  93. req.Pass = fmt.Sprintf("AccessToken: %s", accessToken)
  94. }
  95. // Give the user the access token via Pass in the DB
  96. if err = netcache.Set(stateStr, req); err != nil {
  97. logger.Log(0, "machine failed to complete join on network,", loginMessage.Network, "-", err.Error())
  98. return
  99. }
  100. } else { // handle SSO / OAuth
  101. if auth_provider == nil {
  102. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  103. if err != nil {
  104. logger.Log(0, "error during message writing:", err.Error())
  105. }
  106. return
  107. }
  108. redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
  109. err = conn.WriteMessage(messageType, []byte(redirectUrl))
  110. if err != nil {
  111. logger.Log(0, "error during message writing:", err.Error())
  112. }
  113. }
  114. go func() {
  115. for {
  116. cachedReq, err := netcache.Get(stateStr)
  117. if err != nil {
  118. if strings.Contains(err.Error(), "expired") {
  119. logger.Log(0, "timeout occurred while waiting for SSO on network", loginMessage.Network)
  120. timeout <- true
  121. break
  122. }
  123. continue
  124. } else if cachedReq.Pass != "" {
  125. logger.Log(0, "node SSO process completed for user", cachedReq.User, "on network", loginMessage.Network)
  126. answer <- cachedReq.Pass
  127. break
  128. }
  129. time.Sleep(500) // try it 2 times per second to see if auth is completed
  130. }
  131. }()
  132. select {
  133. case result := <-answer:
  134. // a read from req.answerCh has occurred
  135. err = conn.WriteMessage(messageType, []byte(result))
  136. if err != nil {
  137. logger.Log(0, "Error during message writing:", err.Error())
  138. }
  139. case <-timeout:
  140. logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network)
  141. // the read from req.answerCh has timed out
  142. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  143. if err != nil {
  144. logger.Log(0, "Error during message writing:", err.Error())
  145. }
  146. }
  147. // The entry is not needed anymore, but we will let the producer to close it to avoid panic cases
  148. if err = netcache.Del(stateStr); err != nil {
  149. logger.Log(0, "failed to remove node SSO cache entry", err.Error())
  150. }
  151. // Cleanly close the connection by sending a close message and then
  152. // waiting (with timeout) for the server to close the connection.
  153. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  154. if err != nil {
  155. logger.Log(0, "write close:", err.Error())
  156. return
  157. }
  158. }