config.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. package config
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io/ioutil"
  7. "math"
  8. "os"
  9. "os/signal"
  10. "path/filepath"
  11. "sort"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "syscall"
  16. "time"
  17. "dario.cat/mergo"
  18. "github.com/sirupsen/logrus"
  19. "gopkg.in/yaml.v2"
  20. )
  21. type C struct {
  22. path string
  23. files []string
  24. Settings map[interface{}]interface{}
  25. oldSettings map[interface{}]interface{}
  26. callbacks []func(*C)
  27. l *logrus.Logger
  28. reloadLock sync.Mutex
  29. }
  30. func NewC(l *logrus.Logger) *C {
  31. return &C{
  32. Settings: make(map[interface{}]interface{}),
  33. l: l,
  34. }
  35. }
  36. // Load will find all yaml files within path and load them in lexical order
  37. func (c *C) Load(path string) error {
  38. c.path = path
  39. c.files = make([]string, 0)
  40. err := c.resolve(path, true)
  41. if err != nil {
  42. return err
  43. }
  44. if len(c.files) == 0 {
  45. return fmt.Errorf("no config files found at %s", path)
  46. }
  47. sort.Strings(c.files)
  48. err = c.parse()
  49. if err != nil {
  50. return err
  51. }
  52. return nil
  53. }
  54. func (c *C) LoadString(raw string) error {
  55. if raw == "" {
  56. return errors.New("Empty configuration")
  57. }
  58. return c.parseRaw([]byte(raw))
  59. }
  60. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  61. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  62. // used to help decide if a change is necessary.
  63. // These functions should return quickly or spawn their own go routine if they will take a while
  64. func (c *C) RegisterReloadCallback(f func(*C)) {
  65. c.callbacks = append(c.callbacks, f)
  66. }
  67. // InitialLoad returns true if this is the first load of the config, and ReloadConfig has not been called yet.
  68. func (c *C) InitialLoad() bool {
  69. return c.oldSettings == nil
  70. }
  71. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  72. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  73. // If k is an empty string the entire config is tested.
  74. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  75. // there is change when there actually wasn't any.
  76. func (c *C) HasChanged(k string) bool {
  77. if c.oldSettings == nil {
  78. return false
  79. }
  80. var (
  81. nv interface{}
  82. ov interface{}
  83. )
  84. if k == "" {
  85. nv = c.Settings
  86. ov = c.oldSettings
  87. k = "all settings"
  88. } else {
  89. nv = c.get(k, c.Settings)
  90. ov = c.get(k, c.oldSettings)
  91. }
  92. newVals, err := yaml.Marshal(nv)
  93. if err != nil {
  94. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  95. }
  96. oldVals, err := yaml.Marshal(ov)
  97. if err != nil {
  98. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  99. }
  100. return string(newVals) != string(oldVals)
  101. }
  102. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  103. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  104. func (c *C) CatchHUP(ctx context.Context) {
  105. ch := make(chan os.Signal, 1)
  106. signal.Notify(ch, syscall.SIGHUP)
  107. go func() {
  108. for {
  109. select {
  110. case <-ctx.Done():
  111. signal.Stop(ch)
  112. close(ch)
  113. return
  114. case <-ch:
  115. c.l.Info("Caught HUP, reloading config")
  116. c.ReloadConfig()
  117. }
  118. }
  119. }()
  120. }
  121. func (c *C) ReloadConfig() {
  122. c.reloadLock.Lock()
  123. defer c.reloadLock.Unlock()
  124. c.oldSettings = make(map[interface{}]interface{})
  125. for k, v := range c.Settings {
  126. c.oldSettings[k] = v
  127. }
  128. err := c.Load(c.path)
  129. if err != nil {
  130. c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  131. return
  132. }
  133. for _, v := range c.callbacks {
  134. v(c)
  135. }
  136. }
  137. func (c *C) ReloadConfigString(raw string) error {
  138. c.reloadLock.Lock()
  139. defer c.reloadLock.Unlock()
  140. c.oldSettings = make(map[interface{}]interface{})
  141. for k, v := range c.Settings {
  142. c.oldSettings[k] = v
  143. }
  144. err := c.LoadString(raw)
  145. if err != nil {
  146. return err
  147. }
  148. for _, v := range c.callbacks {
  149. v(c)
  150. }
  151. return nil
  152. }
  153. // GetString will get the string for k or return the default d if not found or invalid
  154. func (c *C) GetString(k, d string) string {
  155. r := c.Get(k)
  156. if r == nil {
  157. return d
  158. }
  159. return fmt.Sprintf("%v", r)
  160. }
  161. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  162. func (c *C) GetStringSlice(k string, d []string) []string {
  163. r := c.Get(k)
  164. if r == nil {
  165. return d
  166. }
  167. rv, ok := r.([]interface{})
  168. if !ok {
  169. return d
  170. }
  171. v := make([]string, len(rv))
  172. for i := 0; i < len(v); i++ {
  173. v[i] = fmt.Sprintf("%v", rv[i])
  174. }
  175. return v
  176. }
  177. // GetMap will get the map for k or return the default d if not found or invalid
  178. func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
  179. r := c.Get(k)
  180. if r == nil {
  181. return d
  182. }
  183. v, ok := r.(map[interface{}]interface{})
  184. if !ok {
  185. return d
  186. }
  187. return v
  188. }
  189. // GetInt will get the int for k or return the default d if not found or invalid
  190. func (c *C) GetInt(k string, d int) int {
  191. r := c.GetString(k, strconv.Itoa(d))
  192. v, err := strconv.Atoi(r)
  193. if err != nil {
  194. return d
  195. }
  196. return v
  197. }
  198. // GetUint32 will get the uint32 for k or return the default d if not found or invalid
  199. func (c *C) GetUint32(k string, d uint32) uint32 {
  200. r := c.GetInt(k, int(d))
  201. if uint64(r) > uint64(math.MaxUint32) {
  202. return d
  203. }
  204. return uint32(r)
  205. }
  206. // GetBool will get the bool for k or return the default d if not found or invalid
  207. func (c *C) GetBool(k string, d bool) bool {
  208. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  209. v, err := strconv.ParseBool(r)
  210. if err != nil {
  211. switch r {
  212. case "y", "yes":
  213. return true
  214. case "n", "no":
  215. return false
  216. }
  217. return d
  218. }
  219. return v
  220. }
  221. // GetDuration will get the duration for k or return the default d if not found or invalid
  222. func (c *C) GetDuration(k string, d time.Duration) time.Duration {
  223. r := c.GetString(k, "")
  224. v, err := time.ParseDuration(r)
  225. if err != nil {
  226. return d
  227. }
  228. return v
  229. }
  230. func (c *C) Get(k string) interface{} {
  231. return c.get(k, c.Settings)
  232. }
  233. func (c *C) IsSet(k string) bool {
  234. return c.get(k, c.Settings) != nil
  235. }
  236. func (c *C) get(k string, v interface{}) interface{} {
  237. parts := strings.Split(k, ".")
  238. for _, p := range parts {
  239. m, ok := v.(map[interface{}]interface{})
  240. if !ok {
  241. return nil
  242. }
  243. v, ok = m[p]
  244. if !ok {
  245. return nil
  246. }
  247. }
  248. return v
  249. }
  250. // direct signifies if this is the config path directly specified by the user,
  251. // versus a file/dir found by recursing into that path
  252. func (c *C) resolve(path string, direct bool) error {
  253. i, err := os.Stat(path)
  254. if err != nil {
  255. return nil
  256. }
  257. if !i.IsDir() {
  258. c.addFile(path, direct)
  259. return nil
  260. }
  261. paths, err := readDirNames(path)
  262. if err != nil {
  263. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  264. }
  265. for _, p := range paths {
  266. err := c.resolve(filepath.Join(path, p), false)
  267. if err != nil {
  268. return err
  269. }
  270. }
  271. return nil
  272. }
  273. func (c *C) addFile(path string, direct bool) error {
  274. ext := filepath.Ext(path)
  275. if !direct && ext != ".yaml" && ext != ".yml" {
  276. return nil
  277. }
  278. ap, err := filepath.Abs(path)
  279. if err != nil {
  280. return err
  281. }
  282. c.files = append(c.files, ap)
  283. return nil
  284. }
  285. func (c *C) parseRaw(b []byte) error {
  286. var m map[interface{}]interface{}
  287. err := yaml.Unmarshal(b, &m)
  288. if err != nil {
  289. return err
  290. }
  291. c.Settings = m
  292. return nil
  293. }
  294. func (c *C) parse() error {
  295. var m map[interface{}]interface{}
  296. for _, path := range c.files {
  297. b, err := ioutil.ReadFile(path)
  298. if err != nil {
  299. return err
  300. }
  301. var nm map[interface{}]interface{}
  302. err = yaml.Unmarshal(b, &nm)
  303. if err != nil {
  304. return err
  305. }
  306. // We need to use WithAppendSlice so that firewall rules in separate
  307. // files are appended together
  308. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  309. m = nm
  310. if err != nil {
  311. return err
  312. }
  313. }
  314. c.Settings = m
  315. return nil
  316. }
  317. func readDirNames(path string) ([]string, error) {
  318. f, err := os.Open(path)
  319. if err != nil {
  320. return nil, err
  321. }
  322. paths, err := f.Readdirnames(-1)
  323. f.Close()
  324. if err != nil {
  325. return nil, err
  326. }
  327. sort.Strings(paths)
  328. return paths, nil
  329. }