Browse Source

Merge pull request #1574 from gravitl/bugfix_v0.16.0_sso_errors

fixing sso error handling
Alex Feiszli 2 years ago
parent
commit
b2e5e410dd

+ 6 - 4
auth/nodecallback.go

@@ -155,8 +155,11 @@ func returnErrTemplate(uname, message, state string, ncache *netcache.CValue) []
 // Listens in /oidc/register/:regKey.
 func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) {
 
-	logger.Log(1, "RegisterNodeSSO\n")
-
+	if auth_provider == nil {
+		w.WriteHeader(http.StatusBadRequest)
+		w.Write([]byte("invalid login attempt"))
+		return
+	}
 	vars := mux.Vars(r)
 
 	// machineKeyStr this is not key but state
@@ -165,8 +168,7 @@ func RegisterNodeSSO(w http.ResponseWriter, r *http.Request) {
 
 	if machineKeyStr == "" {
 		w.WriteHeader(http.StatusBadRequest)
-		w.Write([]byte("Wrong params"))
-		logger.Log(0, "Wrong params ", machineKeyStr)
+		w.Write([]byte("invalid login attempt"))
 		return
 	}
 

+ 19 - 5
auth/nodesession.go

@@ -23,7 +23,6 @@ import (
 // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
 func SessionHandler(conn *websocket.Conn) {
 	defer conn.Close()
-	logger.Log(1, "Running  sessionHandler")
 
 	// If reached here we have a session from user to handle...
 	messageType, message, err := conn.ReadMessage()
@@ -58,12 +57,20 @@ func SessionHandler(conn *websocket.Conn) {
 	defer close(answer)
 	defer close(timeout)
 
+	if _, err = logic.GetNetwork(loginMessage.Network); err != nil {
+		err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
+		if err != nil {
+			logger.Log(0, "error during message writing:", err.Error())
+		}
+		return
+	}
+
 	if loginMessage.User != "" { // handle basic auth
 		// verify that server supports basic auth, then authorize the request with given credentials
 		// check if user is allowed to join via node sso
 		// i.e. user is admin or user has network permissions
 		if !servercfg.IsBasicAuthEnabled() {
-			err = conn.WriteMessage(messageType, []byte("Basic Auth Disabled"))
+			err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			if err != nil {
 				logger.Log(0, "error during message writing:", err.Error())
 			}
@@ -73,7 +80,7 @@ func SessionHandler(conn *websocket.Conn) {
 			Password: loginMessage.Password,
 		})
 		if err != nil {
-			err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("Failed to authenticate, %s.", loginMessage.User)))
+			err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			if err != nil {
 				logger.Log(0, "error during message writing:", err.Error())
 			}
@@ -81,7 +88,7 @@ func SessionHandler(conn *websocket.Conn) {
 		}
 		user, err := isUserIsAllowed(loginMessage.User, loginMessage.Network, false)
 		if err != nil {
-			err = conn.WriteMessage(messageType, []byte(fmt.Sprintf("%s lacks permission to join.", loginMessage.User)))
+			err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			if err != nil {
 				logger.Log(0, "error during message writing:", err.Error())
 			}
@@ -99,6 +106,13 @@ func SessionHandler(conn *websocket.Conn) {
 			return
 		}
 	} else { // handle SSO / OAuth
+		if auth_provider == nil {
+			err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
+			if err != nil {
+				logger.Log(0, "error during message writing:", err.Error())
+			}
+			return
+		}
 		redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
 		err = conn.WriteMessage(messageType, []byte(redirectUrl))
 		if err != nil {
@@ -135,7 +149,7 @@ func SessionHandler(conn *websocket.Conn) {
 	case <-timeout:
 		logger.Log(0, "Authentication server time out for a node on network", loginMessage.Network)
 		// the read from req.answerCh has timed out
-		err = conn.WriteMessage(messageType, []byte("Authentication server time out"))
+		err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 		if err != nil {
 			logger.Log(0, "Error during message writing:", err.Error())
 		}

+ 1 - 1
logger/logger.go

@@ -134,7 +134,7 @@ func Retrieve(filePath string) string {
 
 // FatalLog - exits os after logging
 func FatalLog(message ...string) {
-	fmt.Printf("[netmaker] Fatal: %s \n", MakeString(" ", message...))
+	fmt.Printf("[%s] Fatal: %s \n", program, MakeString(" ", message...))
 	os.Exit(2)
 }
 

+ 2 - 3
netclient/command/commands.go

@@ -30,14 +30,13 @@ func Join(cfg *config.ClientConfig, privateKey string) error {
 		logger.Log(1, "Logging into %s via:", cfg.Network, cfg.SsoServer)
 		err = functions.JoinViaSSo(cfg, privateKey)
 		if err != nil {
-			logger.Log(0, "Join via OIDC failed: ", err.Error())
+			logger.Log(0, "Join failed: ", err.Error())
 			return err
 		}
 
 		if cfg.AccessKey == "" {
-			return errors.New("failed to get access key")
+			return errors.New("login failed")
 		}
-		logger.Log(1, "Got an access key to ", cfg.Network, " via:", cfg.SsoServer)
 	}
 
 	logger.Log(1, "Joining network: ", cfg.Network)

+ 1 - 1
netclient/daemon/freebsd.go

@@ -28,7 +28,7 @@ func SetupFreebsdDaemon() error {
 	}
 	err = ncutils.Copy(binarypath, EXEC_DIR+"netclient")
 	if err != nil {
-		log.Println(err)
+		logger.Log(0, err.Error())
 		return err
 	}
 

+ 1 - 1
netclient/daemon/macos.go

@@ -25,7 +25,7 @@ func SetupMacDaemon() error {
 	}
 	err = ncutils.Copy(binarypath, MAC_EXEC_DIR+"netclient")
 	if err != nil {
-		log.Println(err)
+		logger.Log(0, err.Error())
 		return err
 	}
 

+ 3 - 3
netclient/daemon/systemd.go

@@ -38,7 +38,7 @@ func SetupSystemDDaemon() error {
 	}
 	err = ncutils.Copy(binarypath, EXEC_DIR+"netclient")
 	if err != nil {
-		log.Println(err)
+		logger.Log(0, err.Error())
 		return err
 	}
 
@@ -64,7 +64,7 @@ WantedBy=multi-user.target
 	if !ncutils.FileExists("/etc/systemd/system/netclient.service") {
 		err = os.WriteFile("/etc/systemd/system/netclient.service", servicebytes, 0644)
 		if err != nil {
-			log.Println(err)
+			logger.Log(0, err.Error())
 			return err
 		}
 	}
@@ -106,7 +106,7 @@ func RemoveSystemDServices() error {
 	var err error
 	if !ncutils.IsWindows() && isOnlyService() {
 		if err != nil {
-			log.Println(err)
+			logger.Log(0, err.Error())
 		}
 		ncutils.RunCmd("systemctl disable netclient.service", false)
 		ncutils.RunCmd("systemctl disable netclient.timer", false)

+ 1 - 2
netclient/functions/common.go

@@ -301,8 +301,7 @@ func WipeLocal(cfg *config.ClientConfig) error {
 	if cfg.Node.Interface != "" {
 		if ncutils.FileExists(dir + cfg.Node.Interface + ".conf") {
 			if err := os.Remove(dir + cfg.Node.Interface + ".conf"); err != nil {
-				log.Println("error removing .conf:")
-				log.Println(err.Error())
+				logger.Log(0, err.Error())
 				fail = true
 			}
 		}

+ 6 - 1
netclient/functions/join.go

@@ -82,6 +82,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 		}
 		loginMsg.User = global_settings.User
 		loginMsg.Password = string(pass)
+		fmt.Println("attempting login...")
 	}
 
 	msgTx, err := json.Marshal(loginMsg)
@@ -101,7 +102,6 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 		// Wait to receive something from server
 		_, msg, err := conn.ReadMessage()
 		if err != nil {
-			log.Println("Error in receive:", err)
 			return err
 		}
 		// Print message from the netmaker controller to the user
@@ -121,6 +121,11 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 		for {
 			msgType, msg, err := conn.ReadMessage()
 			if err != nil {
+				if msgType < 0 {
+					logger.Log(1, "received close message from server")
+					done <- struct{}{}
+					return
+				}
 				// Error reading a message from the server
 				if !strings.Contains(err.Error(), "normal") {
 					logger.Log(0, "read:", err.Error())

+ 2 - 2
netclient/main.go

@@ -4,10 +4,10 @@
 package main
 
 import (
-	"log"
 	"os"
 	"runtime/debug"
 
+	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/netclient/cli_options"
 	"github.com/gravitl/netmaker/netclient/config"
 	"github.com/gravitl/netmaker/netclient/functions"
@@ -47,7 +47,7 @@ func main() {
 	} else {
 		err := app.Run(os.Args)
 		if err != nil {
-			log.Fatal(err)
+			logger.FatalLog(err.Error())
 		}
 	}
 }