sqlite.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package database
  2. import (
  3. "context"
  4. "database/sql"
  5. "errors"
  6. "github.com/gravitl/netmaker/db"
  7. "time"
  8. _ "github.com/mattn/go-sqlite3" // need to blank import this package
  9. )
  10. // == sqlite ==
  11. const dbFilename = "netmaker.db"
  12. // SqliteDB is the db object for sqlite database connections
  13. var SqliteDB *sql.DB
  14. // SQLITE_FUNCTIONS - contains a map of the functions for sqlite
  15. var SQLITE_FUNCTIONS = map[string]interface{}{
  16. INIT_DB: initSqliteDB,
  17. CREATE_TABLE: sqliteCreateTable,
  18. INSERT: sqliteInsert,
  19. INSERT_PEER: sqliteInsertPeer,
  20. DELETE: sqliteDeleteRecord,
  21. DELETE_ALL: sqliteDeleteAllRecords,
  22. FETCH_ALL: sqliteFetchRecords,
  23. CLOSE_DB: sqliteCloseDB,
  24. isConnected: sqliteConnected,
  25. }
  26. func initSqliteDB() error {
  27. gormDB := db.FromContext(db.WithContext(context.TODO()))
  28. var dbOpenErr error
  29. SqliteDB, dbOpenErr = gormDB.DB()
  30. if dbOpenErr != nil {
  31. return dbOpenErr
  32. }
  33. SqliteDB.SetMaxOpenConns(5)
  34. SqliteDB.SetConnMaxLifetime(time.Hour)
  35. return nil
  36. }
  37. func sqliteCreateTable(tableName string) error {
  38. statement, err := SqliteDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
  39. if err != nil {
  40. return err
  41. }
  42. defer statement.Close()
  43. _, err = statement.Exec()
  44. if err != nil {
  45. return err
  46. }
  47. return nil
  48. }
  49. func sqliteInsert(key string, value string, tableName string) error {
  50. if key != "" && value != "" {
  51. insertSQL := "INSERT OR REPLACE INTO " + tableName + " (key, value) VALUES (?, ?)"
  52. statement, err := SqliteDB.Prepare(insertSQL)
  53. if err != nil {
  54. return err
  55. }
  56. defer statement.Close()
  57. _, err = statement.Exec(key, value)
  58. if err != nil {
  59. return err
  60. }
  61. return nil
  62. }
  63. return errors.New("invalid insert " + key + " : " + value)
  64. }
  65. func sqliteInsertPeer(key string, value string) error {
  66. if key != "" && value != "" {
  67. err := sqliteInsert(key, value, PEERS_TABLE_NAME)
  68. if err != nil {
  69. return err
  70. }
  71. return nil
  72. }
  73. return errors.New("invalid peer insert " + key + " : " + value)
  74. }
  75. func sqliteDeleteRecord(tableName string, key string) error {
  76. deleteSQL := "DELETE FROM " + tableName + " WHERE key = \"" + key + "\""
  77. statement, err := SqliteDB.Prepare(deleteSQL)
  78. if err != nil {
  79. return err
  80. }
  81. defer statement.Close()
  82. if _, err = statement.Exec(); err != nil {
  83. return err
  84. }
  85. return nil
  86. }
  87. func sqliteDeleteAllRecords(tableName string) error {
  88. deleteSQL := "DELETE FROM " + tableName
  89. statement, err := SqliteDB.Prepare(deleteSQL)
  90. if err != nil {
  91. return err
  92. }
  93. defer statement.Close()
  94. if _, err = statement.Exec(); err != nil {
  95. return err
  96. }
  97. return nil
  98. }
  99. func sqliteFetchRecords(tableName string) (map[string]string, error) {
  100. row, err := SqliteDB.Query("SELECT * FROM " + tableName + " ORDER BY key")
  101. if err != nil {
  102. return nil, err
  103. }
  104. records := make(map[string]string)
  105. defer row.Close()
  106. for row.Next() { // Iterate and fetch the records from result cursor
  107. var key string
  108. var value string
  109. row.Scan(&key, &value)
  110. records[key] = value
  111. }
  112. if len(records) == 0 {
  113. return nil, errors.New(NO_RECORDS)
  114. }
  115. return records, nil
  116. }
  117. func sqliteCloseDB() {
  118. //SqliteDB.Close()
  119. }
  120. func sqliteConnected() bool {
  121. stats := SqliteDB.Stats()
  122. return stats.OpenConnections > 0
  123. }