2
0
Эх сурвалжийг харах

Merge pull request #3488 from gravitl/patch/db-conn-pool

Patch: Use single db handle and use connection pool
Abhishek K 3 сар өмнө
parent
commit
deb3be363b

+ 5 - 0
controllers/network_test.go

@@ -2,6 +2,8 @@ package controller
 
 import (
 	"context"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"os"
 	"testing"
 
@@ -23,6 +25,9 @@ type NetworkValidationTestCase struct {
 var netHost models.Host
 
 func TestMain(m *testing.M) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	logic.CreateSuperAdmin(&models.User{

+ 3 - 0
controllers/server.go

@@ -233,6 +233,9 @@ func getConfig(w http.ResponseWriter, r *http.Request) {
 	if servercfg.IsPro {
 		scfg.IsPro = "yes"
 	}
+
+	scfg.ClientID = logic.Mask()
+	scfg.ClientSecret = logic.Mask()
 	json.NewEncoder(w).Encode(scfg)
 	// w.WriteHeader(http.StatusOK)
 }

+ 11 - 15
database/postgres.go

@@ -1,11 +1,12 @@
 package database
 
 import (
+	"context"
 	"database/sql"
 	"errors"
-	"fmt"
+	"github.com/gravitl/netmaker/db"
+	"time"
 
-	"github.com/gravitl/netmaker/servercfg"
 	_ "github.com/lib/pq"
 )
 
@@ -25,24 +26,19 @@ var PG_FUNCTIONS = map[string]interface{}{
 	isConnected:  pgIsConnected,
 }
 
-func getPGConnString() string {
-	pgconf := servercfg.GetSQLConf()
-	pgConn := fmt.Sprintf("host=%s port=%d user=%s "+
-		"password=%s dbname=%s sslmode=%s connect_timeout=5",
-		pgconf.Host, pgconf.Port, pgconf.Username, pgconf.Password, pgconf.DB, pgconf.SSLMode)
-	return pgConn
-}
-
 func initPGDB() error {
-	connString := getPGConnString()
+	gormDB := db.FromContext(db.WithContext(context.TODO()))
+
 	var dbOpenErr error
-	PGDB, dbOpenErr = sql.Open("postgres", connString)
+	PGDB, dbOpenErr = gormDB.DB()
 	if dbOpenErr != nil {
 		return dbOpenErr
 	}
-	dbOpenErr = PGDB.Ping()
 
-	return dbOpenErr
+	PGDB.SetMaxOpenConns(5)
+	PGDB.SetConnMaxLifetime(time.Hour)
+
+	return PGDB.Ping()
 }
 
 func pgCreateTable(tableName string) error {
@@ -134,7 +130,7 @@ func pgFetchRecords(tableName string) (map[string]string, error) {
 }
 
 func pgCloseDB() {
-	PGDB.Close()
+	//PGDB.Close()
 }
 
 func pgIsConnected() bool {

+ 11 - 17
database/sqlite.go

@@ -1,17 +1,15 @@
 package database
 
 import (
+	"context"
 	"database/sql"
 	"errors"
-	"os"
-	"path/filepath"
+	"github.com/gravitl/netmaker/db"
+	"time"
 
 	_ "github.com/mattn/go-sqlite3" // need to blank import this package
 )
 
-// == sqlite ==
-const dbFilename = "netmaker.db"
-
 // SqliteDB is the db object for sqlite database connections
 var SqliteDB *sql.DB
 
@@ -29,21 +27,17 @@ var SQLITE_FUNCTIONS = map[string]interface{}{
 }
 
 func initSqliteDB() error {
-	// == create db file if not present ==
-	if _, err := os.Stat("data"); os.IsNotExist(err) {
-		os.Mkdir("data", 0700)
-	}
-	dbFilePath := filepath.Join("data", dbFilename)
-	if _, err := os.Stat(dbFilePath); os.IsNotExist(err) {
-		os.Create(dbFilePath)
-	}
-	// == "connect" the database ==
+	gormDB := db.FromContext(db.WithContext(context.TODO()))
+
 	var dbOpenErr error
-	SqliteDB, dbOpenErr = sql.Open("sqlite3", dbFilePath)
+	SqliteDB, dbOpenErr = gormDB.DB()
 	if dbOpenErr != nil {
 		return dbOpenErr
 	}
-	SqliteDB.SetMaxOpenConns(1)
+
+	SqliteDB.SetMaxOpenConns(5)
+	SqliteDB.SetConnMaxLifetime(time.Hour)
+
 	return nil
 }
 
@@ -134,7 +128,7 @@ func sqliteFetchRecords(tableName string) (map[string]string, error) {
 }
 
 func sqliteCloseDB() {
-	SqliteDB.Close()
+	//SqliteDB.Close()
 }
 
 func sqliteConnected() bool {

+ 21 - 0
db/db.go

@@ -110,3 +110,24 @@ func BeginTx(ctx context.Context) context.Context {
 
 	return context.WithValue(ctx, dbCtxKey, dbInCtx.Begin())
 }
+
+// CloseDB close a connection to the database
+// (if one exists). It panics if any error
+// occurs.
+func CloseDB() {
+	if db == nil {
+		return
+	}
+
+	sqlDB, err := db.DB()
+	if err != nil {
+		panic(err)
+	}
+
+	err = sqlDB.Close()
+	if err != nil {
+		panic(err)
+	}
+
+	db = nil
+}

+ 4 - 19
db/postgres.go

@@ -2,6 +2,7 @@ package db
 
 import (
 	"fmt"
+	"github.com/gravitl/netmaker/servercfg"
 	"os"
 	"strconv"
 
@@ -18,7 +19,7 @@ type postgresConnector struct{}
 // postgresConnector.connect connects and
 // initializes a connection to postgres.
 func (pg *postgresConnector) connect() (*gorm.DB, error) {
-	pgConf := GetSQLConf()
+	pgConf := servercfg.GetSQLConf()
 	dsn := fmt.Sprintf(
 		"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
 		pgConf.Host,
@@ -29,27 +30,11 @@ func (pg *postgresConnector) connect() (*gorm.DB, error) {
 		pgConf.SSLMode,
 	)
 
-	db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
+	return gorm.Open(postgres.Open(dsn), &gorm.Config{
 		Logger: logger.Default.LogMode(logger.Silent),
 	})
-	if err != nil {
-		return nil, err
-	}
-
-	// ensure netmaker_v1 schema exists.
-	err = db.Exec("CREATE SCHEMA IF NOT EXISTS netmaker_v1").Error
-	if err != nil {
-		return nil, err
-	}
-
-	// set the netmaker_v1 schema as the default schema.
-	err = db.Exec("SET search_path TO netmaker_v1").Error
-	if err != nil {
-		return nil, err
-	}
-
-	return db, nil
 }
+
 func GetSQLConf() config.SQLConfig {
 	var cfg config.SQLConfig
 	cfg.Host = GetSQLHost()

+ 5 - 0
functions/helpers_test.go

@@ -3,6 +3,8 @@ package functions
 import (
 	"context"
 	"encoding/json"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"os"
 	"testing"
 
@@ -23,6 +25,9 @@ var (
 )
 
 func TestMain(m *testing.M) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	logic.CreateSuperAdmin(&models.User{

+ 20 - 0
logic/enrollmentkey_test.go

@@ -1,6 +1,8 @@
 package logic
 
 import (
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"testing"
 	"time"
 
@@ -11,6 +13,9 @@ import (
 )
 
 func TestCreateEnrollmentKey(t *testing.T) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	t.Run("Can_Not_Create_Key", func(t *testing.T) {
@@ -60,6 +65,9 @@ func TestCreateEnrollmentKey(t *testing.T) {
 }
 
 func TestDelete_EnrollmentKey(t *testing.T) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
@@ -81,6 +89,9 @@ func TestDelete_EnrollmentKey(t *testing.T) {
 }
 
 func TestDecrement_EnrollmentKey(t *testing.T) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
@@ -105,6 +116,9 @@ func TestDecrement_EnrollmentKey(t *testing.T) {
 }
 
 func TestUsability_EnrollmentKey(t *testing.T) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false)
@@ -143,6 +157,9 @@ func removeAllEnrollments() {
 //Test that cheks if it can't tokenize
 
 func TestTokenize_EnrollmentKeys(t *testing.T) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)
@@ -176,6 +193,9 @@ func TestTokenize_EnrollmentKeys(t *testing.T) {
 }
 
 func TestDeTokenize_EnrollmentKeys(t *testing.T) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false)

+ 8 - 0
logic/host_test.go

@@ -3,6 +3,8 @@ package logic
 import (
 	"context"
 	"fmt"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"net"
 	"os"
 	"testing"
@@ -14,6 +16,9 @@ import (
 )
 
 func TestMain(m *testing.M) {
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	defer database.CloseDB()
 	peerUpdate := make(chan *models.Node)
@@ -41,6 +46,9 @@ func TestCheckPorts(t *testing.T) {
 	}
 	//not sure why this initialization is required but without it
 	// RemoveHost returns database is closed
+	db.InitializeDB(schema.ListModels()...)
+	defer db.CloseDB()
+
 	database.InitializeDatabase()
 	RemoveHost(&h, true)
 	CreateHost(&h)

+ 11 - 6
main.go

@@ -7,6 +7,8 @@ import (
 	"encoding/json"
 	"flag"
 	"fmt"
+	"github.com/gravitl/netmaker/db"
+	"github.com/gravitl/netmaker/schema"
 	"os"
 	"os/signal"
 	"path/filepath"
@@ -18,7 +20,6 @@ import (
 	"github.com/gravitl/netmaker/config"
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/database"
-	"github.com/gravitl/netmaker/db"
 	"github.com/gravitl/netmaker/functions"
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/logic"
@@ -26,7 +27,6 @@ import (
 	"github.com/gravitl/netmaker/models"
 	"github.com/gravitl/netmaker/mq"
 	"github.com/gravitl/netmaker/netclient/ncutils"
-	"github.com/gravitl/netmaker/schema"
 	"github.com/gravitl/netmaker/servercfg"
 	"github.com/gravitl/netmaker/serverctl"
 	_ "go.uber.org/automaxprocs"
@@ -62,6 +62,7 @@ func main() {
 	if servercfg.DeployedByOperator() && !servercfg.IsPro {
 		logic.SetFreeTierLimits()
 	}
+	defer db.CloseDB()
 	defer database.CloseDB()
 	ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, os.Interrupt)
 	defer stop()
@@ -102,15 +103,19 @@ func initialize() { // Client Mode Prereq Check
 		logger.FatalLog("error: must set NODE_ID, currently blank")
 	}
 
-	if err = database.InitializeDatabase(); err != nil {
-		logger.FatalLog("Error connecting to database: ", err.Error())
-	}
 	// initialize sql schema db.
 	err = db.InitializeDB(schema.ListModels()...)
 	if err != nil {
-		logger.FatalLog("Error connecting to v1 database: ", err.Error())
+		logger.FatalLog("error connecting to database: ", err.Error())
 	}
+
 	logger.Log(0, "database successfully connected")
+
+	// initialize kv schema db.
+	if err = database.InitializeDatabase(); err != nil {
+		logger.FatalLog("error initializing database: ", err.Error())
+	}
+
 	initializeUUID()
 	//initialize cache
 	_, _ = logic.GetNetworks()