config.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. package config
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io/ioutil"
  7. "os"
  8. "os/signal"
  9. "path/filepath"
  10. "sort"
  11. "strconv"
  12. "strings"
  13. "syscall"
  14. "time"
  15. "github.com/imdario/mergo"
  16. "github.com/sirupsen/logrus"
  17. "gopkg.in/yaml.v2"
  18. )
  19. type C struct {
  20. path string
  21. files []string
  22. Settings map[interface{}]interface{}
  23. oldSettings map[interface{}]interface{}
  24. callbacks []func(*C)
  25. l *logrus.Logger
  26. }
  27. func NewC(l *logrus.Logger) *C {
  28. return &C{
  29. Settings: make(map[interface{}]interface{}),
  30. l: l,
  31. }
  32. }
  33. // Load will find all yaml files within path and load them in lexical order
  34. func (c *C) Load(path string) error {
  35. c.path = path
  36. c.files = make([]string, 0)
  37. err := c.resolve(path, true)
  38. if err != nil {
  39. return err
  40. }
  41. if len(c.files) == 0 {
  42. return fmt.Errorf("no config files found at %s", path)
  43. }
  44. sort.Strings(c.files)
  45. err = c.parse()
  46. if err != nil {
  47. return err
  48. }
  49. return nil
  50. }
  51. func (c *C) LoadString(raw string) error {
  52. if raw == "" {
  53. return errors.New("Empty configuration")
  54. }
  55. return c.parseRaw([]byte(raw))
  56. }
  57. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  58. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  59. // used to help decide if a change is necessary.
  60. // These functions should return quickly or spawn their own go routine if they will take a while
  61. func (c *C) RegisterReloadCallback(f func(*C)) {
  62. c.callbacks = append(c.callbacks, f)
  63. }
  64. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  65. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  66. // If k is an empty string the entire config is tested.
  67. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  68. // there is change when there actually wasn't any.
  69. func (c *C) HasChanged(k string) bool {
  70. if c.oldSettings == nil {
  71. return false
  72. }
  73. var (
  74. nv interface{}
  75. ov interface{}
  76. )
  77. if k == "" {
  78. nv = c.Settings
  79. ov = c.oldSettings
  80. k = "all settings"
  81. } else {
  82. nv = c.get(k, c.Settings)
  83. ov = c.get(k, c.oldSettings)
  84. }
  85. newVals, err := yaml.Marshal(nv)
  86. if err != nil {
  87. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  88. }
  89. oldVals, err := yaml.Marshal(ov)
  90. if err != nil {
  91. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  92. }
  93. return string(newVals) != string(oldVals)
  94. }
  95. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  96. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  97. func (c *C) CatchHUP(ctx context.Context) {
  98. ch := make(chan os.Signal, 1)
  99. signal.Notify(ch, syscall.SIGHUP)
  100. go func() {
  101. for {
  102. select {
  103. case <-ctx.Done():
  104. signal.Stop(ch)
  105. close(ch)
  106. return
  107. case <-ch:
  108. c.l.Info("Caught HUP, reloading config")
  109. c.ReloadConfig()
  110. }
  111. }
  112. }()
  113. }
  114. func (c *C) ReloadConfig() {
  115. c.oldSettings = make(map[interface{}]interface{})
  116. for k, v := range c.Settings {
  117. c.oldSettings[k] = v
  118. }
  119. err := c.Load(c.path)
  120. if err != nil {
  121. c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  122. return
  123. }
  124. for _, v := range c.callbacks {
  125. v(c)
  126. }
  127. }
  128. // GetString will get the string for k or return the default d if not found or invalid
  129. func (c *C) GetString(k, d string) string {
  130. r := c.Get(k)
  131. if r == nil {
  132. return d
  133. }
  134. return fmt.Sprintf("%v", r)
  135. }
  136. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  137. func (c *C) GetStringSlice(k string, d []string) []string {
  138. r := c.Get(k)
  139. if r == nil {
  140. return d
  141. }
  142. rv, ok := r.([]interface{})
  143. if !ok {
  144. return d
  145. }
  146. v := make([]string, len(rv))
  147. for i := 0; i < len(v); i++ {
  148. v[i] = fmt.Sprintf("%v", rv[i])
  149. }
  150. return v
  151. }
  152. // GetMap will get the map for k or return the default d if not found or invalid
  153. func (c *C) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
  154. r := c.Get(k)
  155. if r == nil {
  156. return d
  157. }
  158. v, ok := r.(map[interface{}]interface{})
  159. if !ok {
  160. return d
  161. }
  162. return v
  163. }
  164. // GetInt will get the int for k or return the default d if not found or invalid
  165. func (c *C) GetInt(k string, d int) int {
  166. r := c.GetString(k, strconv.Itoa(d))
  167. v, err := strconv.Atoi(r)
  168. if err != nil {
  169. return d
  170. }
  171. return v
  172. }
  173. // GetBool will get the bool for k or return the default d if not found or invalid
  174. func (c *C) GetBool(k string, d bool) bool {
  175. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  176. v, err := strconv.ParseBool(r)
  177. if err != nil {
  178. switch r {
  179. case "y", "yes":
  180. return true
  181. case "n", "no":
  182. return false
  183. }
  184. return d
  185. }
  186. return v
  187. }
  188. // GetDuration will get the duration for k or return the default d if not found or invalid
  189. func (c *C) GetDuration(k string, d time.Duration) time.Duration {
  190. r := c.GetString(k, "")
  191. v, err := time.ParseDuration(r)
  192. if err != nil {
  193. return d
  194. }
  195. return v
  196. }
  197. func (c *C) Get(k string) interface{} {
  198. return c.get(k, c.Settings)
  199. }
  200. func (c *C) IsSet(k string) bool {
  201. return c.get(k, c.Settings) != nil
  202. }
  203. func (c *C) get(k string, v interface{}) interface{} {
  204. parts := strings.Split(k, ".")
  205. for _, p := range parts {
  206. m, ok := v.(map[interface{}]interface{})
  207. if !ok {
  208. return nil
  209. }
  210. v, ok = m[p]
  211. if !ok {
  212. return nil
  213. }
  214. }
  215. return v
  216. }
  217. // direct signifies if this is the config path directly specified by the user,
  218. // versus a file/dir found by recursing into that path
  219. func (c *C) resolve(path string, direct bool) error {
  220. i, err := os.Stat(path)
  221. if err != nil {
  222. return nil
  223. }
  224. if !i.IsDir() {
  225. c.addFile(path, direct)
  226. return nil
  227. }
  228. paths, err := readDirNames(path)
  229. if err != nil {
  230. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  231. }
  232. for _, p := range paths {
  233. err := c.resolve(filepath.Join(path, p), false)
  234. if err != nil {
  235. return err
  236. }
  237. }
  238. return nil
  239. }
  240. func (c *C) addFile(path string, direct bool) error {
  241. ext := filepath.Ext(path)
  242. if !direct && ext != ".yaml" && ext != ".yml" {
  243. return nil
  244. }
  245. ap, err := filepath.Abs(path)
  246. if err != nil {
  247. return err
  248. }
  249. c.files = append(c.files, ap)
  250. return nil
  251. }
  252. func (c *C) parseRaw(b []byte) error {
  253. var m map[interface{}]interface{}
  254. err := yaml.Unmarshal(b, &m)
  255. if err != nil {
  256. return err
  257. }
  258. c.Settings = m
  259. return nil
  260. }
  261. func (c *C) parse() error {
  262. var m map[interface{}]interface{}
  263. for _, path := range c.files {
  264. b, err := ioutil.ReadFile(path)
  265. if err != nil {
  266. return err
  267. }
  268. var nm map[interface{}]interface{}
  269. err = yaml.Unmarshal(b, &nm)
  270. if err != nil {
  271. return err
  272. }
  273. // We need to use WithAppendSlice so that firewall rules in separate
  274. // files are appended together
  275. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  276. m = nm
  277. if err != nil {
  278. return err
  279. }
  280. }
  281. c.Settings = m
  282. return nil
  283. }
  284. func readDirNames(path string) ([]string, error) {
  285. f, err := os.Open(path)
  286. if err != nil {
  287. return nil, err
  288. }
  289. paths, err := f.Readdirnames(-1)
  290. f.Close()
  291. if err != nil {
  292. return nil, err
  293. }
  294. sort.Strings(paths)
  295. return paths, nil
  296. }