Browse Source

session management for remote access client (#2592)

* feat(NET-584): wip: session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): session mgmt for RAC

* feat(NET-584): only enable if client is disabled

* feat(NET-584): check only for normal users

* feat(NET-584): fix condition
Aceix 1 year ago
parent
commit
bfc61fa359

+ 59 - 55
config/config.go

@@ -7,6 +7,7 @@ package config
 import (
 	"fmt"
 	"os"
+	"time"
 
 	"gopkg.in/yaml.v3"
 )
@@ -32,61 +33,64 @@ type EnvironmentConfig struct {
 
 // ServerConfig - server conf struct
 type ServerConfig struct {
-	CoreDNSAddr                string `yaml:"corednsaddr"`
-	APIConnString              string `yaml:"apiconn"`
-	APIHost                    string `yaml:"apihost"`
-	APIPort                    string `yaml:"apiport"`
-	Broker                     string `yam:"broker"`
-	ServerBrokerEndpoint       string `yaml:"serverbrokerendpoint"`
-	BrokerType                 string `yaml:"brokertype"`
-	EmqxRestEndpoint           string `yaml:"emqxrestendpoint"`
-	NetclientAutoUpdate        string `yaml:"netclientautoupdate"`
-	NetclientEndpointDetection string `yaml:"netclientendpointdetection"`
-	MasterKey                  string `yaml:"masterkey"`
-	DNSKey                     string `yaml:"dnskey"`
-	AllowedOrigin              string `yaml:"allowedorigin"`
-	NodeID                     string `yaml:"nodeid"`
-	RestBackend                string `yaml:"restbackend"`
-	MessageQueueBackend        string `yaml:"messagequeuebackend"`
-	DNSMode                    string `yaml:"dnsmode"`
-	DisableRemoteIPCheck       string `yaml:"disableremoteipcheck"`
-	Version                    string `yaml:"version"`
-	SQLConn                    string `yaml:"sqlconn"`
-	Platform                   string `yaml:"platform"`
-	Database                   string `yaml:"database"`
-	Verbosity                  int32  `yaml:"verbosity"`
-	AuthProvider               string `yaml:"authprovider"`
-	OIDCIssuer                 string `yaml:"oidcissuer"`
-	ClientID                   string `yaml:"clientid"`
-	ClientSecret               string `yaml:"clientsecret"`
-	FrontendURL                string `yaml:"frontendurl"`
-	DisplayKeys                string `yaml:"displaykeys"`
-	AzureTenant                string `yaml:"azuretenant"`
-	Telemetry                  string `yaml:"telemetry"`
-	HostNetwork                string `yaml:"hostnetwork"`
-	Server                     string `yaml:"server"`
-	PublicIPService            string `yaml:"publicipservice"`
-	MQPassword                 string `yaml:"mqpassword"`
-	MQUserName                 string `yaml:"mqusername"`
-	MetricsExporter            string `yaml:"metrics_exporter"`
-	BasicAuth                  string `yaml:"basic_auth"`
-	LicenseValue               string `yaml:"license_value"`
-	NetmakerTenantID           string `yaml:"netmaker_tenant_id"`
-	IsPro                      string `yaml:"is_ee" json:"IsEE"`
-	StunPort                   int    `yaml:"stun_port"`
-	TurnServer                 string `yaml:"turn_server"`
-	TurnApiServer              string `yaml:"turn_api_server"`
-	TurnPort                   int    `yaml:"turn_port"`
-	TurnUserName               string `yaml:"turn_username"`
-	TurnPassword               string `yaml:"turn_password"`
-	UseTurn                    bool   `yaml:"use_turn"`
-	UsersLimit                 int    `yaml:"user_limit"`
-	NetworksLimit              int    `yaml:"network_limit"`
-	MachinesLimit              int    `yaml:"machines_limit"`
-	IngressesLimit             int    `yaml:"ingresses_limit"`
-	EgressesLimit              int    `yaml:"egresses_limit"`
-	DeployedByOperator         bool   `yaml:"deployed_by_operator"`
-	Environment                string `yaml:"environment"`
+	CoreDNSAddr                string        `yaml:"corednsaddr"`
+	APIConnString              string        `yaml:"apiconn"`
+	APIHost                    string        `yaml:"apihost"`
+	APIPort                    string        `yaml:"apiport"`
+	Broker                     string        `yam:"broker"`
+	ServerBrokerEndpoint       string        `yaml:"serverbrokerendpoint"`
+	BrokerType                 string        `yaml:"brokertype"`
+	EmqxRestEndpoint           string        `yaml:"emqxrestendpoint"`
+	NetclientAutoUpdate        string        `yaml:"netclientautoupdate"`
+	NetclientEndpointDetection string        `yaml:"netclientendpointdetection"`
+	MasterKey                  string        `yaml:"masterkey"`
+	DNSKey                     string        `yaml:"dnskey"`
+	AllowedOrigin              string        `yaml:"allowedorigin"`
+	NodeID                     string        `yaml:"nodeid"`
+	RestBackend                string        `yaml:"restbackend"`
+	MessageQueueBackend        string        `yaml:"messagequeuebackend"`
+	DNSMode                    string        `yaml:"dnsmode"`
+	DisableRemoteIPCheck       string        `yaml:"disableremoteipcheck"`
+	Version                    string        `yaml:"version"`
+	SQLConn                    string        `yaml:"sqlconn"`
+	Platform                   string        `yaml:"platform"`
+	Database                   string        `yaml:"database"`
+	Verbosity                  int32         `yaml:"verbosity"`
+	AuthProvider               string        `yaml:"authprovider"`
+	OIDCIssuer                 string        `yaml:"oidcissuer"`
+	ClientID                   string        `yaml:"clientid"`
+	ClientSecret               string        `yaml:"clientsecret"`
+	FrontendURL                string        `yaml:"frontendurl"`
+	DisplayKeys                string        `yaml:"displaykeys"`
+	AzureTenant                string        `yaml:"azuretenant"`
+	Telemetry                  string        `yaml:"telemetry"`
+	HostNetwork                string        `yaml:"hostnetwork"`
+	Server                     string        `yaml:"server"`
+	PublicIPService            string        `yaml:"publicipservice"`
+	MQPassword                 string        `yaml:"mqpassword"`
+	MQUserName                 string        `yaml:"mqusername"`
+	MetricsExporter            string        `yaml:"metrics_exporter"`
+	BasicAuth                  string        `yaml:"basic_auth"`
+	LicenseValue               string        `yaml:"license_value"`
+	NetmakerTenantID           string        `yaml:"netmaker_tenant_id"`
+	IsPro                      string        `yaml:"is_ee" json:"IsEE"`
+	StunPort                   int           `yaml:"stun_port"`
+	StunList                   string        `yaml:"stun_list"`
+	TurnServer                 string        `yaml:"turn_server"`
+	TurnApiServer              string        `yaml:"turn_api_server"`
+	TurnPort                   int           `yaml:"turn_port"`
+	TurnUserName               string        `yaml:"turn_username"`
+	TurnPassword               string        `yaml:"turn_password"`
+	UseTurn                    bool          `yaml:"use_turn"`
+	UsersLimit                 int           `yaml:"user_limit"`
+	NetworksLimit              int           `yaml:"network_limit"`
+	MachinesLimit              int           `yaml:"machines_limit"`
+	IngressesLimit             int           `yaml:"ingresses_limit"`
+	EgressesLimit              int           `yaml:"egresses_limit"`
+	DeployedByOperator         bool          `yaml:"deployed_by_operator"`
+	Environment                string        `yaml:"environment"`
+	JwtValidityDuration        time.Duration `yaml:"jwt_validity_duration"`
+	RacAutoDisable             bool          `yaml:"rac_auto_disable"`
 }
 
 // SQLConfig - Generic SQL Config

+ 28 - 1
controllers/user.go

@@ -12,6 +12,7 @@ import (
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
 	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/servercfg"
 	"golang.org/x/exp/slog"
 )
@@ -96,7 +97,6 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	}
 	// Send back the JWT
 	successJSONResponse, jsonError := json.Marshal(successResponse)
-
 	if jsonError != nil {
 		logger.Log(0, username,
 			"error marshalling resp: ", err.Error())
@@ -106,6 +106,33 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
 	logger.Log(2, username, "was authenticated")
 	response.Header().Set("Content-Type", "application/json")
 	response.Write(successJSONResponse)
+
+	go func() {
+		if servercfg.IsPro && servercfg.GetRacAutoDisable() {
+			// enable all associeated clients for the user
+			clients, err := logic.GetAllExtClients()
+			if err != nil {
+				slog.Error("error getting clients: ", "error", err)
+				return
+			}
+			for _, client := range clients {
+				if client.OwnerID == username && !client.Enabled {
+					slog.Info(fmt.Sprintf("enabling ext client %s for user %s due to RAC autodisabling feature", client.ClientID, client.OwnerID))
+					if newClient, err := logic.ToggleExtClientConnectivity(&client, true); err != nil {
+						slog.Error("error disabling ext client in RAC autodisable hook", "error", err)
+						continue // dont return but try for other clients
+					} else {
+						// publish peer update to ingress gateway
+						if ingressNode, err := logic.GetNodeByID(newClient.IngressGatewayID); err == nil {
+							if err = mq.PublishPeerUpdate(); err != nil {
+								slog.Error("error updating ext clients on", "ingress", ingressNode.ID.String(), "err", err.Error())
+							}
+						}
+					}
+				}
+			}
+		}
+	}()
 }
 
 // swagger:route GET /api/users/adm/hassuperadmin user hasSuperAdmin

+ 5 - 0
logic/auth.go

@@ -141,6 +141,11 @@ func VerifyAuthRequest(authRequest models.UserAuthParams) (string, error) {
 
 	// Create a new JWT for the node
 	tokenString, _ := CreateUserJWT(authRequest.UserName, result.IsSuperAdmin, result.IsAdmin)
+
+	// update last login time
+	result.LastLoginTime = time.Now()
+	UpsertUser(result)
+
 	return tokenString, nil
 }
 

+ 27 - 0
logic/extpeers.go

@@ -9,6 +9,7 @@ import (
 
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/models"
+	"golang.org/x/exp/slog"
 	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
 )
 
@@ -276,3 +277,29 @@ func GetAllExtClients() ([]models.ExtClient, error) {
 
 	return clients, nil
 }
+
+// ToggleExtClientConnectivity - enables or disables an ext client
+func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.ExtClient, error) {
+	update := models.CustomExtClient{
+		Enabled:              enable,
+		ClientID:             client.ClientID,
+		PublicKey:            client.PublicKey,
+		DNS:                  client.DNS,
+		ExtraAllowedIPs:      client.ExtraAllowedIPs,
+		DeniedACLs:           client.DeniedACLs,
+		RemoteAccessClientID: client.RemoteAccessClientID,
+	}
+
+	// update in DB
+	newClient := UpdateExtClient(client, &update)
+	if err := DeleteExtClient(client.Network, client.ClientID); err != nil {
+		slog.Error("failed to delete ext client during update", "id", client.ClientID, "network", client.Network, "error", err)
+		return newClient, err
+	}
+	if err := SaveExtClient(&newClient); err != nil {
+		slog.Error("failed to save updated ext client during update", "id", newClient.ClientID, "network", newClient.Network, "error", err)
+		return newClient, err
+	}
+
+	return newClient, nil
+}

+ 1 - 1
logic/jwts.go

@@ -54,7 +54,7 @@ func CreateJWT(uuid string, macAddress string, network string) (response string,
 
 // CreateUserJWT - creates a user jwt token
 func CreateUserJWT(username string, issuperadmin, isadmin bool) (response string, err error) {
-	expirationTime := time.Now().Add(60 * 12 * time.Minute)
+	expirationTime := time.Now().Add(servercfg.GetServerConfig().JwtValidityDuration)
 	claims := &models.UserClaims{
 		UserName:     username,
 		IsSuperAdmin: issuperadmin,

+ 1 - 1
logic/timer.go

@@ -17,7 +17,7 @@ import (
 const timer_hours_between_runs = 24
 
 // HookManagerCh - channel to add any new hooks
-var HookManagerCh = make(chan models.HookDetails, 2)
+var HookManagerCh = make(chan models.HookDetails, 3)
 
 // == Public ==
 

+ 11 - 9
models/structs.go

@@ -24,19 +24,21 @@ type AuthParams struct {
 
 // User struct - struct for Users
 type User struct {
-	UserName     string              `json:"username" bson:"username" validate:"min=3,max=40,in_charset|email"`
-	Password     string              `json:"password" bson:"password" validate:"required,min=5"`
-	IsAdmin      bool                `json:"isadmin" bson:"isadmin"`
-	IsSuperAdmin bool                `json:"issuperadmin"`
-	RemoteGwIDs  map[string]struct{} `json:"remote_gw_ids"`
+	UserName      string              `json:"username" bson:"username" validate:"min=3,max=40,in_charset|email"`
+	Password      string              `json:"password" bson:"password" validate:"required,min=5"`
+	IsAdmin       bool                `json:"isadmin" bson:"isadmin"`
+	IsSuperAdmin  bool                `json:"issuperadmin"`
+	RemoteGwIDs   map[string]struct{} `json:"remote_gw_ids"`
+	LastLoginTime time.Time           `json:"last_login_time"`
 }
 
 // ReturnUser - return user struct
 type ReturnUser struct {
-	UserName     string              `json:"username"`
-	IsAdmin      bool                `json:"isadmin"`
-	IsSuperAdmin bool                `json:"issuperadmin"`
-	RemoteGwIDs  map[string]struct{} `json:"remote_gw_ids"`
+	UserName      string              `json:"username"`
+	IsAdmin       bool                `json:"isadmin"`
+	IsSuperAdmin  bool                `json:"issuperadmin"`
+	RemoteGwIDs   map[string]struct{} `json:"remote_gw_ids"`
+	LastLoginTime time.Time           `json:"last_login_time"`
 }
 
 // UserAuthParams - user auth params struct

+ 3 - 0
pro/initialize.go

@@ -38,6 +38,9 @@ func InitPro() {
 		logic.SetFreeTierForTelemetry(false)
 		// == End License Handling ==
 		AddLicenseHooks()
+		if servercfg.GetServerConfig().RacAutoDisable {
+			AddRacHooks()
+		}
 		resetFailover()
 	})
 	logic.EnterpriseFailoverFunc = proLogic.SetFailover

+ 76 - 0
pro/remote_access_client.go

@@ -0,0 +1,76 @@
+package pro
+
+import (
+	"fmt"
+	"time"
+
+	"github.com/gravitl/netmaker/logic"
+	"github.com/gravitl/netmaker/models"
+	"github.com/gravitl/netmaker/mq"
+	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/exp/slog"
+)
+
+const racAutoDisableCheckInterval = 3 * time.Minute
+
+// AddRacHooks - adds hooks for Remote Access Client
+func AddRacHooks() {
+	slog.Debug("adding RAC autodisable hook")
+	logic.HookManagerCh <- models.HookDetails{
+		Hook:     racAutoDisableHook,
+		Interval: racAutoDisableCheckInterval,
+	}
+}
+
+// racAutoDisableHook - checks if RAC is enabled and if it is, checks if it should be disabled
+func racAutoDisableHook() error {
+	slog.Debug("running RAC autodisable hook")
+
+	users, err := logic.GetUsers()
+	if err != nil {
+		slog.Error("error getting users: ", "error", err)
+		return err
+	}
+	clients, err := logic.GetAllExtClients()
+	if err != nil {
+		slog.Error("error getting clients: ", "error", err)
+		return err
+	}
+
+	currentTime := time.Now()
+	validityDuration := servercfg.GetJwtValidityDuration()
+	for _, user := range users {
+		if !currentTime.After(user.LastLoginTime.Add(validityDuration)) {
+			continue
+		}
+		for _, client := range clients {
+			if (client.OwnerID == user.UserName) && !user.IsAdmin && !user.IsSuperAdmin && client.Enabled {
+				slog.Info(fmt.Sprintf("disabling ext client %s for user %s due to RAC autodisabling", client.ClientID, client.OwnerID))
+				if err := disableExtClient(&client); err != nil {
+					slog.Error("error disabling ext client in RAC autodisable hook", "error", err)
+					continue // dont return but try for other clients
+				}
+			}
+		}
+	}
+
+	slog.Debug("finished running RAC autodisable hook")
+	return nil
+}
+
+func disableExtClient(client *models.ExtClient) error {
+	if newClient, err := logic.ToggleExtClientConnectivity(client, false); err != nil {
+		return err
+	} else {
+		// publish peer update to ingress gateway
+		if ingressNode, err := logic.GetNodeByID(newClient.IngressGatewayID); err == nil {
+			if err = mq.PublishPeerUpdate(); err != nil {
+				slog.Error("error updating ext clients on", "ingress", ingressNode.ID.String(), "err", err.Error())
+			}
+		} else {
+			return err
+		}
+	}
+
+	return nil
+}

+ 4 - 0
scripts/netmaker.default.env

@@ -78,3 +78,7 @@ FRONTEND_URL=
 AZURE_TENANT=
 # https://oidc.yourprovider.com - URL of oidc provider
 OIDC_ISSUER=
+# Duration of JWT token validity in seconds
+JWT_VALIDITY_DURATION=43200
+# Auto disable a user's connecteds clients bassed on JWT token expiration
+RAC_AUTO_DISABLE="true"

+ 1 - 1
scripts/nm-quick.sh

@@ -310,7 +310,7 @@ save_config() { (
 		"CORS_ALLOWED_ORIGIN" "DISPLAY_KEYS" "DATABASE" "SERVER_BROKER_ENDPOINT" "STUN_PORT" "VERBOSITY"
 		"TURN_PORT" "USE_TURN" "DEBUG_MODE" "TURN_API_PORT" "REST_BACKEND"
 		"DISABLE_REMOTE_IP_CHECK" "NETCLIENT_ENDPOINT_DETECTION" "TELEMETRY" "AUTH_PROVIDER" "CLIENT_ID" "CLIENT_SECRET"
-		"FRONTEND_URL" "AZURE_TENANT" "OIDC_ISSUER" "EXPORTER_API_PORT")
+		"FRONTEND_URL" "AZURE_TENANT" "OIDC_ISSUER" "EXPORTER_API_PORT" "JWT_VALIDITY_DURATION" "RAC_AUTO_DISABLE")
 	for name in "${toCopy[@]}"; do
 		save_config_item $name "${!name}"
 	done

+ 1 - 1
scripts/nm-upgrade.sh

@@ -180,7 +180,7 @@ save_config() { (
 		"CORS_ALLOWED_ORIGIN" "DISPLAY_KEYS" "DATABASE" "SERVER_BROKER_ENDPOINT" "STUN_PORT" "VERBOSITY"
 		"TURN_PORT" "USE_TURN" "DEBUG_MODE" "TURN_API_PORT" "REST_BACKEND"
 		"DISABLE_REMOTE_IP_CHECK" "NETCLIENT_ENDPOINT_DETECTION" "TELEMETRY" "AUTH_PROVIDER" "CLIENT_ID" "CLIENT_SECRET"
-		"FRONTEND_URL" "AZURE_TENANT" "OIDC_ISSUER" "EXPORTER_API_PORT")
+		"FRONTEND_URL" "AZURE_TENANT" "OIDC_ISSUER" "EXPORTER_API_PORT" "JWT_VALIDITY_DURATION" "RAC_AUTO_DISABLE")
 	for name in "${toCopy[@]}"; do
 		save_config_item $name "${!name}"
 	done

+ 21 - 1
servercfg/serverconf.go

@@ -90,11 +90,31 @@ func GetServerConfig() config.ServerConfig {
 	if IsPro {
 		cfg.IsPro = "yes"
 	}
+	cfg.JwtValidityDuration = GetJwtValidityDuration()
+	cfg.RacAutoDisable = GetRacAutoDisable()
 
 	return cfg
 }
 
-// GetServerConfig - gets the server config into memory from file or env
+// GetJwtValidityDuration - returns the JWT validity duration in seconds
+func GetJwtValidityDuration() time.Duration {
+	var defaultDuration = time.Duration(24) * time.Hour
+	if os.Getenv("JWT_VALIDITY_DURATION") != "" {
+		t, err := strconv.Atoi(os.Getenv("JWT_VALIDITY_DURATION"))
+		if err != nil {
+			return defaultDuration
+		}
+		return time.Duration(t) * time.Second
+	}
+	return defaultDuration
+}
+
+// GetRacAutoDisable - returns whether the feature to autodisable RAC is enabled
+func GetRacAutoDisable() bool {
+	return os.Getenv("RAC_AUTO_DISABLE") == "true"
+}
+
+// GetServerInfo - gets the server config into memory from file or env
 func GetServerInfo() models.ServerConfig {
 	var cfg models.ServerConfig
 	cfg.Server = GetServer()