Browse Source

convert access token to sql schema

abhishek9686 5 months ago
parent
commit
5d1f2a39c4
11 changed files with 200 additions and 164 deletions
  1. 7 15
      controllers/user.go
  2. 3 54
      database/database.go
  3. 2 2
      database/postgres.go
  4. 2 2
      database/rqlite.go
  5. 2 2
      database/sqlite.go
  6. 15 2
      db/connector.go
  7. 70 2
      db/postgres.go
  8. 4 76
      logic/auth.go
  9. 4 1
      logic/jwts.go
  10. 43 1
      main.go
  11. 48 7
      models/accessToken.go

+ 7 - 15
controllers/user.go

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net/http"
 	"reflect"
+	"time"
 
 	"github.com/google/uuid"
 	"github.com/gorilla/mux"
@@ -84,6 +85,8 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 	req.ID = uuid.New().String()
+	req.CreatedBy = r.Header.Get("user")
+	req.CreatedAt = time.Now()
 	jwt, err := logic.CreateUserAccessJwtToken(user.UserName, user.PlatformRoleID, req.ExpiresAt, req.ID)
 	if jwt == "" {
 		// very unlikely that err is !nil and no jwt returned, but handle it anyways.
@@ -94,7 +97,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
 		)
 		return
 	}
-	err = logic.CreateAccessToken(req)
+	err = req.Create()
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
@@ -124,13 +127,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 		logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
 		return
 	}
-	_, err := logic.GetUser(username)
-	if err != nil {
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "unauthorized"))
-		return
-	}
-
-	logic.ReturnSuccessResponseWithJson(w, r, logic.ListAccessTokens(username), "fetched api access tokens for user "+username)
+	logic.ReturnSuccessResponseWithJson(w, r, (&models.AccessToken{}).ListByUser(), "fetched api access tokens for user "+username)
 }
 
 // @Summary     Authenticate a user to retrieve an authorization token
@@ -149,7 +146,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	err := logic.RevokeAccessToken(models.AccessToken{ID: id})
+	err := (&models.AccessToken{ID: id}).Delete()
 	if err != nil {
 		logic.ReturnErrorResponse(
 			w,
@@ -792,17 +789,12 @@ func deleteUser(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 	}
-	success, err := logic.DeleteUser(username)
+	err = logic.DeleteUser(username)
 	if err != nil {
 		logger.Log(0, username,
 			"failed to delete user: ", err.Error())
 		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
 		return
-	} else if !success {
-		err := errors.New("delete unsuccessful")
-		logger.Log(0, username, err.Error())
-		logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
-		return
 	}
 	// check and delete extclient with this ownerID
 	go func() {

+ 3 - 54
database/database.go

@@ -1,18 +1,12 @@
 package database
 
 import (
-	"crypto/rand"
-	"encoding/json"
 	"errors"
 	"sync"
 	"time"
 
-	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/logger"
-	"github.com/gravitl/netmaker/models"
-	"github.com/gravitl/netmaker/netclient/ncutils"
 	"github.com/gravitl/netmaker/servercfg"
-	"golang.org/x/crypto/nacl/box"
 )
 
 const (
@@ -26,7 +20,7 @@ const (
 	// USERS_TABLE_NAME - users table
 	USERS_TABLE_NAME = "users"
 	// ACCESS_TOKENS_TABLE_NAME - access tokens table
-	ACCESS_TOKENS_TABLE_NAME = "access_tokens"
+	ACCESS_TOKENS_TABLE_NAME = "user_access_tokens"
 	// USER_PERMISSIONS_TABLE_NAME - user permissions table
 	USER_PERMISSIONS_TABLE_NAME = "user_permissions"
 	// CERTS_TABLE_NAME - certificates table
@@ -163,7 +157,7 @@ func InitializeDatabase() error {
 		time.Sleep(2 * time.Second)
 	}
 	createTables()
-	return initializeUUID()
+	return nil
 }
 
 func createTables() {
@@ -176,18 +170,11 @@ func CreateTable(tableName string) error {
 	return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
 }
 
-// IsJSONString - checks if valid json
-func IsJSONString(value string) bool {
-	var jsonInt interface{}
-	var nodeInt models.Node
-	return json.Unmarshal([]byte(value), &jsonInt) == nil || json.Unmarshal([]byte(value), &nodeInt) == nil
-}
-
 // Insert - inserts object into db
 func Insert(key string, value string, tableName string) error {
 	dbMutex.Lock()
 	defer dbMutex.Unlock()
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName)
 	} else {
 		return errors.New("invalid insert " + key + " : " + value)
@@ -235,44 +222,6 @@ func FetchRecords(tableName string) (map[string]string, error) {
 	return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName)
 }
 
-// initializeUUID - create a UUID record for server if none exists
-func initializeUUID() error {
-	records, err := FetchRecords(SERVER_UUID_TABLE_NAME)
-	if err != nil {
-		if !IsEmptyRecord(err) {
-			return err
-		}
-	} else if len(records) > 0 {
-		return nil
-	}
-	// setup encryption keys
-	var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys
-	if errT != nil {
-		return errT
-	}
-	tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
-	if err != nil {
-		return err
-	}
-
-	tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
-	if err != nil {
-		return err
-	}
-
-	telemetry := models.Telemetry{
-		UUID:           uuid.NewString(),
-		TrafficKeyPriv: tPriv,
-		TrafficKeyPub:  tPub,
-	}
-	telJSON, err := json.Marshal(&telemetry)
-	if err != nil {
-		return err
-	}
-
-	return Insert(SERVER_UUID_RECORD_KEY, string(telJSON), SERVER_UUID_TABLE_NAME)
-}
-
 // CloseDB - closes a database gracefully
 func CloseDB() {
 	getCurrentDB()[CLOSE_DB].(func())()

+ 2 - 2
database/postgres.go

@@ -59,7 +59,7 @@ func pgCreateTable(tableName string) error {
 }
 
 func pgInsert(key string, value string, tableName string) error {
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		insertSQL := "INSERT INTO " + tableName + " (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $3;"
 		statement, err := PGDB.Prepare(insertSQL)
 		if err != nil {
@@ -77,7 +77,7 @@ func pgInsert(key string, value string, tableName string) error {
 }
 
 func pgInsertPeer(key string, value string) error {
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		err := pgInsert(key, value, PEERS_TABLE_NAME)
 		if err != nil {
 			return err

+ 2 - 2
database/rqlite.go

@@ -43,7 +43,7 @@ func rqliteCreateTable(tableName string) error {
 }
 
 func rqliteInsert(key string, value string, tableName string) error {
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')")
 		if err != nil {
 			return err
@@ -54,7 +54,7 @@ func rqliteInsert(key string, value string, tableName string) error {
 }
 
 func rqliteInsertPeer(key string, value string) error {
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')")
 		if err != nil {
 			return err

+ 2 - 2
database/sqlite.go

@@ -61,7 +61,7 @@ func sqliteCreateTable(tableName string) error {
 }
 
 func sqliteInsert(key string, value string, tableName string) error {
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
 		statement, err := SqliteDB.Prepare(insertSQL)
 		if err != nil {
@@ -78,7 +78,7 @@ func sqliteInsert(key string, value string, tableName string) error {
 }
 
 func sqliteInsertPeer(key string, value string) error {
-	if key != "" && value != "" && IsJSONString(value) {
+	if key != "" && value != "" {
 		err := sqliteInsert(key, value, PEERS_TABLE_NAME)
 		if err != nil {
 			return err

+ 15 - 2
db/connector.go

@@ -2,7 +2,9 @@ package db
 
 import (
 	"errors"
-	"github.com/gravitl/netmaker/servercfg"
+	"os"
+
+	"github.com/gravitl/netmaker/config"
 	"gorm.io/gorm"
 )
 
@@ -14,10 +16,21 @@ type connector interface {
 	connect() (*gorm.DB, error)
 }
 
+// GetDB - gets the database type
+func GetDB() string {
+	database := "sqlite"
+	if os.Getenv("DATABASE") != "" {
+		database = os.Getenv("DATABASE")
+	} else if config.Config.Server.Database != "" {
+		database = config.Config.Server.Database
+	}
+	return database
+}
+
 // newConnector detects the database being
 // used and returns the corresponding connector.
 func newConnector() (connector, error) {
-	switch servercfg.GetDB() {
+	switch GetDB() {
 	case "sqlite":
 		return &sqliteConnector{}, nil
 	case "postgres":

+ 70 - 2
db/postgres.go

@@ -2,7 +2,10 @@ package db
 
 import (
 	"fmt"
-	"github.com/gravitl/netmaker/servercfg"
+	"os"
+	"strconv"
+
+	"github.com/gravitl/netmaker/config"
 	"gorm.io/driver/postgres"
 	"gorm.io/gorm"
 	"gorm.io/gorm/logger"
@@ -15,7 +18,7 @@ type postgresConnector struct{}
 // postgresConnector.connect connects and
 // initializes a connection to postgres.
 func (pg *postgresConnector) connect() (*gorm.DB, error) {
-	pgConf := servercfg.GetSQLConf()
+	pgConf := GetSQLConf()
 	dsn := fmt.Sprintf(
 		"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
 		pgConf.Host,
@@ -47,3 +50,68 @@ func (pg *postgresConnector) connect() (*gorm.DB, error) {
 
 	return db, nil
 }
+func GetSQLConf() config.SQLConfig {
+	var cfg config.SQLConfig
+	cfg.Host = GetSQLHost()
+	cfg.Port = GetSQLPort()
+	cfg.Username = GetSQLUser()
+	cfg.Password = GetSQLPass()
+	cfg.DB = GetSQLDB()
+	cfg.SSLMode = GetSQLSSLMode()
+	return cfg
+}
+func GetSQLHost() string {
+	host := "localhost"
+	if os.Getenv("SQL_HOST") != "" {
+		host = os.Getenv("SQL_HOST")
+	} else if config.Config.SQL.Host != "" {
+		host = config.Config.SQL.Host
+	}
+	return host
+}
+func GetSQLPort() int32 {
+	port := int32(5432)
+	envport, err := strconv.Atoi(os.Getenv("SQL_PORT"))
+	if err == nil && envport != 0 {
+		port = int32(envport)
+	} else if config.Config.SQL.Port != 0 {
+		port = config.Config.SQL.Port
+	}
+	return port
+}
+func GetSQLUser() string {
+	user := "postgres"
+	if os.Getenv("SQL_USER") != "" {
+		user = os.Getenv("SQL_USER")
+	} else if config.Config.SQL.Username != "" {
+		user = config.Config.SQL.Username
+	}
+	return user
+}
+func GetSQLPass() string {
+	pass := "nopass"
+	if os.Getenv("SQL_PASS") != "" {
+		pass = os.Getenv("SQL_PASS")
+	} else if config.Config.SQL.Password != "" {
+		pass = config.Config.SQL.Password
+	}
+	return pass
+}
+func GetSQLDB() string {
+	db := "netmaker"
+	if os.Getenv("SQL_DB") != "" {
+		db = os.Getenv("SQL_DB")
+	} else if config.Config.SQL.DB != "" {
+		db = config.Config.SQL.DB
+	}
+	return db
+}
+func GetSQLSSLMode() string {
+	sslmode := "disable"
+	if os.Getenv("SQL_SSL_MODE") != "" {
+		sslmode = os.Getenv("SQL_SSL_MODE")
+	} else if config.Config.SQL.SSLMode != "" {
+		sslmode = config.Config.SQL.SSLMode
+	}
+	return sslmode
+}

+ 4 - 76
logic/auth.go

@@ -138,77 +138,6 @@ func FetchPassValue(newValue string) (string, error) {
 	return string(b64CurrentValue), nil
 }
 
-func RevokeAccessToken(a models.AccessToken) error {
-	err := database.DeleteRecord(database.ACCESS_TOKENS_TABLE_NAME, a.ID)
-	if err != nil {
-		return err
-	}
-	return nil
-}
-func RevokeAllUserTokens(username string) {
-	collection, err := database.FetchRecords(database.USERS_TABLE_NAME)
-	if err != nil {
-		return
-	}
-
-	for key, value := range collection {
-
-		var a models.AccessToken
-		err = json.Unmarshal([]byte(value), &a)
-		if err != nil {
-			continue // get users
-		}
-		if a.UserName == username {
-			database.DeleteRecord(database.ACCESS_TOKENS_TABLE_NAME, key)
-		}
-	}
-}
-
-func GetAccessToken(k string) (a models.AccessToken, err error) {
-	value, err := database.FetchRecord(database.ACCESS_TOKENS_TABLE_NAME, k)
-	if err != nil {
-		return
-	}
-	err = json.Unmarshal([]byte(value), &a)
-	return
-}
-
-func ListAccessTokens(username string) (tokens []models.AccessToken) {
-	collection, err := database.FetchRecords(database.ACCESS_TOKENS_TABLE_NAME)
-	if err != nil {
-		return
-	}
-
-	for _, value := range collection {
-
-		var a models.AccessToken
-		err = json.Unmarshal([]byte(value), &a)
-		if err != nil {
-			continue // get users
-		}
-		if a.UserName == username {
-			tokens = append(tokens, a)
-		}
-
-	}
-	return
-}
-
-func CreateAccessToken(a models.AccessToken) error {
-
-	data, err := json.Marshal(a)
-	if err != nil {
-		logger.Log(0, "failed to marshal", err.Error())
-		return err
-	}
-	err = database.Insert(a.ID, string(data), database.ACCESS_TOKENS_TABLE_NAME)
-	if err != nil {
-		logger.Log(0, "failed to insert user", err.Error())
-		return err
-	}
-	return nil
-}
-
 // CreateUser - creates a user
 func CreateUser(user *models.User) error {
 	// check if user exists
@@ -420,19 +349,18 @@ func ValidateUser(user *models.User) error {
 }
 
 // DeleteUser - deletes a given user
-func DeleteUser(user string) (bool, error) {
+func DeleteUser(user string) error {
 
 	if userRecord, err := database.FetchRecord(database.USERS_TABLE_NAME, user); err != nil || len(userRecord) == 0 {
-		return false, errors.New("user does not exist")
+		return errors.New("user does not exist")
 	}
 
 	err := database.DeleteRecord(database.USERS_TABLE_NAME, user)
 	if err != nil {
-		return false, err
+		return err
 	}
 	go RemoveUserFromAclPolicy(user)
-	go RevokeAllUserTokens(user)
-	return true, nil
+	return (&models.AccessToken{UserName: user}).DeleteAllUserTokens()
 }
 
 func SetAuthSecret(secret string) error {

+ 4 - 1
logic/jwts.go

@@ -154,12 +154,15 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
 	if claims.TokenType == models.AccessTokenType {
 		jti := claims.ID
 		if jti != "" {
+			a := models.AccessToken{ID: jti}
 			// check if access token is active
-			_, err := GetAccessToken(jti)
+			err := a.Get()
 			if err != nil {
 				err = errors.New("token revoked")
 				return "", false, false, err
 			}
+			a.LastUsed = time.Now()
+			a.Update()
 		}
 	}
 	if token != nil && token.Valid {

+ 43 - 1
main.go

@@ -3,6 +3,8 @@ package main
 
 import (
 	"context"
+	"crypto/rand"
+	"encoding/json"
 	"flag"
 	"fmt"
 	"os"
@@ -12,6 +14,7 @@ import (
 	"sync"
 	"syscall"
 
+	"github.com/google/uuid"
 	"github.com/gravitl/netmaker/config"
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/database"
@@ -25,6 +28,7 @@ import (
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/serverctl"
 	_ "go.uber.org/automaxprocs"
+	"golang.org/x/crypto/nacl/box"
 	"golang.org/x/exp/slog"
 )
 
@@ -100,7 +104,7 @@ func initialize() { // Client Mode Prereq Check
 		logger.FatalLog("Error connecting to database: ", err.Error())
 	}
 	logger.Log(0, "database successfully connected")
-
+	initializeUUID()
 	//initialize cache
 	_, _ = logic.GetNetworks()
 	_, _ = logic.GetAllNodes()
@@ -247,3 +251,41 @@ func setGarbageCollection() {
 		debug.SetGCPercent(ncutils.DEFAULT_GC_PERCENT)
 	}
 }
+
+// initializeUUID - create a UUID record for server if none exists
+func initializeUUID() error {
+	records, err := database.FetchRecords(database.SERVER_UUID_TABLE_NAME)
+	if err != nil {
+		if !database.IsEmptyRecord(err) {
+			return err
+		}
+	} else if len(records) > 0 {
+		return nil
+	}
+	// setup encryption keys
+	var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys
+	if errT != nil {
+		return errT
+	}
+	tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey)
+	if err != nil {
+		return err
+	}
+
+	tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey)
+	if err != nil {
+		return err
+	}
+
+	telemetry := models.Telemetry{
+		UUID:           uuid.NewString(),
+		TrafficKeyPriv: tPriv,
+		TrafficKeyPub:  tPub,
+	}
+	telJSON, err := json.Marshal(&telemetry)
+	if err != nil {
+		return err
+	}
+
+	return database.Insert(database.SERVER_UUID_RECORD_KEY, string(telJSON), database.SERVER_UUID_TABLE_NAME)
+}

+ 48 - 7
models/accessToken.go

@@ -1,16 +1,57 @@
 package models
 
 import (
+	"context"
 	"time"
+
+	"github.com/gravitl/netmaker/db"
 )
 
+// accessTokenTableName - access tokens table
+const accessTokenTableName = "user_access_tokens"
+
 // AccessToken - token used to access netmaker
 type AccessToken struct {
-	ID        string    `json:"id"`
-	Name      string    `json:"name"`
-	UserName  string    `json:"user_name"`
-	ExpiresAt time.Time `json:"expires_at"`
-	LastUsed  time.Time `json:"last_used"`
-	CreatedBy time.Time `json:"created_by"`
-	CreatedAt time.Time `json:"created_at"`
+	ID        string    `gorm:"id,primary_key" json:"id"`
+	Name      string    `gorm:"name" json:"name"`
+	UserName  string    `gorm:"user_name" json:"user_name"`
+	ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
+	LastUsed  time.Time `gorm:"last_used" json:"last_used"`
+	CreatedBy string    `gorm:"created_by" json:"created_by"`
+	CreatedAt time.Time `gorm:"created_at" json:"created_at"`
+}
+
+func (a *AccessToken) Table() string {
+	return accessTokenTableName
+}
+
+func (a *AccessToken) Get() error {
+	return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
+}
+
+func (a *AccessToken) Update() error {
+	return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
+}
+
+func (a *AccessToken) Create() error {
+	return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
+}
+
+func (a *AccessToken) List() (ats []AccessToken, err error) {
+	err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
+	return
+}
+
+func (a *AccessToken) ListByUser() (ats []AccessToken) {
+	db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
+	return
+}
+
+func (a *AccessToken) Delete() error {
+	return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
+}
+
+func (a *AccessToken) DeleteAllUserTokens() error {
+	return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Delete(&a).Error
+
 }