config.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. package nebula
  2. import (
  3. "fmt"
  4. "github.com/imdario/mergo"
  5. "github.com/sirupsen/logrus"
  6. "gopkg.in/yaml.v2"
  7. "io/ioutil"
  8. "os"
  9. "os/signal"
  10. "path/filepath"
  11. "sort"
  12. "strconv"
  13. "strings"
  14. "syscall"
  15. "time"
  16. )
  17. type Config struct {
  18. path string
  19. files []string
  20. Settings map[interface{}]interface{}
  21. oldSettings map[interface{}]interface{}
  22. callbacks []func(*Config)
  23. }
  24. func NewConfig() *Config {
  25. return &Config{
  26. Settings: make(map[interface{}]interface{}),
  27. }
  28. }
  29. // Load will find all yaml files within path and load them in lexical order
  30. func (c *Config) Load(path string) error {
  31. c.path = path
  32. c.files = make([]string, 0)
  33. err := c.resolve(path, true)
  34. if err != nil {
  35. return err
  36. }
  37. if len(c.files) == 0 {
  38. return fmt.Errorf("no config files found at %s", path)
  39. }
  40. sort.Strings(c.files)
  41. err = c.parse()
  42. if err != nil {
  43. return err
  44. }
  45. return nil
  46. }
  47. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  48. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  49. // used to help decide if a change is necessary.
  50. // These functions should return quickly or spawn their own go routine if they will take a while
  51. func (c *Config) RegisterReloadCallback(f func(*Config)) {
  52. c.callbacks = append(c.callbacks, f)
  53. }
  54. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  55. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  56. // If k is an empty string the entire config is tested.
  57. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  58. // there is change when there actually wasn't any.
  59. func (c *Config) HasChanged(k string) bool {
  60. if c.oldSettings == nil {
  61. return false
  62. }
  63. var (
  64. nv interface{}
  65. ov interface{}
  66. )
  67. if k == "" {
  68. nv = c.Settings
  69. ov = c.oldSettings
  70. k = "all settings"
  71. } else {
  72. nv = c.get(k, c.Settings)
  73. ov = c.get(k, c.oldSettings)
  74. }
  75. newVals, err := yaml.Marshal(nv)
  76. if err != nil {
  77. l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  78. }
  79. oldVals, err := yaml.Marshal(ov)
  80. if err != nil {
  81. l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  82. }
  83. return string(newVals) != string(oldVals)
  84. }
  85. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  86. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  87. func (c *Config) CatchHUP() {
  88. ch := make(chan os.Signal, 1)
  89. signal.Notify(ch, syscall.SIGHUP)
  90. go func() {
  91. for range ch {
  92. l.Info("Caught HUP, reloading config")
  93. c.ReloadConfig()
  94. }
  95. }()
  96. }
  97. func (c *Config) ReloadConfig() {
  98. c.oldSettings = make(map[interface{}]interface{})
  99. for k, v := range c.Settings {
  100. c.oldSettings[k] = v
  101. }
  102. err := c.Load(c.path)
  103. if err != nil {
  104. l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  105. return
  106. }
  107. for _, v := range c.callbacks {
  108. v(c)
  109. }
  110. }
  111. // GetString will get the string for k or return the default d if not found or invalid
  112. func (c *Config) GetString(k, d string) string {
  113. r := c.Get(k)
  114. if r == nil {
  115. return d
  116. }
  117. return fmt.Sprintf("%v", r)
  118. }
  119. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  120. func (c *Config) GetStringSlice(k string, d []string) []string {
  121. r := c.Get(k)
  122. if r == nil {
  123. return d
  124. }
  125. rv, ok := r.([]interface{})
  126. if !ok {
  127. return d
  128. }
  129. v := make([]string, len(rv))
  130. for i := 0; i < len(v); i++ {
  131. v[i] = fmt.Sprintf("%v", rv[i])
  132. }
  133. return v
  134. }
  135. // GetMap will get the map for k or return the default d if not found or invalid
  136. func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
  137. r := c.Get(k)
  138. if r == nil {
  139. return d
  140. }
  141. v, ok := r.(map[interface{}]interface{})
  142. if !ok {
  143. return d
  144. }
  145. return v
  146. }
  147. // GetInt will get the int for k or return the default d if not found or invalid
  148. func (c *Config) GetInt(k string, d int) int {
  149. r := c.GetString(k, strconv.Itoa(d))
  150. v, err := strconv.Atoi(r)
  151. if err != nil {
  152. return d
  153. }
  154. return v
  155. }
  156. // GetBool will get the bool for k or return the default d if not found or invalid
  157. func (c *Config) GetBool(k string, d bool) bool {
  158. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  159. v, err := strconv.ParseBool(r)
  160. if err != nil {
  161. switch r {
  162. case "y", "yes":
  163. return true
  164. case "n", "no":
  165. return false
  166. }
  167. return d
  168. }
  169. return v
  170. }
  171. // GetDuration will get the duration for k or return the default d if not found or invalid
  172. func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
  173. r := c.GetString(k, "")
  174. v, err := time.ParseDuration(r)
  175. if err != nil {
  176. return d
  177. }
  178. return v
  179. }
  180. func (c *Config) Get(k string) interface{} {
  181. return c.get(k, c.Settings)
  182. }
  183. func (c *Config) IsSet(k string) bool {
  184. return c.get(k, c.Settings) != nil
  185. }
  186. func (c *Config) get(k string, v interface{}) interface{} {
  187. parts := strings.Split(k, ".")
  188. for _, p := range parts {
  189. m, ok := v.(map[interface{}]interface{})
  190. if !ok {
  191. return nil
  192. }
  193. v, ok = m[p]
  194. if !ok {
  195. return nil
  196. }
  197. }
  198. return v
  199. }
  200. // direct signifies if this is the config path directly specified by the user,
  201. // versus a file/dir found by recursing into that path
  202. func (c *Config) resolve(path string, direct bool) error {
  203. i, err := os.Stat(path)
  204. if err != nil {
  205. return nil
  206. }
  207. if !i.IsDir() {
  208. c.addFile(path, direct)
  209. return nil
  210. }
  211. paths, err := readDirNames(path)
  212. if err != nil {
  213. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  214. }
  215. for _, p := range paths {
  216. err := c.resolve(filepath.Join(path, p), false)
  217. if err != nil {
  218. return err
  219. }
  220. }
  221. return nil
  222. }
  223. func (c *Config) addFile(path string, direct bool) error {
  224. ext := filepath.Ext(path)
  225. if !direct && ext != ".yaml" && ext != ".yml" {
  226. return nil
  227. }
  228. ap, err := filepath.Abs(path)
  229. if err != nil {
  230. return err
  231. }
  232. c.files = append(c.files, ap)
  233. return nil
  234. }
  235. func (c *Config) parse() error {
  236. var m map[interface{}]interface{}
  237. for _, path := range c.files {
  238. b, err := ioutil.ReadFile(path)
  239. if err != nil {
  240. return err
  241. }
  242. var nm map[interface{}]interface{}
  243. err = yaml.Unmarshal(b, &nm)
  244. if err != nil {
  245. return err
  246. }
  247. // We need to use WithAppendSlice so that firewall rules in separate
  248. // files are appended together
  249. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  250. m = nm
  251. if err != nil {
  252. return err
  253. }
  254. }
  255. c.Settings = m
  256. return nil
  257. }
  258. func readDirNames(path string) ([]string, error) {
  259. f, err := os.Open(path)
  260. if err != nil {
  261. return nil, err
  262. }
  263. paths, err := f.Readdirnames(-1)
  264. f.Close()
  265. if err != nil {
  266. return nil, err
  267. }
  268. sort.Strings(paths)
  269. return paths, nil
  270. }
  271. func configLogger(c *Config) error {
  272. // set up our logging level
  273. logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
  274. if err != nil {
  275. return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
  276. }
  277. l.SetLevel(logLevel)
  278. timestampFormat := c.GetString("logging.timestamp_format", "")
  279. fullTimestamp := (timestampFormat != "")
  280. if timestampFormat == "" {
  281. timestampFormat = time.RFC3339
  282. }
  283. logFormat := strings.ToLower(c.GetString("logging.format", "text"))
  284. switch logFormat {
  285. case "text":
  286. l.Formatter = &logrus.TextFormatter{
  287. TimestampFormat: timestampFormat,
  288. FullTimestamp: fullTimestamp,
  289. }
  290. case "json":
  291. l.Formatter = &logrus.JSONFormatter{
  292. TimestampFormat: timestampFormat,
  293. }
  294. default:
  295. return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
  296. }
  297. return nil
  298. }