Browse Source

Merge pull request #1444 from gravitl/feature_v0.14.7_ha_sso

added better state management to make OAuth sign-ins HA
dcarns 3 years ago
parent
commit
61553d70ab
8 changed files with 107 additions and 9 deletions
  1. 0 1
      auth/auth.go
  2. 9 2
      auth/azure-ad.go
  3. 9 2
      auth/github.go
  4. 9 2
      auth/google.go
  5. 9 2
      auth/oidc.go
  6. 4 0
      database/database.go
  7. 50 0
      logic/auth.go
  8. 17 0
      models/ssocache.go

+ 0 - 1
auth/auth.go

@@ -29,7 +29,6 @@ const (
 	auth_key               = "netmaker_auth"
 )
 
-var oauth_state_string = "netmaker-oauth-state" // should be set randomly each provider login
 var auth_provider *oauth2.Config
 
 func getCurrentAuthFunctions() map[string]interface{} {

+ 9 - 2
auth/azure-ad.go

@@ -41,7 +41,7 @@ func initAzureAD(redirectURL string, clientID string, clientSecret string) {
 }
 
 func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
-	oauth_state_string = logic.RandomString(16)
+	var oauth_state_string = logic.RandomString(16)
 	if auth_provider == nil && servercfg.GetFrontendURL() != "" {
 		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
 		return
@@ -49,6 +49,12 @@ func handleAzureLogin(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
 		return
 	}
+
+	if err := logic.SetState(oauth_state_string); err != nil {
+		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
+		return
+	}
+
 	var url = auth_provider.AuthCodeURL(oauth_state_string)
 	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 }
@@ -88,7 +94,8 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) {
 }
 
 func getAzureUserInfo(state string, code string) (*azureOauthUser, error) {
-	if state != oauth_state_string {
+	oauth_state_string, isValid := logic.IsStateValid(state)
+	if !isValid || state != oauth_state_string {
 		return nil, fmt.Errorf("invalid oauth state")
 	}
 	var token, err = auth_provider.Exchange(context.Background(), code)

+ 9 - 2
auth/github.go

@@ -41,7 +41,7 @@ func initGithub(redirectURL string, clientID string, clientSecret string) {
 }
 
 func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
-	oauth_state_string = logic.RandomString(16)
+	var oauth_state_string = logic.RandomString(16)
 	if auth_provider == nil && servercfg.GetFrontendURL() != "" {
 		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
 		return
@@ -49,6 +49,12 @@ func handleGithubLogin(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
 		return
 	}
+
+	if err := logic.SetState(oauth_state_string); err != nil {
+		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
+		return
+	}
+
 	var url = auth_provider.AuthCodeURL(oauth_state_string)
 	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 }
@@ -88,7 +94,8 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) {
 }
 
 func getGithubUserInfo(state string, code string) (*githubOauthUser, error) {
-	if state != oauth_state_string {
+	oauth_state_string, isValid := logic.IsStateValid(state)
+	if !isValid || state != oauth_state_string {
 		return nil, fmt.Errorf("invalid OAuth state")
 	}
 	var token, err = auth_provider.Exchange(context.Background(), code)

+ 9 - 2
auth/google.go

@@ -42,7 +42,7 @@ func initGoogle(redirectURL string, clientID string, clientSecret string) {
 }
 
 func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
-	oauth_state_string = logic.RandomString(16)
+	var oauth_state_string = logic.RandomString(16)
 	if auth_provider == nil && servercfg.GetFrontendURL() != "" {
 		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
 		return
@@ -50,6 +50,12 @@ func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
 		return
 	}
+
+	if err := logic.SetState(oauth_state_string); err != nil {
+		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
+		return
+	}
+
 	var url = auth_provider.AuthCodeURL(oauth_state_string)
 	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 }
@@ -89,7 +95,8 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
 }
 
 func getGoogleUserInfo(state string, code string) (*googleOauthUser, error) {
-	if state != oauth_state_string {
+	oauth_state_string, isValid := logic.IsStateValid(state)
+	if !isValid || state != oauth_state_string {
 		return nil, fmt.Errorf("invalid OAuth state")
 	}
 	var token, err = auth_provider.Exchange(context.Background(), code)

+ 9 - 2
auth/oidc.go

@@ -54,7 +54,7 @@ func initOIDC(redirectURL string, clientID string, clientSecret string, issuer s
 }
 
 func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
-	oauth_state_string = logic.RandomString(16)
+	var oauth_state_string = logic.RandomString(16)
 	if auth_provider == nil && servercfg.GetFrontendURL() != "" {
 		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
 		return
@@ -62,6 +62,12 @@ func handleOIDCLogin(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprintf(w, "%s", []byte("no frontend URL was provided and an OAuth login was attempted\nplease reconfigure server to use OAuth or use basic credentials"))
 		return
 	}
+
+	if err := logic.SetState(oauth_state_string); err != nil {
+		http.Redirect(w, r, servercfg.GetFrontendURL()+"/login?oauth=callback-error", http.StatusTemporaryRedirect)
+		return
+	}
+
 	var url = auth_provider.AuthCodeURL(oauth_state_string)
 	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
 }
@@ -101,7 +107,8 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) {
 }
 
 func getOIDCUserInfo(state string, code string) (u *OIDCUser, e error) {
-	if state != oauth_state_string {
+	oauth_state_string, isValid := logic.IsStateValid(state)
+	if !isValid || state != oauth_state_string {
 		return nil, fmt.Errorf("invalid OAuth state")
 	}
 

+ 4 - 0
database/database.go

@@ -56,6 +56,9 @@ const GENERATED_TABLE_NAME = "generated"
 // NODE_ACLS_TABLE_NAME - stores the node ACL rules
 const NODE_ACLS_TABLE_NAME = "nodeacls"
 
+// SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins
+const SSO_STATE_CACHE = "ssostatecache"
+
 // == ERROR CONSTS ==
 
 // NO_RECORD - no singular result found
@@ -135,6 +138,7 @@ func createTables() {
 	createTable(SERVER_UUID_TABLE_NAME)
 	createTable(GENERATED_TABLE_NAME)
 	createTable(NODE_ACLS_TABLE_NAME)
+	createTable(SSO_STATE_CACHE)
 }
 
 func createTable(tableName string) error {

+ 50 - 0
logic/auth.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"time"
 
 	"github.com/go-playground/validator/v10"
 	"github.com/gravitl/netmaker/database"
@@ -270,3 +271,52 @@ func FetchAuthSecret(key string, secret string) (string, error) {
 	}
 	return record, nil
 }
+
+// GetState - gets an SsoState from DB, if expired returns error
+func GetState(state string) (*models.SsoState, error) {
+	var s models.SsoState
+	record, err := database.FetchRecord(database.SSO_STATE_CACHE, state)
+	if err != nil {
+		return &s, err
+	}
+
+	if err = json.Unmarshal([]byte(record), &s); err != nil {
+		return &s, err
+	}
+
+	if s.IsExpired() {
+		return &s, fmt.Errorf("state expired")
+	}
+
+	return &s, nil
+}
+
+// SetState - sets a state with new expiration
+func SetState(state string) error {
+	s := models.SsoState{
+		Value:      state,
+		Expiration: time.Now().Add(models.DefaultExpDuration),
+	}
+
+	data, err := json.Marshal(&s)
+	if err != nil {
+		return err
+	}
+
+	return database.Insert(state, string(data), database.SSO_STATE_CACHE)
+}
+
+// IsStateValid - checks if given state is valid or not
+// deletes state after call is made to clean up, should only be called once per sign-in
+func IsStateValid(state string) (string, bool) {
+	s, err := GetState(state)
+	if s.Value != "" {
+		delState(state)
+	}
+	return s.Value, err == nil
+}
+
+// delState - removes a state from cache/db
+func delState(state string) error {
+	return database.DeleteRecord(database.SSO_STATE_CACHE, state)
+}

+ 17 - 0
models/ssocache.go

@@ -0,0 +1,17 @@
+package models
+
+import "time"
+
+// DefaultExpDuration - the default expiration time of SsoState
+const DefaultExpDuration = time.Minute * 5
+
+// SsoState - holds SSO sign-in session data
+type SsoState struct {
+	Value      string    `json:"value"`
+	Expiration time.Time `json:"expiration"`
+}
+
+// SsoState.IsExpired - tells if an SsoState is expired or not
+func (s *SsoState) IsExpired() bool {
+	return time.Now().After(s.Expiration)
+}