浏览代码

added sqlite support and ability to add dbs easier

worker-9 4 年之前
父节点
当前提交
fbb999f36b
共有 8 个文件被更改,包括 323 次插入70 次删除
  1. 1 0
      config/config.go
  2. 35 49
      database/database.go
  3. 117 0
      database/rqlite.go
  4. 137 0
      database/sqlite.go
  5. 3 2
      main.go
  6. 1 1
      netclient/main.go
  7. 18 16
      netclient/wireguard/kernel.go
  8. 11 2
      servercfg/serverconf.go

+ 1 - 0
config/config.go

@@ -53,6 +53,7 @@ type ServerConfig struct {
 	GRPCSSL              string `yaml:"grpcssl"`
 	Version              string `yaml:"version"`
 	SQLConn              string `yaml:"sqlconn"`
+	Database             string `yaml:database`
 	DefaultNodeLimit     int32  `yaml:"defaultnodelimit"`
 	Verbosity            int32  `yaml:"verbosity"`
 }

+ 35 - 49
database/database.go

@@ -3,9 +3,8 @@ package database
 import (
 	"encoding/json"
 	"errors"
-	"log"
+
 	"github.com/gravitl/netmaker/servercfg"
-	"github.com/rqlite/gorqlite"
 )
 
 const NETWORKS_TABLE_NAME = "networks"
@@ -22,19 +21,32 @@ const DATABASE_FILENAME = "netmaker.db"
 const NO_RECORD = "no result found"
 const NO_RECORDS = "could not find any records"
 
-var Database gorqlite.Connection
+// == Constants ==
+const INIT_DB = "init"
+const CREATE_TABLE = "createtable"
+const INSERT = "insert"
+const INSERT_PEER = "insertpeer"
+const DELETE = "delete"
+const DELETE_ALL = "deleteall"
+const FETCH_ALL = "fetchall"
+const CLOSE_DB = "closedb"
+
+func getCurrentDB() map[string]interface{} {
+	switch servercfg.GetDB() {
+	case "rqlite":
+		return RQLITE_FUNCTIONS
+	case "sqlite":
+		return SQLITE_FUNCTIONS
+	default:
+		return RQLITE_FUNCTIONS
+	}
+}
 
 func InitializeDatabase() error {
 
-	//log.Println("sql conn value:",servercfg.GetSQLConn())
-	conn, err := gorqlite.Open(servercfg.GetSQLConn())
-	if err != nil {
+	if err := getCurrentDB()[INIT_DB].(func() error)(); err != nil {
 		return err
 	}
-
-	// sqliteDatabase, _ := sql.Open("sqlite3", "./database/"+dbFilename)
-	Database = conn
-	Database.SetConsistencyLevel("strong")
 	createTables()
 	return nil
 }
@@ -51,52 +63,36 @@ func createTables() {
 }
 
 func createTable(tableName string) error {
-	_, err := Database.WriteOne("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
-	if err != nil {
-		return err
-	}
-	return nil
+	return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
 }
 
-func isJSONString(value string) bool {
+func IsJSONString(value string) bool {
 	var jsonInt interface{}
 	return json.Unmarshal([]byte(value), &jsonInt) == nil
 }
 
 func Insert(key string, value string, tableName string) error {
-	if key != "" && value != "" && isJSONString(value) {
-		_, err := Database.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')")
-		if err != nil {
-			return err
-		}
-		return nil
+	if key != "" && value != "" && IsJSONString(value) {
+		return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName)
 	} else {
 		return errors.New("invalid insert " + key + " : " + value)
 	}
 }
 
 func InsertPeer(key string, value string) error {
-	if key != "" && value != "" && isJSONString(value) {
-		_, err := Database.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')")
-		if err != nil {
-			return err
-		}
-		return nil
+	if key != "" && value != "" && IsJSONString(value) {
+		return getCurrentDB()[INSERT_PEER].(func(string, string) error)(key, value)
 	} else {
 		return errors.New("invalid peer insert " + key + " : " + value)
 	}
 }
 
 func DeleteRecord(tableName string, key string) error {
-	_, err := Database.WriteOne("DELETE FROM " + tableName + " WHERE key = \"" + key + "\"")
-	if err != nil {
-		return err
-	}
-	return nil
+	return getCurrentDB()[DELETE].(func(string, string) error)(tableName, key)
 }
 
 func DeleteAllRecords(tableName string) error {
-	_, err := Database.WriteOne("DELETE TABLE " + tableName)
+	err := getCurrentDB()[DELETE_ALL].(func(string) error)(tableName)
 	if err != nil {
 		return err
 	}
@@ -119,19 +115,9 @@ func FetchRecord(tableName string, key string) (string, error) {
 }
 
 func FetchRecords(tableName string) (map[string]string, error) {
-	row, err := Database.QueryOne("SELECT * FROM " + tableName + " ORDER BY key")
-	if err != nil {
-		return nil, err
-	}
-	records := make(map[string]string)
-	for row.Next() { // Iterate and fetch the records from result cursor
-		var key string
-		var value string
-		row.Scan(&key, &value)
-		records[key] = value
-	}
-	if len(records) == 0 {
-		return nil, errors.New(NO_RECORDS)
-	}
-	return records, nil
+	return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName)
+}
+
+func CloseDB() {
+	getCurrentDB()[CLOSE_DB].(func())()
 }

+ 117 - 0
database/rqlite.go

@@ -0,0 +1,117 @@
+package database
+
+import (
+	"errors"
+
+	"github.com/gravitl/netmaker/servercfg"
+	"github.com/rqlite/gorqlite"
+)
+
+var RQliteDatabase gorqlite.Connection
+
+var RQLITE_FUNCTIONS = map[string]interface{}{
+	INIT_DB:      initRqliteDatabase,
+	CREATE_TABLE: rqliteCreateTable,
+	INSERT:       rqliteInsert,
+	INSERT_PEER:  rqliteInsertPeer,
+	DELETE:       rqliteDeleteRecord,
+	DELETE_ALL:   rqliteDeleteAllRecords,
+	FETCH_ALL:    rqliteFetchRecords,
+	CLOSE_DB:     rqliteCloseDB,
+}
+
+func initRqliteDatabase() error {
+
+	conn, err := gorqlite.Open(servercfg.GetSQLConn())
+	if err != nil {
+		return err
+	}
+	RQliteDatabase = conn
+	RQliteDatabase.SetConsistencyLevel("strong")
+	return nil
+}
+
+func rqliteCreateTable(tableName string) error {
+	_, err := RQliteDatabase.WriteOne("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func rqliteInsert(key string, value string, tableName string) error {
+	if key != "" && value != "" && IsJSONString(value) {
+		_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES ('" + key + "', '" + value + "')")
+		if err != nil {
+			return err
+		}
+		return nil
+	} else {
+		return errors.New("invalid insert " + key + " : " + value)
+	}
+}
+
+func rqliteInsertPeer(key string, value string) error {
+	if key != "" && value != "" && IsJSONString(value) {
+		_, err := RQliteDatabase.WriteOne("INSERT OR REPLACE INTO " + PEERS_TABLE_NAME + " (key, value) VALUES ('" + key + "', '" + value + "')")
+		if err != nil {
+			return err
+		}
+		return nil
+	} else {
+		return errors.New("invalid peer insert " + key + " : " + value)
+	}
+}
+
+func rqliteDeleteRecord(tableName string, key string) error {
+	_, err := RQliteDatabase.WriteOne("DELETE FROM " + tableName + " WHERE key = \"" + key + "\"")
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func rqliteDeleteAllRecords(tableName string) error {
+	_, err := RQliteDatabase.WriteOne("DELETE TABLE " + tableName)
+	if err != nil {
+		return err
+	}
+	err = rqliteCreateTable(tableName)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func rqliteFetchRecord(tableName string, key string) (string, error) {
+	results, err := FetchRecords(tableName)
+	if err != nil {
+		return "", err
+	}
+	if results[key] == "" {
+		return "", errors.New(NO_RECORD)
+	}
+	return results[key], nil
+}
+
+func rqliteFetchRecords(tableName string) (map[string]string, error) {
+	row, err := RQliteDatabase.QueryOne("SELECT * FROM " + tableName + " ORDER BY key")
+	if err != nil {
+		return nil, err
+	}
+	records := make(map[string]string)
+	for row.Next() { // Iterate and fetch the records from result cursor
+		var key string
+		var value string
+		row.Scan(&key, &value)
+		records[key] = value
+	}
+	if len(records) == 0 {
+		return nil, errors.New(NO_RECORDS)
+	}
+	return records, nil
+}
+
+func rqliteCloseDB() {
+	RQliteDatabase.Close()
+}

+ 137 - 0
database/sqlite.go

@@ -0,0 +1,137 @@
+package database
+
+import (
+	"database/sql"
+	"errors"
+	"log"
+	"os"
+	"path/filepath"
+
+	_ "github.com/mattn/go-sqlite3"
+)
+
+// == sqlite ==
+const dbFilename = "netmaker.db"
+
+var SqliteDB *sql.DB
+
+var SQLITE_FUNCTIONS = map[string]interface{}{
+	INIT_DB:      initSqliteDB,
+	CREATE_TABLE: sqliteCreateTable,
+	INSERT:       sqliteInsert,
+	INSERT_PEER:  sqliteInsertPeer,
+	DELETE:       sqliteDeleteRecord,
+	DELETE_ALL:   sqliteDeleteAllRecords,
+	FETCH_ALL:    sqliteFetchRecords,
+	CLOSE_DB:     sqliteCloseDB,
+}
+
+func initSqliteDB() error {
+	// == create db file if not present ==
+	if _, err := os.Stat("data"); os.IsNotExist(err) {
+		log.Println("Could not find data directory, creating it.")
+		os.Mkdir("data", 0644)
+	}
+	dbFilePath := filepath.Join("data", dbFilename)
+	if _, err := os.Stat(dbFilePath); os.IsNotExist(err) {
+		log.Println("Could not get database file, creating it.")
+		os.Create(dbFilePath)
+	}
+	// == "connect" the database ==
+	var dbOpenErr error
+	SqliteDB, dbOpenErr = sql.Open("sqlite3", dbFilePath)
+	if dbOpenErr != nil {
+		return dbOpenErr
+	}
+	return nil
+}
+
+func sqliteCreateTable(tableName string) error {
+	statement, err := SqliteDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
+	if err != nil {
+		return err
+	}
+	_, err = statement.Exec()
+	if err != nil {
+		return err
+	}
+	log.Println(tableName, "table created")
+	return nil
+}
+
+func sqliteInsert(key string, value string, tableName string) error {
+	if key != "" && value != "" && IsJSONString(value) {
+		insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
+		statement, err := SqliteDB.Prepare(insertSQL)
+		if err != nil {
+			return err
+		}
+		_, err = statement.Exec(key, value)
+		if err != nil {
+			return err
+		}
+		log.Println("inserted", key, ":", value, "into ", tableName)
+		return nil
+	} else {
+		return errors.New("invalid insert " + key + " : " + value)
+	}
+}
+
+func sqliteInsertPeer(key string, value string) error {
+	if key != "" && value != "" && IsJSONString(value) {
+		err := sqliteInsert(key, value, PEERS_TABLE_NAME)
+		if err != nil {
+			return err
+		}
+		return nil
+	} else {
+		return errors.New("invalid peer insert " + key + " : " + value)
+	}
+}
+
+func sqliteDeleteRecord(tableName string, key string) error {
+	deleteSQL := "DELETE FROM " + tableName + " WHERE key = \"" + key + "\""
+	statement, err := SqliteDB.Prepare(deleteSQL)
+	if err != nil {
+		return err
+	}
+	if _, err = statement.Exec(); err != nil {
+		return err
+	}
+	return nil
+}
+
+func sqliteDeleteAllRecords(tableName string) error {
+	deleteSQL := "DELETE FROM " + tableName
+	statement, err := SqliteDB.Prepare(deleteSQL)
+	if err != nil {
+		return err
+	}
+	if _, err = statement.Exec(); err != nil {
+		return err
+	}
+	return nil
+}
+
+func sqliteFetchRecords(tableName string) (map[string]string, error) {
+	row, err := SqliteDB.Query("SELECT * FROM " + tableName + " ORDER BY key")
+	if err != nil {
+		return nil, err
+	}
+	records := make(map[string]string)
+	defer row.Close()
+	for row.Next() { // Iterate and fetch the records from result cursor
+		var key string
+		var value string
+		row.Scan(&key, &value)
+		records[key] = value
+	}
+	if len(records) == 0 {
+		return nil, errors.New(NO_RECORDS)
+	}
+	return records, nil
+}
+
+func sqliteCloseDB() {
+	SqliteDB.Close()
+}

+ 3 - 2
main.go

@@ -11,6 +11,7 @@ import (
 	"os/signal"
 	"strconv"
 	"sync"
+
 	controller "github.com/gravitl/netmaker/controllers"
 	"github.com/gravitl/netmaker/database"
 	"github.com/gravitl/netmaker/functions"
@@ -25,7 +26,7 @@ import (
 func main() {
 	fmt.Println(models.RetrieveLogo()) // print the logo
 	initialize()                       // initial db and grpc server
-	defer database.Database.Close()
+	defer database.CloseDB()
 	startControllers() // start the grpc or rest endpoints
 }
 
@@ -40,7 +41,7 @@ func initialize() { // Client Mode Prereq Check
 
 	if err != nil {
 		log.Println("Error running 'id -u' for prereq check. Please investigate or disable client mode.")
-		log.Fatal(err)
+		log.Fatal(output, err)
 	}
 	uid, err := strconv.Atoi(string(output[:len(output)-1]))
 	if err != nil {

+ 1 - 1
netclient/main.go

@@ -316,7 +316,7 @@ func main() {
 	out, err := local.RunCmd("id -u")
 
 	if err != nil {
-		log.Fatal(err)
+		log.Fatal(out, err)
 	}
 	id, err := strconv.Atoi(string(out[:len(out)-1]))
 

+ 18 - 16
netclient/wireguard/kernel.go

@@ -64,17 +64,18 @@ func InitWireguard(node *models.Node, privkey string, peers []wgtypes.PeerConfig
 		network = node.Network
 	}
 
-	_, delErr := local.RunCmd("ip link delete dev " + ifacename)
-	_, addLinkErr := local.RunCmd(ipExec + " link add dev " + ifacename + " type wireguard")
-	_, addErr := local.RunCmd(ipExec + " address add dev " + ifacename + " " + node.Address + "/24")
+	delOut, delErr := local.RunCmd("ip link delete dev " + ifacename)
+	addLinkOut, addLinkErr := local.RunCmd(ipExec + " link add dev " + ifacename + " type wireguard")
+	addOut, addErr := local.RunCmd(ipExec + " address add dev " + ifacename + " " + node.Address + "/24")
 	if delErr != nil {
 		// pass
+		log.Println(delOut, delErr)
 	}
 	if addLinkErr != nil {
-		log.Println(addLinkErr)
+		log.Println(addLinkOut, addLinkErr)
 	}
 	if addErr != nil {
-		log.Println(addErr)
+		log.Println(addOut, addErr)
 	}
 	var nodeport int
 	nodeport = int(node.ListenPort)
@@ -162,16 +163,16 @@ func InitWireguard(node *models.Node, privkey string, peers []wgtypes.PeerConfig
 			out, err := local.RunCmd(ipExec + " -4 route add " + gateway + " dev " + ifacename)
 			fmt.Println(string(out))
 			if err != nil {
-				fmt.Println("Error encountered adding gateway: " + err.Error())
+				fmt.Println("error encountered adding gateway: " + err.Error())
 			}
 		}
 	}
 	if node.Address6 != "" && node.IsDualStack == "yes" {
-		fmt.Println("Adding address: " + node.Address6)
+		fmt.Println("adding address: " + node.Address6)
 		out, err := local.RunCmd(ipExec + " address add dev " + ifacename + " " + node.Address6 + "/64")
 		if err != nil {
 			fmt.Println(out)
-			fmt.Println("Error encountered adding ipv6: " + err.Error())
+			fmt.Println("error encountered adding ipv6: " + err.Error())
 		}
 	}
 
@@ -268,9 +269,9 @@ func SetPeers(iface string, keepalive int32, peers []wgtypes.PeerConfig) error {
 		for _, currentPeer := range devicePeers {
 			if currentPeer.AllowedIPs[0].String() == peer.AllowedIPs[0].String() &&
 				currentPeer.PublicKey.String() != peer.PublicKey.String() {
-				_, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove")
+				output, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove")
 				if err != nil {
-					log.Println("error removing peer", peer.Endpoint.String())
+					log.Println(output, "error removing peer", peer.Endpoint.String())
 				}
 			}
 		}
@@ -285,18 +286,19 @@ func SetPeers(iface string, keepalive int32, peers []wgtypes.PeerConfig) error {
 		if keepAliveString == "0" {
 			keepAliveString = "5"
 		}
+		var output string
 		if peer.Endpoint != nil {
-			_, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() +
+			output, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() +
 				" endpoint " + udpendpoint +
 				" persistent-keepalive " + keepAliveString +
 				" allowed-ips " + allowedips)
 		} else {
-			_, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() +
+			output, err = local.RunCmd("wg set " + iface + " peer " + peer.PublicKey.String() +
 				" persistent-keepalive " + keepAliveString +
 				" allowed-ips " + allowedips)
 		}
 		if err != nil {
-			log.Println("error setting peer", peer.PublicKey.String(), err)
+			log.Println(output, "error setting peer", peer.PublicKey.String(), err)
 		}
 	}
 
@@ -308,15 +310,15 @@ func SetPeers(iface string, keepalive int32, peers []wgtypes.PeerConfig) error {
 			}
 		}
 		if shouldDelete {
-			_, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove")
+			output, err := local.RunCmd("wg set " + iface + " peer " + currentPeer.PublicKey.String() + " remove")
 			if err != nil {
-				log.Println("error removing peer", currentPeer.PublicKey.String())
+				log.Println(output, "error removing peer", currentPeer.PublicKey.String())
 			} else {
 				log.Println("removed peer " + currentPeer.PublicKey.String())
 			}
 		}
 	}
-	
+
 	return nil
 }
 

+ 11 - 2
servercfg/serverconf.go

@@ -78,6 +78,15 @@ func GetVersion() string {
 	}
 	return version
 }
+func GetDB() string {
+	database := "rqlite"
+	if os.Getenv("DATABASE") == "sqlite" {
+		database = os.Getenv("DATABASE")
+	} else if config.Config.Server.Database == "sqlite" {
+		database = config.Config.Server.Database
+	}
+	return database
+}
 func GetAPIHost() string {
 	serverhost := "127.0.0.1"
 	remoteip, _ := GetPublicIP()
@@ -313,8 +322,8 @@ func GetSQLConn() string {
 	sqlconn := "http://"
 	if os.Getenv("SQL_CONN") != "" {
 		sqlconn = os.Getenv("SQL_CONN")
-	} else if config.Config.Server.SQLConn != ""  {
+	} else if config.Config.Server.SQLConn != "" {
 		sqlconn = config.Config.Server.SQLConn
 	}
 	return sqlconn
-}
+}