Browse Source

Merge pull request #339 from gravitl/feature_v0.9_postgres

postgres working
Alex 3 years ago
parent
commit
dd1e97a3b0
7 changed files with 223 additions and 2 deletions
  1. 12 0
      config/config.go
  2. 2 0
      database/database.go
  3. 131 0
      database/postgres.go
  4. 1 0
      go.mod
  5. 2 0
      go.sum
  6. 2 2
      servercfg/serverconf.go
  7. 73 0
      servercfg/sqlconf.go

+ 12 - 0
config/config.go

@@ -30,6 +30,7 @@ var Config *EnvironmentConfig
 // EnvironmentConfig :
 type EnvironmentConfig struct {
 	Server ServerConfig `yaml:"server"`
+	SQL SQLConfig `yaml:"sql"`
 }
 
 // ServerConfig :
@@ -61,6 +62,17 @@ type ServerConfig struct {
 	Verbosity            int32  `yaml:"verbosity"`
 }
 
+
+// Generic SQL Config
+type SQLConfig struct {
+	Host string `yaml:"host"`
+	Port int32 `yaml:"port"`
+	Username string `yaml:"username"`
+	Password string `yaml:"password"`
+	DB string `yaml:"db"`
+	SSLMode string `yaml:"sslmode"`
+}
+
 //reading in the env file
 func readConfig() *EnvironmentConfig {
 	file := fmt.Sprintf("config/environments/%s.yaml", getEnv())

+ 2 - 0
database/database.go

@@ -38,6 +38,8 @@ func getCurrentDB() map[string]interface{} {
 		return RQLITE_FUNCTIONS
 	case "sqlite":
 		return SQLITE_FUNCTIONS
+	case "postgres":
+		return PG_FUNCTIONS
 	default:
 		return SQLITE_FUNCTIONS
 	}

+ 131 - 0
database/postgres.go

@@ -0,0 +1,131 @@
+package database
+
+import (
+	"github.com/gravitl/netmaker/servercfg"
+	"database/sql"
+	"errors"
+	_ "github.com/lib/pq"
+	"fmt"
+)
+
+var PGDB *sql.DB
+
+var PG_FUNCTIONS = map[string]interface{}{
+	INIT_DB:      initPGDB,
+	CREATE_TABLE: pgCreateTable,
+	INSERT:       pgInsert,
+	INSERT_PEER:  pgInsertPeer,
+	DELETE:       pgDeleteRecord,
+	DELETE_ALL:   pgDeleteAllRecords,
+	FETCH_ALL:    pgFetchRecords,
+	CLOSE_DB:     pgCloseDB,
+}
+
+func getPGConnString() string{
+	pgconf := servercfg.GetSQLConf()
+	pgConn := fmt.Sprintf("host=%s port=%d user=%s "+
+	  "password=%s dbname=%s sslmode=%s",
+	  pgconf.Host, pgconf.Port, pgconf.Username, pgconf.Password, pgconf.DB, pgconf.SSLMode)
+	return pgConn
+}
+  
+
+func initPGDB() error {
+	connString := getPGConnString()
+	var dbOpenErr error
+	PGDB, dbOpenErr = sql.Open("postgres", connString)
+	if dbOpenErr != nil {
+		return dbOpenErr
+	}
+	dbOpenErr = PGDB.Ping()
+
+	return dbOpenErr
+}
+
+func pgCreateTable(tableName string) error {
+	statement, err := PGDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
+	if err != nil {
+		return err
+	}
+	_, err = statement.Exec()
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func pgInsert(key string, value string, tableName string) error {
+	if key != "" && value != "" && IsJSONString(value) {
+		insertSQL := "INSERT INTO " + tableName + " (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $3;"
+		statement, err := PGDB.Prepare(insertSQL)
+		if err != nil {
+			return err
+		}
+		_, err = statement.Exec(key, value, value)
+		if err != nil {
+			return err
+		}
+		return nil
+	} else {
+		return errors.New("invalid insert " + key + " : " + value)
+	}
+}
+
+func pgInsertPeer(key string, value string) error {
+	if key != "" && value != "" && IsJSONString(value) {
+		err := pgInsert(key, value, PEERS_TABLE_NAME)
+		if err != nil {
+			return err
+		}
+		return nil
+	} else {
+		return errors.New("invalid peer insert " + key + " : " + value)
+	}
+}
+
+func pgDeleteRecord(tableName string, key string) error {
+	deleteSQL := "DELETE FROM " + tableName + " WHERE key = \"" + key + "\""
+	statement, err := PGDB.Prepare(deleteSQL)
+	if err != nil {
+		return err
+	}
+	if _, err = statement.Exec(); err != nil {
+		return err
+	}
+	return nil
+}
+
+func pgDeleteAllRecords(tableName string) error {
+	deleteSQL := "DELETE FROM " + tableName
+	statement, err := PGDB.Prepare(deleteSQL)
+	if err != nil {
+		return err
+	}
+	if _, err = statement.Exec(); err != nil {
+		return err
+	}
+	return nil
+}
+
+func pgFetchRecords(tableName string) (map[string]string, error) {
+	row, err := PGDB.Query("SELECT * FROM " + tableName + " ORDER BY key")
+	if err != nil {
+		return nil, err
+	}
+	records := make(map[string]string)
+	defer row.Close()
+	for row.Next() { // Iterate and fetch the records from result cursor
+		var key string
+		var value string
+		row.Scan(&key, &value)
+		records[key] = value
+	}
+	if len(records) == 0 {
+		return nil, errors.New(NO_RECORDS)
+	}
+	return records, nil
+}
+
+func pgCloseDB() {
+	PGDB.Close()
+}

+ 1 - 0
go.mod

@@ -8,6 +8,7 @@ require (
 	github.com/golang/protobuf v1.5.2 // indirect
 	github.com/gorilla/handlers v1.5.1
 	github.com/gorilla/mux v1.8.0
+	github.com/lib/pq v1.10.3 // indirect
 	github.com/mattn/go-sqlite3 v1.14.8
 	github.com/rqlite/gorqlite v0.0.0-20210514125552-08ff1e76b22f
 	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e

+ 2 - 0
go.sum

@@ -73,6 +73,8 @@ github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b h1:c3NTyLNozI
 github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9Rh8m+aHZIG69YPGGem1i5VzoyRC8nw2kA8B+ik5U=
 github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y=
 github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
+github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg=
+github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
 github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ=
 github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
 github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=

+ 2 - 2
servercfg/serverconf.go

@@ -82,9 +82,9 @@ func GetVersion() string {
 }
 func GetDB() string {
 	database := "sqlite"
-	if os.Getenv("DATABASE") == "rqlite" {
+	if os.Getenv("DATABASE") != "" {
 		database = os.Getenv("DATABASE")
-	} else if config.Config.Server.Database == "rqlite" {
+	} else if config.Config.Server.Database != "" {
 		database = config.Config.Server.Database
 	}
 	return database

+ 73 - 0
servercfg/sqlconf.go

@@ -0,0 +1,73 @@
+package servercfg
+
+import (
+	"os"
+	"github.com/gravitl/netmaker/config"
+	"strconv"
+)
+
+func GetSQLConf() config.SQLConfig {
+	var cfg config.SQLConfig
+	cfg.Host = GetSQLHost()
+	cfg.Port = GetSQLPort()
+	cfg.Username = GetSQLUser()
+	cfg.Password = GetSQLPass()
+	cfg.DB = GetSQLDB()
+	cfg.SSLMode = GetSQLSSLMode()
+	return cfg
+}
+func GetSQLHost() string {
+	host := "localhost"
+	if os.Getenv("SQL_HOST") != "" {
+		host = os.Getenv("SQL_HOST")
+	} else if config.Config.SQL.Host != "" {
+		host = config.Config.SQL.Host
+	}
+	return host
+}
+func GetSQLPort() int32 {
+	port := int32(5432)
+	envport, err := strconv.Atoi(os.Getenv("SQL_PORT"))
+	if err == nil && envport != 0 {
+		port = int32(envport)
+	} else if config.Config.SQL.Port != 0 {
+		port = config.Config.SQL.Port
+	}
+	return port
+}
+func GetSQLUser() string {
+	user := "posgres"
+	if os.Getenv("SQL_USER") != "" {
+		user = os.Getenv("SQL_USER")
+	} else if config.Config.SQL.Username != "" {
+		user = config.Config.SQL.Username
+	}
+	return user
+}
+func GetSQLPass() string {
+	pass := "nopass"
+	if os.Getenv("SQL_PASS") != "" {
+		pass = os.Getenv("SQL_PASS")
+	} else if config.Config.SQL.Password != "" {
+		pass = config.Config.SQL.Password
+	}
+	return pass
+}
+func GetSQLDB() string {
+	db := "netmaker"
+	if os.Getenv("SQL_DB") != "" {
+		db = os.Getenv("SQL_DB")
+	} else if config.Config.SQL.DB != "" {
+		db = config.Config.SQL.DB
+	}
+	return db
+}
+func GetSQLSSLMode() string {
+	sslmode := "disable"
+	if os.Getenv("SQL_SSL_MODE") != "" {
+		sslmode = os.Getenv("SQL_SSL_MODE")
+	} else if config.Config.SQL.SSLMode != "" {
+		sslmode = config.Config.SQL.SSLMode
+	}
+	return sslmode
+}