2
0

postgres.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package db
  2. import (
  3. "fmt"
  4. "os"
  5. "strconv"
  6. "github.com/gravitl/netmaker/config"
  7. "gorm.io/driver/postgres"
  8. "gorm.io/gorm"
  9. "gorm.io/gorm/logger"
  10. )
  11. // postgresConnector for initializing and
  12. // connecting to a postgres database.
  13. type postgresConnector struct{}
  14. // postgresConnector.connect connects and
  15. // initializes a connection to postgres.
  16. func (pg *postgresConnector) connect() (*gorm.DB, error) {
  17. pgConf := GetSQLConf()
  18. dsn := fmt.Sprintf(
  19. "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s connect_timeout=5",
  20. pgConf.Host,
  21. pgConf.Port,
  22. pgConf.Username,
  23. pgConf.Password,
  24. pgConf.DB,
  25. pgConf.SSLMode,
  26. )
  27. db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
  28. Logger: logger.Default.LogMode(logger.Silent),
  29. })
  30. if err != nil {
  31. return nil, err
  32. }
  33. // ensure netmaker_v1 schema exists.
  34. err = db.Exec("CREATE SCHEMA IF NOT EXISTS netmaker_v1").Error
  35. if err != nil {
  36. return nil, err
  37. }
  38. // set the netmaker_v1 schema as the default schema.
  39. err = db.Exec("SET search_path TO netmaker_v1").Error
  40. if err != nil {
  41. return nil, err
  42. }
  43. return db, nil
  44. }
  45. func GetSQLConf() config.SQLConfig {
  46. var cfg config.SQLConfig
  47. cfg.Host = GetSQLHost()
  48. cfg.Port = GetSQLPort()
  49. cfg.Username = GetSQLUser()
  50. cfg.Password = GetSQLPass()
  51. cfg.DB = GetSQLDB()
  52. cfg.SSLMode = GetSQLSSLMode()
  53. return cfg
  54. }
  55. func GetSQLHost() string {
  56. host := "localhost"
  57. if os.Getenv("SQL_HOST") != "" {
  58. host = os.Getenv("SQL_HOST")
  59. } else if config.Config.SQL.Host != "" {
  60. host = config.Config.SQL.Host
  61. }
  62. return host
  63. }
  64. func GetSQLPort() int32 {
  65. port := int32(5432)
  66. envport, err := strconv.Atoi(os.Getenv("SQL_PORT"))
  67. if err == nil && envport != 0 {
  68. port = int32(envport)
  69. } else if config.Config.SQL.Port != 0 {
  70. port = config.Config.SQL.Port
  71. }
  72. return port
  73. }
  74. func GetSQLUser() string {
  75. user := "postgres"
  76. if os.Getenv("SQL_USER") != "" {
  77. user = os.Getenv("SQL_USER")
  78. } else if config.Config.SQL.Username != "" {
  79. user = config.Config.SQL.Username
  80. }
  81. return user
  82. }
  83. func GetSQLPass() string {
  84. pass := "nopass"
  85. if os.Getenv("SQL_PASS") != "" {
  86. pass = os.Getenv("SQL_PASS")
  87. } else if config.Config.SQL.Password != "" {
  88. pass = config.Config.SQL.Password
  89. }
  90. return pass
  91. }
  92. func GetSQLDB() string {
  93. db := "netmaker"
  94. if os.Getenv("SQL_DB") != "" {
  95. db = os.Getenv("SQL_DB")
  96. } else if config.Config.SQL.DB != "" {
  97. db = config.Config.SQL.DB
  98. }
  99. return db
  100. }
  101. func GetSQLSSLMode() string {
  102. sslmode := "disable"
  103. if os.Getenv("SQL_SSL_MODE") != "" {
  104. sslmode = os.Getenv("SQL_SSL_MODE")
  105. } else if config.Config.SQL.SSLMode != "" {
  106. sslmode = config.Config.SQL.SSLMode
  107. }
  108. return sslmode
  109. }