config.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. package nebula
  2. import (
  3. "fmt"
  4. "github.com/imdario/mergo"
  5. "github.com/sirupsen/logrus"
  6. "gopkg.in/yaml.v2"
  7. "io/ioutil"
  8. "net"
  9. "os"
  10. "os/signal"
  11. "path/filepath"
  12. "regexp"
  13. "sort"
  14. "strconv"
  15. "strings"
  16. "syscall"
  17. "time"
  18. )
  19. type Config struct {
  20. path string
  21. files []string
  22. Settings map[interface{}]interface{}
  23. oldSettings map[interface{}]interface{}
  24. callbacks []func(*Config)
  25. }
  26. func NewConfig() *Config {
  27. return &Config{
  28. Settings: make(map[interface{}]interface{}),
  29. }
  30. }
  31. // Load will find all yaml files within path and load them in lexical order
  32. func (c *Config) Load(path string) error {
  33. c.path = path
  34. c.files = make([]string, 0)
  35. err := c.resolve(path, true)
  36. if err != nil {
  37. return err
  38. }
  39. if len(c.files) == 0 {
  40. return fmt.Errorf("no config files found at %s", path)
  41. }
  42. sort.Strings(c.files)
  43. err = c.parse()
  44. if err != nil {
  45. return err
  46. }
  47. return nil
  48. }
  49. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  50. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  51. // used to help decide if a change is necessary.
  52. // These functions should return quickly or spawn their own go routine if they will take a while
  53. func (c *Config) RegisterReloadCallback(f func(*Config)) {
  54. c.callbacks = append(c.callbacks, f)
  55. }
  56. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  57. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  58. // If k is an empty string the entire config is tested.
  59. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  60. // there is change when there actually wasn't any.
  61. func (c *Config) HasChanged(k string) bool {
  62. if c.oldSettings == nil {
  63. return false
  64. }
  65. var (
  66. nv interface{}
  67. ov interface{}
  68. )
  69. if k == "" {
  70. nv = c.Settings
  71. ov = c.oldSettings
  72. k = "all settings"
  73. } else {
  74. nv = c.get(k, c.Settings)
  75. ov = c.get(k, c.oldSettings)
  76. }
  77. newVals, err := yaml.Marshal(nv)
  78. if err != nil {
  79. l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  80. }
  81. oldVals, err := yaml.Marshal(ov)
  82. if err != nil {
  83. l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  84. }
  85. return string(newVals) != string(oldVals)
  86. }
  87. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  88. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  89. func (c *Config) CatchHUP() {
  90. ch := make(chan os.Signal, 1)
  91. signal.Notify(ch, syscall.SIGHUP)
  92. go func() {
  93. for range ch {
  94. l.Info("Caught HUP, reloading config")
  95. c.ReloadConfig()
  96. }
  97. }()
  98. }
  99. func (c *Config) ReloadConfig() {
  100. c.oldSettings = make(map[interface{}]interface{})
  101. for k, v := range c.Settings {
  102. c.oldSettings[k] = v
  103. }
  104. err := c.Load(c.path)
  105. if err != nil {
  106. l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  107. return
  108. }
  109. for _, v := range c.callbacks {
  110. v(c)
  111. }
  112. }
  113. // GetString will get the string for k or return the default d if not found or invalid
  114. func (c *Config) GetString(k, d string) string {
  115. r := c.Get(k)
  116. if r == nil {
  117. return d
  118. }
  119. return fmt.Sprintf("%v", r)
  120. }
  121. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  122. func (c *Config) GetStringSlice(k string, d []string) []string {
  123. r := c.Get(k)
  124. if r == nil {
  125. return d
  126. }
  127. rv, ok := r.([]interface{})
  128. if !ok {
  129. return d
  130. }
  131. v := make([]string, len(rv))
  132. for i := 0; i < len(v); i++ {
  133. v[i] = fmt.Sprintf("%v", rv[i])
  134. }
  135. return v
  136. }
  137. // GetMap will get the map for k or return the default d if not found or invalid
  138. func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
  139. r := c.Get(k)
  140. if r == nil {
  141. return d
  142. }
  143. v, ok := r.(map[interface{}]interface{})
  144. if !ok {
  145. return d
  146. }
  147. return v
  148. }
  149. // GetInt will get the int for k or return the default d if not found or invalid
  150. func (c *Config) GetInt(k string, d int) int {
  151. r := c.GetString(k, strconv.Itoa(d))
  152. v, err := strconv.Atoi(r)
  153. if err != nil {
  154. return d
  155. }
  156. return v
  157. }
  158. // GetBool will get the bool for k or return the default d if not found or invalid
  159. func (c *Config) GetBool(k string, d bool) bool {
  160. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  161. v, err := strconv.ParseBool(r)
  162. if err != nil {
  163. switch r {
  164. case "y", "yes":
  165. return true
  166. case "n", "no":
  167. return false
  168. }
  169. return d
  170. }
  171. return v
  172. }
  173. // GetDuration will get the duration for k or return the default d if not found or invalid
  174. func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
  175. r := c.GetString(k, "")
  176. v, err := time.ParseDuration(r)
  177. if err != nil {
  178. return d
  179. }
  180. return v
  181. }
  182. func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) {
  183. r := c.Get(k)
  184. if r == nil {
  185. return nil, nil
  186. }
  187. rawMap, ok := r.(map[interface{}]interface{})
  188. if !ok {
  189. return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
  190. }
  191. tree := NewCIDRTree()
  192. var nameRules []AllowListNameRule
  193. firstValue := true
  194. allValuesMatch := true
  195. defaultSet := false
  196. var allValues bool
  197. for rawKey, rawValue := range rawMap {
  198. rawCIDR, ok := rawKey.(string)
  199. if !ok {
  200. return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
  201. }
  202. // Special rule for interface names
  203. if rawCIDR == "interfaces" {
  204. if !allowInterfaces {
  205. return nil, fmt.Errorf("config `%s` does not support `interfaces`", k)
  206. }
  207. var err error
  208. nameRules, err = c.getAllowListInterfaces(k, rawValue)
  209. if err != nil {
  210. return nil, err
  211. }
  212. continue
  213. }
  214. value, ok := rawValue.(bool)
  215. if !ok {
  216. return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
  217. }
  218. _, cidr, err := net.ParseCIDR(rawCIDR)
  219. if err != nil {
  220. return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
  221. }
  222. // TODO: should we error on duplicate CIDRs in the config?
  223. tree.AddCIDR(cidr, value)
  224. if firstValue {
  225. allValues = value
  226. firstValue = false
  227. } else {
  228. if value != allValues {
  229. allValuesMatch = false
  230. }
  231. }
  232. // Check if this is 0.0.0.0/0
  233. bits, size := cidr.Mask.Size()
  234. if bits == 0 && size == 32 {
  235. defaultSet = true
  236. }
  237. }
  238. if !defaultSet {
  239. if allValuesMatch {
  240. _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
  241. tree.AddCIDR(zeroCIDR, !allValues)
  242. } else {
  243. return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
  244. }
  245. }
  246. return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
  247. }
  248. func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
  249. var nameRules []AllowListNameRule
  250. rawRules, ok := v.(map[interface{}]interface{})
  251. if !ok {
  252. return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
  253. }
  254. firstEntry := true
  255. var allValues bool
  256. for rawName, rawAllow := range rawRules {
  257. name, ok := rawName.(string)
  258. if !ok {
  259. return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
  260. }
  261. allow, ok := rawAllow.(bool)
  262. if !ok {
  263. return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
  264. }
  265. nameRE, err := regexp.Compile("^" + name + "$")
  266. if err != nil {
  267. return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
  268. }
  269. nameRules = append(nameRules, AllowListNameRule{
  270. Name: nameRE,
  271. Allow: allow,
  272. })
  273. if firstEntry {
  274. allValues = allow
  275. firstEntry = false
  276. } else {
  277. if allow != allValues {
  278. return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
  279. }
  280. }
  281. }
  282. return nameRules, nil
  283. }
  284. func (c *Config) Get(k string) interface{} {
  285. return c.get(k, c.Settings)
  286. }
  287. func (c *Config) IsSet(k string) bool {
  288. return c.get(k, c.Settings) != nil
  289. }
  290. func (c *Config) get(k string, v interface{}) interface{} {
  291. parts := strings.Split(k, ".")
  292. for _, p := range parts {
  293. m, ok := v.(map[interface{}]interface{})
  294. if !ok {
  295. return nil
  296. }
  297. v, ok = m[p]
  298. if !ok {
  299. return nil
  300. }
  301. }
  302. return v
  303. }
  304. // direct signifies if this is the config path directly specified by the user,
  305. // versus a file/dir found by recursing into that path
  306. func (c *Config) resolve(path string, direct bool) error {
  307. i, err := os.Stat(path)
  308. if err != nil {
  309. return nil
  310. }
  311. if !i.IsDir() {
  312. c.addFile(path, direct)
  313. return nil
  314. }
  315. paths, err := readDirNames(path)
  316. if err != nil {
  317. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  318. }
  319. for _, p := range paths {
  320. err := c.resolve(filepath.Join(path, p), false)
  321. if err != nil {
  322. return err
  323. }
  324. }
  325. return nil
  326. }
  327. func (c *Config) addFile(path string, direct bool) error {
  328. ext := filepath.Ext(path)
  329. if !direct && ext != ".yaml" && ext != ".yml" {
  330. return nil
  331. }
  332. ap, err := filepath.Abs(path)
  333. if err != nil {
  334. return err
  335. }
  336. c.files = append(c.files, ap)
  337. return nil
  338. }
  339. func (c *Config) parse() error {
  340. var m map[interface{}]interface{}
  341. for _, path := range c.files {
  342. b, err := ioutil.ReadFile(path)
  343. if err != nil {
  344. return err
  345. }
  346. var nm map[interface{}]interface{}
  347. err = yaml.Unmarshal(b, &nm)
  348. if err != nil {
  349. return err
  350. }
  351. // We need to use WithAppendSlice so that firewall rules in separate
  352. // files are appended together
  353. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  354. m = nm
  355. if err != nil {
  356. return err
  357. }
  358. }
  359. c.Settings = m
  360. return nil
  361. }
  362. func readDirNames(path string) ([]string, error) {
  363. f, err := os.Open(path)
  364. if err != nil {
  365. return nil, err
  366. }
  367. paths, err := f.Readdirnames(-1)
  368. f.Close()
  369. if err != nil {
  370. return nil, err
  371. }
  372. sort.Strings(paths)
  373. return paths, nil
  374. }
  375. func configLogger(c *Config) error {
  376. // set up our logging level
  377. logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
  378. if err != nil {
  379. return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
  380. }
  381. l.SetLevel(logLevel)
  382. timestampFormat := c.GetString("logging.timestamp_format", "")
  383. fullTimestamp := (timestampFormat != "")
  384. if timestampFormat == "" {
  385. timestampFormat = time.RFC3339
  386. }
  387. logFormat := strings.ToLower(c.GetString("logging.format", "text"))
  388. switch logFormat {
  389. case "text":
  390. l.Formatter = &logrus.TextFormatter{
  391. TimestampFormat: timestampFormat,
  392. FullTimestamp: fullTimestamp,
  393. }
  394. case "json":
  395. l.Formatter = &logrus.JSONFormatter{
  396. TimestampFormat: timestampFormat,
  397. }
  398. default:
  399. return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
  400. }
  401. return nil
  402. }