config.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. package config
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "math"
  7. "os"
  8. "os/signal"
  9. "path/filepath"
  10. "sort"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "syscall"
  15. "time"
  16. "dario.cat/mergo"
  17. "github.com/sirupsen/logrus"
  18. "gopkg.in/yaml.v3"
  19. )
  20. type C struct {
  21. path string
  22. files []string
  23. Settings map[string]any
  24. oldSettings map[string]any
  25. callbacks []func(*C)
  26. l *logrus.Logger
  27. reloadLock sync.Mutex
  28. }
  29. func NewC(l *logrus.Logger) *C {
  30. return &C{
  31. Settings: make(map[string]any),
  32. l: l,
  33. }
  34. }
  35. // Load will find all yaml files within path and load them in lexical order
  36. func (c *C) Load(path string) error {
  37. c.path = path
  38. c.files = make([]string, 0)
  39. err := c.resolve(path, true)
  40. if err != nil {
  41. return err
  42. }
  43. if len(c.files) == 0 {
  44. return fmt.Errorf("no config files found at %s", path)
  45. }
  46. sort.Strings(c.files)
  47. err = c.parse()
  48. if err != nil {
  49. return err
  50. }
  51. return nil
  52. }
  53. func (c *C) LoadString(raw string) error {
  54. if raw == "" {
  55. return errors.New("Empty configuration")
  56. }
  57. return c.parseRaw([]byte(raw))
  58. }
  59. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  60. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  61. // used to help decide if a change is necessary.
  62. // These functions should return quickly or spawn their own go routine if they will take a while
  63. func (c *C) RegisterReloadCallback(f func(*C)) {
  64. c.callbacks = append(c.callbacks, f)
  65. }
  66. // InitialLoad returns true if this is the first load of the config, and ReloadConfig has not been called yet.
  67. func (c *C) InitialLoad() bool {
  68. return c.oldSettings == nil
  69. }
  70. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  71. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  72. // If k is an empty string the entire config is tested.
  73. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  74. // there is change when there actually wasn't any.
  75. func (c *C) HasChanged(k string) bool {
  76. if c.oldSettings == nil {
  77. return false
  78. }
  79. var (
  80. nv any
  81. ov any
  82. )
  83. if k == "" {
  84. nv = c.Settings
  85. ov = c.oldSettings
  86. k = "all settings"
  87. } else {
  88. nv = c.get(k, c.Settings)
  89. ov = c.get(k, c.oldSettings)
  90. }
  91. newVals, err := yaml.Marshal(nv)
  92. if err != nil {
  93. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  94. }
  95. oldVals, err := yaml.Marshal(ov)
  96. if err != nil {
  97. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  98. }
  99. return string(newVals) != string(oldVals)
  100. }
  101. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  102. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  103. func (c *C) CatchHUP(ctx context.Context) {
  104. if c.path == "" {
  105. return
  106. }
  107. ch := make(chan os.Signal, 1)
  108. signal.Notify(ch, syscall.SIGHUP)
  109. go func() {
  110. for {
  111. select {
  112. case <-ctx.Done():
  113. signal.Stop(ch)
  114. close(ch)
  115. return
  116. case <-ch:
  117. c.l.Info("Caught HUP, reloading config")
  118. c.ReloadConfig()
  119. }
  120. }
  121. }()
  122. }
  123. func (c *C) ReloadConfig() {
  124. c.reloadLock.Lock()
  125. defer c.reloadLock.Unlock()
  126. c.oldSettings = make(map[string]any)
  127. for k, v := range c.Settings {
  128. c.oldSettings[k] = v
  129. }
  130. err := c.Load(c.path)
  131. if err != nil {
  132. c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  133. return
  134. }
  135. for _, v := range c.callbacks {
  136. v(c)
  137. }
  138. }
  139. func (c *C) ReloadConfigString(raw string) error {
  140. c.reloadLock.Lock()
  141. defer c.reloadLock.Unlock()
  142. c.oldSettings = make(map[string]any)
  143. for k, v := range c.Settings {
  144. c.oldSettings[k] = v
  145. }
  146. err := c.LoadString(raw)
  147. if err != nil {
  148. return err
  149. }
  150. for _, v := range c.callbacks {
  151. v(c)
  152. }
  153. return nil
  154. }
  155. // GetString will get the string for k or return the default d if not found or invalid
  156. func (c *C) GetString(k, d string) string {
  157. r := c.Get(k)
  158. if r == nil {
  159. return d
  160. }
  161. return fmt.Sprintf("%v", r)
  162. }
  163. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  164. func (c *C) GetStringSlice(k string, d []string) []string {
  165. r := c.Get(k)
  166. if r == nil {
  167. return d
  168. }
  169. rv, ok := r.([]any)
  170. if !ok {
  171. return d
  172. }
  173. v := make([]string, len(rv))
  174. for i := 0; i < len(v); i++ {
  175. v[i] = fmt.Sprintf("%v", rv[i])
  176. }
  177. return v
  178. }
  179. // GetMap will get the map for k or return the default d if not found or invalid
  180. func (c *C) GetMap(k string, d map[string]any) map[string]any {
  181. r := c.Get(k)
  182. if r == nil {
  183. return d
  184. }
  185. v, ok := r.(map[string]any)
  186. if !ok {
  187. return d
  188. }
  189. return v
  190. }
  191. // GetInt will get the int for k or return the default d if not found or invalid
  192. func (c *C) GetInt(k string, d int) int {
  193. r := c.GetString(k, strconv.Itoa(d))
  194. v, err := strconv.Atoi(r)
  195. if err != nil {
  196. return d
  197. }
  198. return v
  199. }
  200. // GetUint32 will get the uint32 for k or return the default d if not found or invalid
  201. func (c *C) GetUint32(k string, d uint32) uint32 {
  202. r := c.GetInt(k, int(d))
  203. if uint64(r) > uint64(math.MaxUint32) {
  204. return d
  205. }
  206. return uint32(r)
  207. }
  208. // GetBool will get the bool for k or return the default d if not found or invalid
  209. func (c *C) GetBool(k string, d bool) bool {
  210. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  211. v, err := strconv.ParseBool(r)
  212. if err != nil {
  213. switch r {
  214. case "y", "yes":
  215. return true
  216. case "n", "no":
  217. return false
  218. }
  219. return d
  220. }
  221. return v
  222. }
  223. func AsBool(v any) (value bool, ok bool) {
  224. switch x := v.(type) {
  225. case bool:
  226. return x, true
  227. case string:
  228. switch x {
  229. case "y", "yes":
  230. return true, true
  231. case "n", "no":
  232. return false, true
  233. }
  234. }
  235. return false, false
  236. }
  237. // GetDuration will get the duration for k or return the default d if not found or invalid
  238. func (c *C) GetDuration(k string, d time.Duration) time.Duration {
  239. r := c.GetString(k, "")
  240. v, err := time.ParseDuration(r)
  241. if err != nil {
  242. return d
  243. }
  244. return v
  245. }
  246. func (c *C) Get(k string) any {
  247. return c.get(k, c.Settings)
  248. }
  249. func (c *C) IsSet(k string) bool {
  250. return c.get(k, c.Settings) != nil
  251. }
  252. func (c *C) get(k string, v any) any {
  253. parts := strings.Split(k, ".")
  254. for _, p := range parts {
  255. m, ok := v.(map[string]any)
  256. if !ok {
  257. return nil
  258. }
  259. v, ok = m[p]
  260. if !ok {
  261. return nil
  262. }
  263. }
  264. return v
  265. }
  266. // direct signifies if this is the config path directly specified by the user,
  267. // versus a file/dir found by recursing into that path
  268. func (c *C) resolve(path string, direct bool) error {
  269. i, err := os.Stat(path)
  270. if err != nil {
  271. return nil
  272. }
  273. if !i.IsDir() {
  274. c.addFile(path, direct)
  275. return nil
  276. }
  277. paths, err := readDirNames(path)
  278. if err != nil {
  279. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  280. }
  281. for _, p := range paths {
  282. err := c.resolve(filepath.Join(path, p), false)
  283. if err != nil {
  284. return err
  285. }
  286. }
  287. return nil
  288. }
  289. func (c *C) addFile(path string, direct bool) error {
  290. ext := filepath.Ext(path)
  291. if !direct && ext != ".yaml" && ext != ".yml" {
  292. return nil
  293. }
  294. ap, err := filepath.Abs(path)
  295. if err != nil {
  296. return err
  297. }
  298. c.files = append(c.files, ap)
  299. return nil
  300. }
  301. func (c *C) parseRaw(b []byte) error {
  302. var m map[string]any
  303. err := yaml.Unmarshal(b, &m)
  304. if err != nil {
  305. return err
  306. }
  307. c.Settings = m
  308. return nil
  309. }
  310. func (c *C) parse() error {
  311. var m map[string]any
  312. for _, path := range c.files {
  313. b, err := os.ReadFile(path)
  314. if err != nil {
  315. return err
  316. }
  317. var nm map[string]any
  318. err = yaml.Unmarshal(b, &nm)
  319. if err != nil {
  320. return err
  321. }
  322. // We need to use WithAppendSlice so that firewall rules in separate
  323. // files are appended together
  324. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  325. m = nm
  326. if err != nil {
  327. return err
  328. }
  329. }
  330. c.Settings = m
  331. return nil
  332. }
  333. func readDirNames(path string) ([]string, error) {
  334. f, err := os.Open(path)
  335. if err != nil {
  336. return nil, err
  337. }
  338. paths, err := f.Readdirnames(-1)
  339. f.Close()
  340. if err != nil {
  341. return nil, err
  342. }
  343. sort.Strings(paths)
  344. return paths, nil
  345. }