Browse Source

PR comments addressed

0xdcarns 2 years ago
parent
commit
b1b497faa4

+ 1 - 1
auth/google.go

@@ -88,7 +88,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 
 
 	logger.Log(1, "completed google OAuth sigin in for", content.Email)
 	logger.Log(1, "completed google OAuth sigin in for", content.Email)
-	http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?login="+jwt+"&user="+content.Email, http.StatusPermanentRedirect)
+	http.Redirect(w, r, fmt.Sprintf("%s/login?login=%s&user=%s", servercfg.GetFrontendURL(), jwt, content.Email), http.StatusPermanentRedirect)
 }
 }
 
 
 func getGoogleUserInfo(state string, code string) (*OAuthUser, error) {
 func getGoogleUserInfo(state string, code string) (*OAuthUser, error) {

+ 1 - 1
auth/nodecallback.go

@@ -58,7 +58,7 @@ func HandleNodeSSOCallback(w http.ResponseWriter, r *http.Request) {
 	// retrieve machinekey from state cache
 	// retrieve machinekey from state cache
 	reqKeyIf, machineKeyFoundErr := netcache.Get(state)
 	reqKeyIf, machineKeyFoundErr := netcache.Get(state)
 	if machineKeyFoundErr != nil {
 	if machineKeyFoundErr != nil {
-		logger.Log(0, "requested machine state key expired before authorisation completed -", err.Error())
+		logger.Log(0, "requested machine state key expired before authorisation completed -", machineKeyFoundErr.Error())
 		reqKeyIf = &netcache.CValue{
 		reqKeyIf = &netcache.CValue{
 			Network:    "invalid",
 			Network:    "invalid",
 			Value:      state,
 			Value:      state,

+ 3 - 4
auth/nodesession.go

@@ -19,7 +19,7 @@ import (
 // SessionHandler - called by the HTTP router when user
 // SessionHandler - called by the HTTP router when user
 // is calling netclient with --login-server parameter in order to authenticate
 // is calling netclient with --login-server parameter in order to authenticate
 // via SSO mechanism by OAuth2 protocol flow.
 // via SSO mechanism by OAuth2 protocol flow.
-// This triggers a session start and it is managed by the flow implmented here and callback
+// This triggers a session start and it is managed by the flow implemented here and callback
 // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
 // When this method finishes - the auth flow has finished either OK or by timeout or any other error occured
 func SessionHandler(conn *websocket.Conn) {
 func SessionHandler(conn *websocket.Conn) {
 	defer conn.Close()
 	defer conn.Close()
@@ -55,6 +55,8 @@ func SessionHandler(conn *websocket.Conn) {
 	// TBD: what should be the timeout here ?
 	// TBD: what should be the timeout here ?
 	timeout := make(chan bool, 1)
 	timeout := make(chan bool, 1)
 	answer := make(chan string, 1)
 	answer := make(chan string, 1)
+	defer close(answer)
+	defer close(timeout)
 
 
 	if loginMessage.User != "" { // handle basic auth
 	if loginMessage.User != "" { // handle basic auth
 		// verify that server supports basic auth, then authorize the request with given credentials
 		// verify that server supports basic auth, then authorize the request with given credentials
@@ -149,7 +151,4 @@ func SessionHandler(conn *websocket.Conn) {
 		logger.Log(0, "write close:", err.Error())
 		logger.Log(0, "write close:", err.Error())
 		return
 		return
 	}
 	}
-	time.After(time.Second)
-	close(answer)
-	close(timeout)
 }
 }

+ 0 - 1
auth/oidc.go

@@ -62,7 +62,6 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
 		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
 		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
 		return
 		return
 	}
 	}
-	logger.Log(3, "using state string:", oauth_state_string)
 	var url = auth_provider.AuthCodeURL(oauth_state_string)
 	var url = auth_provider.AuthCodeURL(oauth_state_string)
 	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 }
 }

+ 5 - 1
controllers/user.go

@@ -476,7 +476,11 @@ func socketHandler(w http.ResponseWriter, r *http.Request) {
 	// Upgrade our raw HTTP connection to a websocket based one
 	// Upgrade our raw HTTP connection to a websocket based one
 	conn, err := upgrader.Upgrade(w, r, nil)
 	conn, err := upgrader.Upgrade(w, r, nil)
 	if err != nil {
 	if err != nil {
-		logger.Log(0, "error during connection upgrade for node SSO sign-in:", err.Error())
+		logger.Log(0, "error during connection upgrade for node sign-in:", err.Error())
+		return
+	}
+	if conn == nil {
+		logger.Log(0, "failed to establish web-socket connection during node sign-in")
 		return
 		return
 	}
 	}
 	// Start handling the session
 	// Start handling the session

+ 1 - 1
ee/initialize.go

@@ -13,7 +13,7 @@ import (
 
 
 // InitEE - Initialize EE Logic
 // InitEE - Initialize EE Logic
 func InitEE() {
 func InitEE() {
-	SetIsEnterprise()
+	setIsEnterprise()
 	models.SetLogo(retrieveEELogo())
 	models.SetLogo(retrieveEELogo())
 	controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers)
 	controller.HttpHandlers = append(controller.HttpHandlers, ee_controllers.MetricHandlers)
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {
 	logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func() {

+ 5 - 5
ee/util.go

@@ -8,16 +8,16 @@ import (
 
 
 var isEnterprise bool
 var isEnterprise bool
 
 
-// SetIsEnterprise - sets server to use enterprise features
-func SetIsEnterprise() {
-	isEnterprise = true
-}
-
 // IsEnterprise - checks if enterprise binary or not
 // IsEnterprise - checks if enterprise binary or not
 func IsEnterprise() bool {
 func IsEnterprise() bool {
 	return isEnterprise
 	return isEnterprise
 }
 }
 
 
+// setIsEnterprise - sets server to use enterprise features
+func setIsEnterprise() {
+	isEnterprise = true
+}
+
 // base64encode - base64 encode helper function
 // base64encode - base64 encode helper function
 func base64encode(input []byte) string {
 func base64encode(input []byte) string {
 	return base64.StdEncoding.EncodeToString(input)
 	return base64.StdEncoding.EncodeToString(input)

+ 0 - 1
mq/publishers.go

@@ -229,7 +229,6 @@ func collectServerMetrics(networks []models.Network) {
 
 
 func pushMetricsToExporter(metrics models.Metrics) error {
 func pushMetricsToExporter(metrics models.Metrics) error {
 	logger.Log(2, "----> Pushing metrics to exporter")
 	logger.Log(2, "----> Pushing metrics to exporter")
-	SetupMQTT()
 	data, err := json.Marshal(metrics)
 	data, err := json.Marshal(metrics)
 	if err != nil {
 	if err != nil {
 		return errors.New("failed to marshal metrics: " + err.Error())
 		return errors.New("failed to marshal metrics: " + err.Error())

+ 11 - 10
netclient/functions/join.go

@@ -13,7 +13,6 @@ import (
 	"runtime"
 	"runtime"
 	"strings"
 	"strings"
 	"syscall"
 	"syscall"
-	"time"
 
 
 	"github.com/gorilla/websocket"
 	"github.com/gorilla/websocket"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logger"
@@ -56,7 +55,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 	// Dial the netmaker server controller
 	// Dial the netmaker server controller
 	conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil)
 	conn, _, err := websocket.DefaultDialer.Dial(socketUrl, nil)
 	if err != nil {
 	if err != nil {
-		logger.Log(0, fmt.Sprintf("Error connecting to %s : %s", cfg.Server.API, err.Error()))
+		logger.Log(0, fmt.Sprintf("error connecting to %s : %s", cfg.Server.API, err.Error()))
 		return err
 		return err
 	}
 	}
 	// Don't forget to close when finished
 	// Don't forget to close when finished
@@ -113,14 +112,14 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 	// An answer from the server.
 	// An answer from the server.
 	// Server waits ~5 min - If takes too long timeout will be triggered by the server
 	// Server waits ~5 min - If takes too long timeout will be triggered by the server
 	done := make(chan struct{})
 	done := make(chan struct{})
+	defer close(done)
 	// Following code will run in a separate go routine
 	// Following code will run in a separate go routine
 	// it reads a message from the server which either contains 'AccessToken:' string or not
 	// it reads a message from the server which either contains 'AccessToken:' string or not
 	// if not - then it contains an Error to display.
 	// if not - then it contains an Error to display.
 	// if yes - then AccessToken is to be used to proceed joining the network
 	// if yes - then AccessToken is to be used to proceed joining the network
 	go func() {
 	go func() {
-		defer close(done)
 		for {
 		for {
-			_, msg, err := conn.ReadMessage()
+			msgType, msg, err := conn.ReadMessage()
 			if err != nil {
 			if err != nil {
 				// Error reading a message from the server
 				// Error reading a message from the server
 				if !strings.Contains(err.Error(), "normal") {
 				if !strings.Contains(err.Error(), "normal") {
@@ -128,13 +127,19 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 				}
 				}
 				return
 				return
 			}
 			}
+
+			if msgType == websocket.CloseMessage {
+				logger.Log(1, "received close message from server")
+				done <- struct{}{}
+				return
+			}
 			// Get the access token from the response
 			// Get the access token from the response
 			if strings.Contains(string(msg), "AccessToken: ") {
 			if strings.Contains(string(msg), "AccessToken: ") {
 				// Access was granted
 				// Access was granted
 				rxToken := strings.TrimPrefix(string(msg), "AccessToken: ")
 				rxToken := strings.TrimPrefix(string(msg), "AccessToken: ")
 				accesstoken, err := config.ParseAccessToken(rxToken)
 				accesstoken, err := config.ParseAccessToken(rxToken)
 				if err != nil {
 				if err != nil {
-					log.Printf("Failed to parse received access token %s,err=%s\n", accesstoken, err.Error())
+					logger.Log(0, fmt.Sprintf("failed to parse received access token %s,err=%s\n", accesstoken, err.Error()))
 					return
 					return
 				}
 				}
 
 
@@ -159,7 +164,7 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 			logger.Log(1, "finished")
 			logger.Log(1, "finished")
 			return nil
 			return nil
 		case <-interrupt:
 		case <-interrupt:
-			log.Println("interrupt")
+			logger.Log(0, "interrupt received, closing connection")
 			// Cleanly close the connection by sending a close message and then
 			// Cleanly close the connection by sending a close message and then
 			// waiting (with timeout) for the server to close the connection.
 			// waiting (with timeout) for the server to close the connection.
 			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
 			err := conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
@@ -167,10 +172,6 @@ func JoinViaSSo(cfg *config.ClientConfig, privateKey string) error {
 				logger.Log(0, "write close:", err.Error())
 				logger.Log(0, "write close:", err.Error())
 				return err
 				return err
 			}
 			}
-			select {
-			case <-done:
-			case <-time.After(time.Second):
-			}
 			return nil
 			return nil
 		}
 		}
 	}
 	}

+ 1 - 1
netclient/functions/mqpublish.go

@@ -167,7 +167,7 @@ func publishMetrics(nodeCfg *config.ClientConfig) {
 		logger.Log(1, "failed to authenticate when publishing metrics", err.Error())
 		logger.Log(1, "failed to authenticate when publishing metrics", err.Error())
 		return
 		return
 	}
 	}
-	url := "https://" + nodeCfg.Server.API + "/api/nodes/" + nodeCfg.Network + "/" + nodeCfg.Node.ID
+	url := fmt.Sprintf("https://%s/api/nodes/%s/%s", nodeCfg.Server.API, nodeCfg.Network, nodeCfg.Node.ID)
 	response, err := API("", http.MethodGet, url, token)
 	response, err := API("", http.MethodGet, url, token)
 	if err != nil {
 	if err != nil {
 		logger.Log(1, "failed to read from server during metrics publish", err.Error())
 		logger.Log(1, "failed to read from server during metrics publish", err.Error())