| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 | package authimport (	"encoding/base64"	"encoding/json"	"errors"	"fmt"	"net/http"	"strings"	"time"	"golang.org/x/crypto/bcrypt"	"golang.org/x/oauth2"	"github.com/gorilla/websocket"	"github.com/gravitl/netmaker/logger"	"github.com/gravitl/netmaker/logic"	"github.com/gravitl/netmaker/logic/pro/netcache"	"github.com/gravitl/netmaker/models"	"github.com/gravitl/netmaker/servercfg")// == consts ==const (	init_provider          = "initprovider"	get_user_info          = "getuserinfo"	handle_callback        = "handlecallback"	handle_login           = "handlelogin"	google_provider_name   = "google"	azure_ad_provider_name = "azure-ad"	github_provider_name   = "github"	oidc_provider_name     = "oidc"	verify_user            = "verifyuser"	auth_key               = "netmaker_auth"	user_signin_length     = 16	node_signin_length     = 64	headless_signin_length = 32)// OAuthUser - generic OAuth strategy usertype OAuthUser struct {	Name              string `json:"name" bson:"name"`	Email             string `json:"email" bson:"email"`	Login             string `json:"login" bson:"login"`	UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`	AccessToken       string `json:"accesstoken" bson:"accesstoken"`}var (	auth_provider *oauth2.Config	upgrader      = websocket.Upgrader{})func getCurrentAuthFunctions() map[string]interface{} {	var authInfo = servercfg.GetAuthProviderInfo()	var authProvider = authInfo[0]	switch authProvider {	case google_provider_name:		return google_functions	case azure_ad_provider_name:		return azure_ad_functions	case github_provider_name:		return github_functions	case oidc_provider_name:		return oidc_functions	default:		return nil	}}// InitializeAuthProvider - initializes the auth provider if any is presentfunc InitializeAuthProvider() string {	var functions = getCurrentAuthFunctions()	if functions == nil {		return ""	}	var _, err = fetchPassValue(logic.RandomString(64))	if err != nil {		logger.Log(0, err.Error())		return ""	}	var authInfo = servercfg.GetAuthProviderInfo()	var serverConn = servercfg.GetAPIHost()	if strings.Contains(serverConn, "localhost") || strings.Contains(serverConn, "127.0.0.1") {		serverConn = "http://" + serverConn		logger.Log(1, "localhost OAuth detected, proceeding with insecure http redirect: (", serverConn, ")")	} else {		serverConn = "https://" + serverConn		logger.Log(1, "external OAuth detected, proceeding with https redirect: ("+serverConn+")")	}	if authInfo[0] == "oidc" {		functions[init_provider].(func(string, string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2], authInfo[3])		return authInfo[0]	}	functions[init_provider].(func(string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2])	return authInfo[0]}// HandleAuthCallback - handles oauth callback// Note: not included in API reference as part of the OAuth process itself.func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {	if auth_provider == nil {		handleOauthNotConfigured(w)		return	}	var functions = getCurrentAuthFunctions()	if functions == nil {		return	}	state, _ := getStateAndCode(r)	_, err := netcache.Get(state) // if in netcache proceeed with node registration login	if err == nil || errors.Is(err, netcache.ErrExpired) {		switch len(state) {		case node_signin_length:			logger.Log(0, "proceeding with node SSO callback")			HandleNodeSSOCallback(w, r)		case headless_signin_length:			logger.Log(0, "proceeding with headless SSO callback")			HandleHeadlessSSOCallback(w, r)		default:			logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))		}	} else { // handle normal login		functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)	}}// swagger:route GET /api/oauth/login nodes HandleAuthLogin//// Handles OAuth login.////			Schemes: https////			Security://	  		oauthfunc HandleAuthLogin(w http.ResponseWriter, r *http.Request) {	if auth_provider == nil {		handleOauthNotConfigured(w)		return	}	var functions = getCurrentAuthFunctions()	if functions == nil {		return	}	if servercfg.GetFrontendURL() == "" {		handleOauthNotConfigured(w)		return	}	functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)}// IsOauthUser - returnsfunc IsOauthUser(user *models.User) error {	var currentValue, err = fetchPassValue("")	if err != nil {		return err	}	var bCryptErr = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(currentValue))	return bCryptErr}// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocketfunc HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {	conn, err := upgrader.Upgrade(w, r, nil)	if err != nil {		logger.Log(0, "error during connection upgrade for headless sign-in:", err.Error())		return	}	if conn == nil {		logger.Log(0, "failed to establish web-socket connection during headless sign-in")		return	}	defer conn.Close()	req := &netcache.CValue{User: "", Pass: ""}	stateStr := logic.RandomString(headless_signin_length)	if err = netcache.Set(stateStr, req); err != nil {		logger.Log(0, "Failed to process sso request -", err.Error())		return	}	timeout := make(chan bool, 1)	answer := make(chan string, 1)	defer close(answer)	defer close(timeout)	if auth_provider == nil {		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {			logger.Log(0, "error during message writing:", err.Error())		}		return	}	redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)	if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {		logger.Log(0, "error during message writing:", err.Error())	}	go func() {		for {			cachedReq, err := netcache.Get(stateStr)			if err != nil {				if strings.Contains(err.Error(), "expired") {					logger.Log(0, "timeout occurred while waiting for SSO")					timeout <- true					break				}				continue			} else if cachedReq.Pass != "" {				logger.Log(0, "SSO process completed for user ", cachedReq.User)				answer <- cachedReq.Pass				break			}			time.Sleep(500) // try it 2 times per second to see if auth is completed		}	}()	select {	case result := <-answer:		if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {			logger.Log(0, "Error during message writing:", err.Error())		}	case <-timeout:		logger.Log(0, "Authentication server time out for headless SSO login")		if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {			logger.Log(0, "Error during message writing:", err.Error())		}	}	if err = netcache.Del(stateStr); err != nil {		logger.Log(0, "failed to remove SSO cache entry", err.Error())	}	if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {		logger.Log(0, "write close:", err.Error())	}}// == private methods ==func addUser(email string) error {	var hasAdmin, err = logic.HasAdmin()	if err != nil {		logger.Log(1, "error checking for existence of admin user during OAuth login for", email, "; user not added")		return err	} // generate random password to adapt to current model	var newPass, fetchErr = fetchPassValue("")	if fetchErr != nil {		return fetchErr	}	var newUser = models.User{		UserName: email,		Password: newPass,	}	if !hasAdmin { // must be first attempt, create an admin		if err = logic.CreateAdmin(&newUser); err != nil {			logger.Log(1, "error creating admin from user,", email, "; user not added")		} else {			logger.Log(1, "admin created from user,", email, "; was first user added")		}	} else { // otherwise add to db as admin..?		// TODO: add ability to add users with preemptive permissions		newUser.IsAdmin = false		if err = logic.CreateUser(&newUser); err != nil {			logger.Log(1, "error creating user,", email, "; user not added")		} else {			logger.Log(0, "user created from ", email)		}	}	return nil}func fetchPassValue(newValue string) (string, error) {	type valueHolder struct {		Value string `json:"value" bson:"value"`	}	var b64NewValue = base64.StdEncoding.EncodeToString([]byte(newValue))	var newValueHolder = &valueHolder{		Value: b64NewValue,	}	var data, marshalErr = json.Marshal(newValueHolder)	if marshalErr != nil {		return "", marshalErr	}	var currentValue, err = logic.FetchAuthSecret(auth_key, string(data))	if err != nil {		return "", err	}	var unmarshErr = json.Unmarshal([]byte(currentValue), newValueHolder)	if unmarshErr != nil {		return "", unmarshErr	}	var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)	if b64Err != nil {		logger.Log(0, "could not decode pass")		return "", nil	}	return string(b64CurrentValue), nil}func getStateAndCode(r *http.Request) (string, string) {	var state, code string	if r.FormValue("state") != "" && r.FormValue("code") != "" {		state = r.FormValue("state")		code = r.FormValue("code")	} else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" {		state = r.URL.Query().Get("state")		code = r.URL.Query().Get("code")	}	return state, code}func (user *OAuthUser) getUserName() string {	var userName string	if user.Email != "" {		userName = user.Email	} else if user.Login != "" {		userName = user.Login	} else if user.UserPrincipalName != "" {		userName = user.UserPrincipalName	} else if user.Name != "" {		userName = user.Name	}	return userName}func isStateCached(state string) bool {	_, err := netcache.Get(state)	return err == nil || strings.Contains(err.Error(), "expired")}
 |