Bläddra i källkod

use semaphore to protect database calls

Matthew R Kasun 3 år sedan
förälder
incheckning
a322894c7f
1 ändrade filer med 39 tillägg och 0 borttagningar
  1. 39 0
      database/database.go

+ 39 - 0
database/database.go

@@ -1,12 +1,14 @@
 package database
 
 import (
+	"context"
 	"encoding/json"
 	"errors"
 	"time"
 
 	"github.com/gravitl/netmaker/logger"
 	"github.com/gravitl/netmaker/servercfg"
+	"golang.org/x/sync/semaphore"
 )
 
 // NETWORKS_TABLE_NAME - networks table
@@ -89,6 +91,9 @@ func getCurrentDB() map[string]interface{} {
 	}
 }
 
+var ctx context.Context
+var sem semaphore.Weighted
+
 // InitializeDatabase - initializes database
 func InitializeDatabase() error {
 	logger.Log(0, "connecting to", servercfg.GetDB())
@@ -104,6 +109,8 @@ func InitializeDatabase() error {
 		}
 		time.Sleep(2 * time.Second)
 	}
+	ctx = context.TODO()
+	sem = semaphore.NewWeighted(1)
 	createTables()
 	return nil
 }
@@ -122,6 +129,10 @@ func createTables() {
 }
 
 func createTable(tableName string) error {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	return getCurrentDB()[CREATE_TABLE].(func(string) error)(tableName)
 }
 
@@ -133,6 +144,10 @@ func IsJSONString(value string) bool {
 
 // Insert - inserts object into db
 func Insert(key string, value string, tableName string) error {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	if key != "" && value != "" && IsJSONString(value) {
 		return getCurrentDB()[INSERT].(func(string, string, string) error)(key, value, tableName)
 	} else {
@@ -142,6 +157,10 @@ func Insert(key string, value string, tableName string) error {
 
 // InsertPeer - inserts peer into db
 func InsertPeer(key string, value string) error {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	if key != "" && value != "" && IsJSONString(value) {
 		return getCurrentDB()[INSERT_PEER].(func(string, string) error)(key, value)
 	} else {
@@ -151,11 +170,19 @@ func InsertPeer(key string, value string) error {
 
 // DeleteRecord - deletes a record from db
 func DeleteRecord(tableName string, key string) error {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	return getCurrentDB()[DELETE].(func(string, string) error)(tableName, key)
 }
 
 // DeleteAllRecords - removes a table and remakes
 func DeleteAllRecords(tableName string) error {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	err := getCurrentDB()[DELETE_ALL].(func(string) error)(tableName)
 	if err != nil {
 		return err
@@ -169,6 +196,10 @@ func DeleteAllRecords(tableName string) error {
 
 // FetchRecord - fetches a record
 func FetchRecord(tableName string, key string) (string, error) {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return "", errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	results, err := FetchRecords(tableName)
 	if err != nil {
 		return "", err
@@ -181,10 +212,18 @@ func FetchRecord(tableName string, key string) (string, error) {
 
 // FetchRecords - fetches all records in given table
 func FetchRecords(tableName string) (map[string]string, error) {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		return nil, errors.New("semphore error")
+	}
+	defer sem.Release(1)
 	return getCurrentDB()[FETCH_ALL].(func(string) (map[string]string, error))(tableName)
 }
 
 // CloseDB - closes a database gracefully
 func CloseDB() {
+	if err := sem.Acquire(ctx, 1); err != nil {
+		logger.Log(0, "semaphore error closing DB")
+	}
+	defer sem.Release(1)
 	getCurrentDB()[CLOSE_DB].(func())()
 }