Browse Source

initialized

0xdcarns 3 years ago
parent
commit
8a54f50676
5 changed files with 192 additions and 32 deletions
  1. 66 0
      auth/auth.go
  2. 65 0
      auth/google.go
  3. 34 32
      config/config.go
  4. 1 0
      go.mod
  5. 26 0
      servercfg/serverconf.go

+ 66 - 0
auth/auth.go

@@ -0,0 +1,66 @@
+package auth
+
+import (
+	"net/http"
+
+	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/oauth2"
+)
+
+// == consts ==
+const (
+	init_provider          = "initprovider"
+	get_user_info          = "getuserinfo"
+	handle_callback        = "handlecallback"
+	handle_login           = "handlelogin"
+	oauth_state_string     = "netmaker-oauth-state"
+	google_provider_name   = "google"
+	azure_ad_provider_name = "azure-ad"
+	github_provider_name   = "github"
+)
+
+var auth_provider *oauth2.Config
+
+func getCurrentAuthFunctions() map[string]interface{} {
+	var authInfo = servercfg.GetAuthProviderInfo()
+	var authProvider = authInfo[0]
+	switch authProvider {
+	case google_provider_name:
+		return google_functions
+	case azure_ad_provider_name:
+		return google_functions
+	case github_provider_name:
+		return google_functions
+	default:
+		return nil
+	}
+}
+
+// InitializeAuthProvider - initializes the auth provider if any is present
+func InitializeAuthProvider() bool {
+	var functions = getCurrentAuthFunctions()
+	if functions == nil {
+		return false
+	}
+	var authInfo = servercfg.GetAuthProviderInfo()
+	functions[init_provider].(func(string, string, string))(servercfg.GetAPIConnString(), authInfo[1], authInfo[2])
+	return auth_provider != nil
+}
+
+// HandleAuthCallback - handles oauth callback
+func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
+	var functions = getCurrentAuthFunctions()
+	if functions == nil {
+		return
+	}
+	functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
+}
+
+// HandleAuthLogin - handles oauth login
+func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
+	var functions = getCurrentAuthFunctions()
+	if functions == nil {
+		return
+	}
+	functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
+}

+ 65 - 0
auth/google.go

@@ -0,0 +1,65 @@
+package auth
+
+import (
+	"fmt"
+	"io/ioutil"
+	"net/http"
+
+	"golang.org/x/oauth2"
+	"golang.org/x/oauth2/google"
+)
+
+var google_functions = map[string]interface{}{
+	init_provider:   initGoogle,
+	get_user_info:   getUserInfo,
+	handle_callback: handleGoogleCallback,
+	handle_login:    handleGoogleLogin,
+}
+
+// == handle google authentication here ==
+
+func initGoogle(redirectURL string, clientID string, clientSecret string) {
+	auth_provider = &oauth2.Config{
+		RedirectURL:  redirectURL,
+		ClientID:     clientID,
+		ClientSecret: clientSecret,
+		Scopes:       []string{"https://www.googleapis.com/auth/userinfo.email"},
+		Endpoint:     google.Endpoint,
+	}
+}
+
+func handleGoogleLogin(w http.ResponseWriter, r *http.Request) {
+	url := auth_provider.AuthCodeURL(oauth_state_string)
+	http.Redirect(w, r, url, http.StatusTemporaryRedirect)
+}
+
+func handleGoogleCallback(w http.ResponseWriter, r *http.Request) {
+
+	var content, err = getUserInfo(r.FormValue("state"), r.FormValue("code"))
+	if err != nil {
+		fmt.Println(err.Error())
+		http.Redirect(w, r, "/api/oauth/error", http.StatusTemporaryRedirect)
+		return
+	}
+	fmt.Fprintf(w, "Content: %s\n", content)
+}
+
+func getUserInfo(state string, code string) ([]byte, error) {
+	if state != oauth_state_string {
+		return nil, fmt.Errorf("invalid oauth state")
+	}
+	token, err := auth_provider.Exchange(oauth2.NoContext, code)
+	if err != nil {
+		return nil, fmt.Errorf("code exchange failed: %s", err.Error())
+	}
+	response, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken)
+	if err != nil {
+		return nil, fmt.Errorf("failed getting user info: %s", err.Error())
+	}
+	defer response.Body.Close()
+	contents, err := ioutil.ReadAll(response.Body)
+	if err != nil {
+		return nil, fmt.Errorf("failed reading response body: %s", err.Error())
+	}
+	return contents, nil
+}

+ 34 - 32
config/config.go

@@ -30,49 +30,51 @@ var Config *EnvironmentConfig
 // EnvironmentConfig :
 type EnvironmentConfig struct {
 	Server ServerConfig `yaml:"server"`
-	SQL SQLConfig `yaml:"sql"`
+	SQL    SQLConfig    `yaml:"sql"`
 }
 
 // ServerConfig :
 type ServerConfig struct {
-	CoreDNSAddr          string `yaml:"corednsaddr"`
-	APIConnString        string `yaml:"apiconn"`
-	APIHost              string `yaml:"apihost"`
-	APIPort              string `yaml:"apiport"`
-	GRPCConnString       string `yaml:"grpcconn"`
-	GRPCHost             string `yaml:"grpchost"`
-	GRPCPort             string `yaml:"grpcport"`
-	GRPCSecure           string `yaml:"grpcsecure"`
-	MasterKey            string `yaml:"masterkey"`
-	AllowedOrigin        string `yaml:"allowedorigin"`
-	NodeID        string `yaml:"nodeid"`
-	RestBackend          string `yaml:"restbackend"`
-	AgentBackend         string `yaml:"agentbackend"`
-	ClientMode           string `yaml:"clientmode"`
-	DNSMode              string `yaml:"dnsmode"`
-	SplitDNS             string `yaml:"splitdns"`
-	DisableRemoteIPCheck string `yaml:"disableremoteipcheck"`
-	DisableDefaultNet    string `yaml:"disabledefaultnet"`
-	GRPCSSL              string `yaml:"grpcssl"`
-	Version              string `yaml:"version"`
-	SQLConn              string `yaml:"sqlconn"`
-	Platform             string `yaml:"platform"`
-	Database             string `yaml:database`
-	CheckinInterval      string `yaml:checkininterval`
-	DefaultNodeLimit     int32  `yaml:"defaultnodelimit"`
-	Verbosity            int32  `yaml:"verbosity"`
+	CoreDNSAddr           string `yaml:"corednsaddr"`
+	APIConnString         string `yaml:"apiconn"`
+	APIHost               string `yaml:"apihost"`
+	APIPort               string `yaml:"apiport"`
+	GRPCConnString        string `yaml:"grpcconn"`
+	GRPCHost              string `yaml:"grpchost"`
+	GRPCPort              string `yaml:"grpcport"`
+	GRPCSecure            string `yaml:"grpcsecure"`
+	MasterKey             string `yaml:"masterkey"`
+	AllowedOrigin         string `yaml:"allowedorigin"`
+	NodeID                string `yaml:"nodeid"`
+	RestBackend           string `yaml:"restbackend"`
+	AgentBackend          string `yaml:"agentbackend"`
+	ClientMode            string `yaml:"clientmode"`
+	DNSMode               string `yaml:"dnsmode"`
+	SplitDNS              string `yaml:"splitdns"`
+	DisableRemoteIPCheck  string `yaml:"disableremoteipcheck"`
+	DisableDefaultNet     string `yaml:"disabledefaultnet"`
+	GRPCSSL               string `yaml:"grpcssl"`
+	Version               string `yaml:"version"`
+	SQLConn               string `yaml:"sqlconn"`
+	Platform              string `yaml:"platform"`
+	Database              string `yaml:database`
+	CheckinInterval       string `yaml:checkininterval`
+	DefaultNodeLimit      int32  `yaml:"defaultnodelimit"`
+	Verbosity             int32  `yaml:"verbosity"`
 	ServerCheckinInterval int64  `yaml:"servercheckininterval"`
+	AuthProvider          string `yaml:"authprovider"`
+	ClientID              string `yaml:"clientid"`
+	ClientSecret          string `yaml:"clientsecret"`
 }
 
-
 // Generic SQL Config
 type SQLConfig struct {
-	Host string `yaml:"host"`
-	Port int32 `yaml:"port"`
+	Host     string `yaml:"host"`
+	Port     int32  `yaml:"port"`
 	Username string `yaml:"username"`
 	Password string `yaml:"password"`
-	DB string `yaml:"db"`
-	SSLMode string `yaml:"sslmode"`
+	DB       string `yaml:"db"`
+	SSLMode  string `yaml:"sslmode"`
 }
 
 //reading in the env file

+ 1 - 0
go.mod

@@ -17,6 +17,7 @@ require (
 	github.com/urfave/cli/v2 v2.3.0
 	golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97
 	golang.org/x/net v0.0.0-20210726213435-c6fcb2dbf985 // indirect
+	golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be // indirect
 	golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e // indirect
 	golang.org/x/text v0.3.7-0.20210524175448-3115f89c4b99 // indirect
 	golang.zx2c4.com/wireguard v0.0.0-20210805125648-3957e9b9dd19 // indirect

+ 26 - 0
servercfg/serverconf.go

@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"os"
 	"strconv"
+	"strings"
 
 	"github.com/gravitl/netmaker/config"
 )
@@ -65,6 +66,12 @@ func GetServerConfig() config.ServerConfig {
 	cfg.Database = GetDB()
 	cfg.Platform = GetPlatform()
 	cfg.Version = GetVersion()
+
+	// == auth config ==
+	var authInfo = GetAuthProviderInfo()
+	cfg.AuthProvider = authInfo[0]
+	cfg.ClientID = authInfo[1]
+	cfg.ClientSecret = authInfo[2]
 	return cfg
 }
 func GetAPIConnString() string {
@@ -398,6 +405,25 @@ func GetServerCheckinInterval() int64 {
 	return t
 }
 
+// GetAuthProviderInfo = gets the oauth provider info
+func GetAuthProviderInfo() []string {
+	var authProvider = ""
+	if os.Getenv("AUTH_PROVIDER") != "" && os.Getenv("CLIENT_ID") != "" && os.Getenv("CLIENT_SECRET") != "" {
+		authProvider = strings.ToLower(os.Getenv("AUTH_PROVIDER"))
+		if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" {
+			return []string{authProvider, os.Getenv("CLIENT_ID"), os.Getenv("CLIENT_SECRET")}
+		} else {
+			authProvider = ""
+		}
+	} else if config.Config.Server.AuthProvider != "" && config.Config.Server.ClientID != "" && config.Config.Server.ClientSecret != "" {
+		authProvider = strings.ToLower(config.Config.Server.AuthProvider)
+		if authProvider == "google" || authProvider == "azure-ad" || authProvider == "github" {
+			return []string{authProvider, config.Config.Server.ClientID, config.Config.Server.ClientSecret}
+		}
+	}
+	return []string{"", "", ""}
+}
+
 // GetMacAddr - get's mac address
 func getMacAddr() string {
 	ifas, err := net.Interfaces()