瀏覽代碼

NET-1061: check if user exists, handle oauth not configured for host SSO (#2917)

* add debug logs

* check if user exists, handle oauth not configured for host SSO

* check if user exists, handle oauth not configured for host SSO

* check if user exists, handle oauth not configured for host SSO

* quit when websocket is closed

* quit when websocket is closed

* quit when websocket is closed

* avoid sending msg onb closed channel

* add debug log

* exit when oauth state is deleted

* add debug log

* handle oauth state not valid with appropirate message

* handle oauth state not valid with appropirate message

* check for invalid oauth state

* rm debug logs
Abhishek K 1 年之前
父節點
當前提交
33846a5124
共有 7 個文件被更改,包括 67 次插入18 次删除
  1. 5 0
      auth/azure-ad.go
  2. 12 0
      auth/error.go
  3. 5 0
      auth/github.go
  4. 5 0
      auth/google.go
  5. 23 17
      auth/host_session.go
  6. 5 0
      auth/oidc.go
  7. 12 1
      auth/register_callback.go

+ 5 - 0
auth/azure-ad.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
@@ -58,6 +59,10 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 	var content, err = getAzureUserInfo(rState, rCode)
 	if err != nil {
 		logger.Log(1, "error when getting user info from azure:", err.Error())
+		if strings.Contains(err.Error(), "invalid oauth state") {
+			handleOauthNotValid(w)
+			return
+		}
 		handleOauthNotConfigured(w)
 		return
 	}

+ 12 - 0
auth/error.go

@@ -10,6 +10,12 @@ const oauthNotConfigured = `<!DOCTYPE html><html>
 </body>
 </html>`
 
+const oauthStateInvalid = `<!DOCTYPE html><html>
+<body>
+<h3>Invalid OAuth Session. Please re-try again.</h3>
+</body>
+</html>`
+
 const userNotAllowed = `<!DOCTYPE html><html>
 <body>
 <h3>Only administrators can access the Dashboard. Please contact your administrator to elevate your account.</h3>
@@ -86,6 +92,12 @@ func handleOauthNotConfigured(response http.ResponseWriter) {
 	response.Write([]byte(oauthNotConfigured))
 }
 
+func handleOauthNotValid(response http.ResponseWriter) {
+	response.Header().Set("Content-Type", "text/html; charset=utf-8")
+	response.WriteHeader(http.StatusBadRequest)
+	response.Write([]byte(oauthStateInvalid))
+}
+
 func handleSomethingWentWrong(response http.ResponseWriter) {
 	response.Header().Set("Content-Type", "text/html; charset=utf-8")
 	response.WriteHeader(http.StatusInternalServerError)

+ 5 - 0
auth/github.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/logger"
@@ -58,6 +59,10 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
 	var content, err = getGithubUserInfo(rState, rCode)
 	if err != nil {
 		logger.Log(1, "error when getting user info from github:", err.Error())
+		if strings.Contains(err.Error(), "invalid oauth state") {
+			handleOauthNotValid(w)
+			return
+		}
 		handleOauthNotConfigured(w)
 		return
 	}

+ 5 - 0
auth/google.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"io"
 	"net/http"
+	"strings"
 	"time"
 
 	"github.com/gravitl/netmaker/database"
@@ -60,6 +61,10 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
 	var content, err = getGoogleUserInfo(rState, rCode)
 	if err != nil {
 		logger.Log(1, "error when getting user info from google:", err.Error())
+		if strings.Contains(err.Error(), "invalid oauth state") {
+			handleOauthNotValid(w)
+			return
+		}
 		handleOauthNotConfigured(w)
 		return
 	}

+ 23 - 17
auth/host_session.go

@@ -3,7 +3,6 @@ package auth
 import (
 	"encoding/json"
 	"fmt"
-	"strings"
 	"time"
 
 	"github.com/google/uuid"
@@ -59,12 +58,12 @@ func SessionHandler(conn *websocket.Conn) {
 		logger.Log(0, "Failed to process sso request -", err.Error())
 		return
 	}
+	defer netcache.Del(stateStr)
 	// Wait for the user to finish his auth flow...
-	timeout := make(chan bool, 1)
+	timeout := make(chan bool, 2)
 	answer := make(chan netcache.CValue, 1)
 	defer close(answer)
 	defer close(timeout)
-
 	if len(registerMessage.User) > 0 { // handle basic auth
 		logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "user:", registerMessage.User)
 
@@ -111,6 +110,10 @@ func SessionHandler(conn *websocket.Conn) {
 		}
 	} else { // handle SSO / OAuth
 		if auth_provider == nil {
+			err = conn.WriteMessage(messageType, []byte("Oauth not configured"))
+			if err != nil {
+				logger.Log(0, "error during message writing:", err.Error())
+			}
 			err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			if err != nil {
 				logger.Log(0, "error during message writing:", err.Error())
@@ -125,22 +128,30 @@ func SessionHandler(conn *websocket.Conn) {
 		}
 	}
 
+	go func() {
+		for {
+			msgType, _, err := conn.ReadMessage()
+			if err != nil || msgType == websocket.CloseMessage {
+				netcache.Del(stateStr)
+				return
+			}
+		}
+	}()
+
 	go func() {
 		for {
 			cachedReq, err := netcache.Get(stateStr)
 			if err != nil {
-				if strings.Contains(err.Error(), "expired") {
-					logger.Log(1, "timeout occurred while waiting for SSO registration")
-					timeout <- true
-					break
-				}
-				continue
+				logger.Log(0, "oauth state has been deleted ", err.Error())
+				timeout <- true
+				break
+
 			} else if len(cachedReq.User) > 0 {
 				logger.Log(0, "host SSO process completed for user", cachedReq.User)
 				answer <- *cachedReq
 				break
 			}
-			time.Sleep(500) // try it 2 times per second to see if auth is completed
+			time.Sleep(time.Second)
 		}
 	}()
 
@@ -217,13 +228,8 @@ func SessionHandler(conn *websocket.Conn) {
 		}
 		go CheckNetRegAndHostUpdate(netsToAdd[:], &result.Host, uuid.Nil)
 	case <-timeout: // the read from req.answerCh has timed out
-		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
-			logger.Log(0, "error during timeout message writing:", err.Error())
-		}
-	}
-	// The entry is not needed anymore, but we will let the producer to close it to avoid panic cases
-	if err = netcache.Del(stateStr); err != nil {
-		logger.Log(0, "failed to remove node SSO cache entry", err.Error())
+		logger.Log(0, "timeout signal recv,exiting oauth socket conn")
+		break
 	}
 	// Cleanly close the connection by sending a close message and then
 	// waiting (with timeout) for the server to close the connection.

+ 5 - 0
auth/oidc.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"fmt"
 	"net/http"
+	"strings"
 	"time"
 
 	"github.com/coreos/go-oidc/v3/oidc"
@@ -71,6 +72,10 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
 	var content, err = getOIDCUserInfo(rState, rCode)
 	if err != nil {
 		logger.Log(1, "error when getting user info from callback:", err.Error())
+		if strings.Contains(err.Error(), "invalid oauth state") {
+			handleOauthNotValid(w)
+			return
+		}
 		handleOauthNotConfigured(w)
 		return
 	}

+ 12 - 1
auth/register_callback.go

@@ -68,7 +68,18 @@ func HandleHostSSOCallback(w http.ResponseWriter, r *http.Request) {
 		w.Write(response)
 		return
 	}
-
+	// check if user exists
+	user, err := logic.GetUser(userClaims.getUserName())
+	if err != nil {
+		handleOauthUserNotFound(w)
+		return
+	}
+	if !user.IsAdmin && !user.IsSuperAdmin {
+		response := returnErrTemplate(userClaims.getUserName(), "only admin users can register using SSO", state, reqKeyIf)
+		w.WriteHeader(http.StatusForbidden)
+		w.Write(response)
+		return
+	}
 	logger.Log(1, "registering host for user:", userClaims.getUserName(), reqKeyIf.Host.Name, reqKeyIf.Host.ID.String())
 
 	// Send OK to user in the browser