postgres.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. package database
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "github.com/gravitl/netmaker/servercfg"
  7. _ "github.com/lib/pq"
  8. )
  9. // PGDB - database object for PostGreSQL
  10. var PGDB *sql.DB
  11. // PG_FUNCTIONS - map of db functions for PostGreSQL
  12. var PG_FUNCTIONS = map[string]interface{}{
  13. INIT_DB: initPGDB,
  14. CREATE_TABLE: pgCreateTable,
  15. INSERT: pgInsert,
  16. INSERT_PEER: pgInsertPeer,
  17. DELETE: pgDeleteRecord,
  18. DELETE_ALL: pgDeleteAllRecords,
  19. FETCH_ALL: pgFetchRecords,
  20. CLOSE_DB: pgCloseDB,
  21. }
  22. func getPGConnString() string {
  23. pgconf := servercfg.GetSQLConf()
  24. pgConn := fmt.Sprintf("host=%s port=%d user=%s "+
  25. "password=%s dbname=%s sslmode=%s connect_timeout=5",
  26. pgconf.Host, pgconf.Port, pgconf.Username, pgconf.Password, pgconf.DB, pgconf.SSLMode)
  27. return pgConn
  28. }
  29. func initPGDB() error {
  30. connString := getPGConnString()
  31. var dbOpenErr error
  32. PGDB, dbOpenErr = sql.Open("postgres", connString)
  33. if dbOpenErr != nil {
  34. return dbOpenErr
  35. }
  36. dbOpenErr = PGDB.Ping()
  37. return dbOpenErr
  38. }
  39. func pgCreateTable(tableName string) error {
  40. statement, err := PGDB.Prepare("CREATE TABLE IF NOT EXISTS " + tableName + " (key TEXT NOT NULL UNIQUE PRIMARY KEY, value TEXT)")
  41. if err != nil {
  42. return err
  43. }
  44. defer statement.Close()
  45. _, err = statement.Exec()
  46. if err != nil {
  47. return err
  48. }
  49. return nil
  50. }
  51. func pgInsert(key string, value string, tableName string) error {
  52. if key != "" && value != "" && IsJSONString(value) {
  53. insertSQL := "INSERT INTO " + tableName + " (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $3;"
  54. statement, err := PGDB.Prepare(insertSQL)
  55. if err != nil {
  56. return err
  57. }
  58. defer statement.Close()
  59. _, err = statement.Exec(key, value, value)
  60. if err != nil {
  61. return err
  62. }
  63. return nil
  64. } else {
  65. return errors.New("invalid insert " + key + " : " + value)
  66. }
  67. }
  68. func pgInsertPeer(key string, value string) error {
  69. if key != "" && value != "" && IsJSONString(value) {
  70. err := pgInsert(key, value, PEERS_TABLE_NAME)
  71. if err != nil {
  72. return err
  73. }
  74. return nil
  75. } else {
  76. return errors.New("invalid peer insert " + key + " : " + value)
  77. }
  78. }
  79. func pgDeleteRecord(tableName string, key string) error {
  80. deleteSQL := "DELETE FROM " + tableName + " WHERE key = $1;"
  81. statement, err := PGDB.Prepare(deleteSQL)
  82. if err != nil {
  83. return err
  84. }
  85. defer statement.Close()
  86. if _, err = statement.Exec(key); err != nil {
  87. return err
  88. }
  89. return nil
  90. }
  91. func pgDeleteAllRecords(tableName string) error {
  92. deleteSQL := "DELETE FROM " + tableName
  93. statement, err := PGDB.Prepare(deleteSQL)
  94. if err != nil {
  95. return err
  96. }
  97. defer statement.Close()
  98. if _, err = statement.Exec(); err != nil {
  99. return err
  100. }
  101. return nil
  102. }
  103. func pgFetchRecords(tableName string) (map[string]string, error) {
  104. row, err := PGDB.Query("SELECT * FROM " + tableName + " ORDER BY key")
  105. if err != nil {
  106. return nil, err
  107. }
  108. records := make(map[string]string)
  109. defer row.Close()
  110. for row.Next() { // Iterate and fetch the records from result cursor
  111. var key string
  112. var value string
  113. row.Scan(&key, &value)
  114. records[key] = value
  115. }
  116. if len(records) == 0 {
  117. return nil, errors.New(NO_RECORDS)
  118. }
  119. return records, nil
  120. }
  121. func pgCloseDB() {
  122. PGDB.Close()
  123. }