nodesession.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. package auth
  2. import (
  3. "encoding/hex"
  4. "encoding/json"
  5. "fmt"
  6. "strings"
  7. "time"
  8. "github.com/gorilla/websocket"
  9. "github.com/gravitl/netmaker/logger"
  10. "github.com/gravitl/netmaker/logic"
  11. "github.com/gravitl/netmaker/logic/pro/netcache"
  12. "github.com/gravitl/netmaker/models"
  13. "github.com/gravitl/netmaker/models/promodels"
  14. "github.com/gravitl/netmaker/servercfg"
  15. )
  16. // SessionHandler - called by the HTTP router when user
  17. // is calling netclient with --login-server parameter in order to authenticate
  18. // via SSO mechanism by OAuth2 protocol flow.
  19. // This triggers a session start and it is managed by the flow implemented here and callback
  20. // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
  21. func SessionHandler(conn *websocket.Conn) {
  22. defer conn.Close()
  23. logger.Log(1, "Running sessionHandler")
  24. // If reached here we have a session from user to handle...
  25. messageType, message, err := conn.ReadMessage()
  26. if err != nil {
  27. logger.Log(0, "Error during message reading:", err.Error())
  28. return
  29. }
  30. var loginMessage promodels.LoginMsg
  31. err = json.Unmarshal(message, &loginMessage)
  32. if err != nil {
  33. logger.Log(0, "Failed to unmarshall data err=", err.Error())
  34. return
  35. }
  36. logger.Log(1, "SSO node join attempted with info network:", loginMessage.Network, "node identifier:", loginMessage.Mac, "user:", loginMessage.User)
  37. req := new(netcache.CValue)
  38. req.Value = string(loginMessage.Mac)
  39. req.Network = loginMessage.Network
  40. req.Pass = ""
  41. req.User = ""
  42. // Add any extra parameter provided in the configuration to the Authorize Endpoint request??
  43. stateStr := hex.EncodeToString([]byte(logic.RandomString(node_signin_length)))
  44. if err := netcache.Set(stateStr, req); err != nil {
  45. logger.Log(0, "Failed to process sso request -", err.Error())
  46. return
  47. }
  48. // Wait for the user to finish his auth flow...
  49. // TBD: what should be the timeout here ?
  50. timeout := make(chan bool, 1)
  51. answer := make(chan string, 1)
  52. defer close(answer)
  53. defer close(timeout)
  54. if loginMessage.User != "" { // handle basic auth
  55. // verify that server supports basic auth, then authorize the request with given credentials
  56. // check if user is allowed to join via node sso
  57. // i.e. user is admin or user has network permissions
  58. if !servercfg.IsBasicAuthEnabled() {
  59. err = conn.WriteMessage(messageType, []byte("Basic Auth Disabled"))
  60. if err != nil {
  61. logger.Log(0, "error during message writing:", err.Error())
  62. }
  63. }
  64. _, err := logic.VerifyAuthRequest(models.UserAuthParams{
  65. UserName: loginMessage.User,
  66. Password: loginMessage.Password,
  67. })
  68. if err != nil {
  69. err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("Failed to authenticate, %s.", loginMessage.User)))
  70. if err != nil {
  71. logger.Log(0, "error during message writing:", err.Error())
  72. }
  73. return
  74. }
  75. user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false)
  76. if err != nil {
  77. err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("%s lacks permission to join.", loginMessage.User)))
  78. if err != nil {
  79. logger.Log(0, "error during message writing:", err.Error())
  80. }
  81. return
  82. }
  83. accessToken, err := requestAccessKey(loginMessage.Network, 1, user.UserName)
  84. if err != nil {
  85. req.Pass = fmt.Sprintf("Error from the netmaker controller %s", err.Error())
  86. } else {
  87. req.Pass = fmt.Sprintf("AccessToken: %s", accessToken)
  88. }
  89. // Give the user the access token via Pass in the DB
  90. if err = netcache.Set(stateStr, req); err != nil {
  91. logger.Log(0, "machine failed to complete join on network,", loginMessage.Network, "-", err.Error())
  92. return
  93. }
  94. } else { // handle SSO / OAuth
  95. redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
  96. err = conn.WriteMessage(messageType, []byte(redirectUrl))
  97. if err != nil {
  98. logger.Log(0, "error during message writing:", err.Error())
  99. }
  100. }
  101. go func() {
  102. for {
  103. cachedReq, err := netcache.Get(stateStr)
  104. if err != nil {
  105. if strings.Contains(err.Error(), "expired") {
  106. logger.Log(0, "timeout occurred while waiting for SSO on network", loginMessage.Network)
  107. timeout <- true
  108. break
  109. }
  110. continue
  111. } else if cachedReq.Pass != "" {
  112. logger.Log(0, "node SSO process completed for user", cachedReq.User, "on network", loginMessage.Network)
  113. answer <- cachedReq.Pass
  114. break
  115. }
  116. time.Sleep(500) // try it 2 times per second to see if auth is completed
  117. }
  118. }()
  119. select {
  120. case result := <-answer:
  121. // a read from req.answerCh has occurred
  122. err = conn.WriteMessage(messageType, []byte(result))
  123. if err != nil {
  124. logger.Log(0, "Error during message writing:", err.Error())
  125. }
  126. case <-timeout:
  127. logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network)
  128. // the read from req.answerCh has timed out
  129. err = conn.WriteMessage(messageType, []byte("Authentication server time out"))
  130. if err != nil {
  131. logger.Log(0, "Error during message writing:", err.Error())
  132. }
  133. }
  134. // The entry is not needed anymore, but we will let the producer to close it to avoid panic cases
  135. if err = netcache.Del(stateStr); err != nil {
  136. logger.Log(0, "failed to remove node SSO cache entry", err.Error())
  137. }
  138. // Cleanly close the connection by sending a close message and then
  139. // waiting (with timeout) for the server to close the connection.
  140. err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
  141. if err != nil {
  142. logger.Log(0, "write close:", err.Error())
  143. return
  144. }
  145. }