config.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. package nebula
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io/ioutil"
  7. "net"
  8. "os"
  9. "os/signal"
  10. "path/filepath"
  11. "regexp"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "syscall"
  16. "time"
  17. "github.com/imdario/mergo"
  18. "github.com/sirupsen/logrus"
  19. "gopkg.in/yaml.v2"
  20. )
  21. type Config struct {
  22. path string
  23. files []string
  24. Settings map[interface{}]interface{}
  25. oldSettings map[interface{}]interface{}
  26. callbacks []func(*Config)
  27. l *logrus.Logger
  28. }
  29. func NewConfig(l *logrus.Logger) *Config {
  30. return &Config{
  31. Settings: make(map[interface{}]interface{}),
  32. l: l,
  33. }
  34. }
  35. // Load will find all yaml files within path and load them in lexical order
  36. func (c *Config) Load(path string) error {
  37. c.path = path
  38. c.files = make([]string, 0)
  39. err := c.resolve(path, true)
  40. if err != nil {
  41. return err
  42. }
  43. if len(c.files) == 0 {
  44. return fmt.Errorf("no config files found at %s", path)
  45. }
  46. sort.Strings(c.files)
  47. err = c.parse()
  48. if err != nil {
  49. return err
  50. }
  51. return nil
  52. }
  53. func (c *Config) LoadString(raw string) error {
  54. if raw == "" {
  55. return errors.New("Empty configuration")
  56. }
  57. return c.parseRaw([]byte(raw))
  58. }
  59. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  60. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  61. // used to help decide if a change is necessary.
  62. // These functions should return quickly or spawn their own go routine if they will take a while
  63. func (c *Config) RegisterReloadCallback(f func(*Config)) {
  64. c.callbacks = append(c.callbacks, f)
  65. }
  66. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  67. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  68. // If k is an empty string the entire config is tested.
  69. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  70. // there is change when there actually wasn't any.
  71. func (c *Config) HasChanged(k string) bool {
  72. if c.oldSettings == nil {
  73. return false
  74. }
  75. var (
  76. nv interface{}
  77. ov interface{}
  78. )
  79. if k == "" {
  80. nv = c.Settings
  81. ov = c.oldSettings
  82. k = "all settings"
  83. } else {
  84. nv = c.get(k, c.Settings)
  85. ov = c.get(k, c.oldSettings)
  86. }
  87. newVals, err := yaml.Marshal(nv)
  88. if err != nil {
  89. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  90. }
  91. oldVals, err := yaml.Marshal(ov)
  92. if err != nil {
  93. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  94. }
  95. return string(newVals) != string(oldVals)
  96. }
  97. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  98. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  99. func (c *Config) CatchHUP(ctx context.Context) {
  100. ch := make(chan os.Signal, 1)
  101. signal.Notify(ch, syscall.SIGHUP)
  102. go func() {
  103. for {
  104. select {
  105. case <-ctx.Done():
  106. signal.Stop(ch)
  107. close(ch)
  108. return
  109. case <-ch:
  110. c.l.Info("Caught HUP, reloading config")
  111. c.ReloadConfig()
  112. }
  113. }
  114. }()
  115. }
  116. func (c *Config) ReloadConfig() {
  117. c.oldSettings = make(map[interface{}]interface{})
  118. for k, v := range c.Settings {
  119. c.oldSettings[k] = v
  120. }
  121. err := c.Load(c.path)
  122. if err != nil {
  123. c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  124. return
  125. }
  126. for _, v := range c.callbacks {
  127. v(c)
  128. }
  129. }
  130. // GetString will get the string for k or return the default d if not found or invalid
  131. func (c *Config) GetString(k, d string) string {
  132. r := c.Get(k)
  133. if r == nil {
  134. return d
  135. }
  136. return fmt.Sprintf("%v", r)
  137. }
  138. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  139. func (c *Config) GetStringSlice(k string, d []string) []string {
  140. r := c.Get(k)
  141. if r == nil {
  142. return d
  143. }
  144. rv, ok := r.([]interface{})
  145. if !ok {
  146. return d
  147. }
  148. v := make([]string, len(rv))
  149. for i := 0; i < len(v); i++ {
  150. v[i] = fmt.Sprintf("%v", rv[i])
  151. }
  152. return v
  153. }
  154. // GetMap will get the map for k or return the default d if not found or invalid
  155. func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
  156. r := c.Get(k)
  157. if r == nil {
  158. return d
  159. }
  160. v, ok := r.(map[interface{}]interface{})
  161. if !ok {
  162. return d
  163. }
  164. return v
  165. }
  166. // GetInt will get the int for k or return the default d if not found or invalid
  167. func (c *Config) GetInt(k string, d int) int {
  168. r := c.GetString(k, strconv.Itoa(d))
  169. v, err := strconv.Atoi(r)
  170. if err != nil {
  171. return d
  172. }
  173. return v
  174. }
  175. // GetBool will get the bool for k or return the default d if not found or invalid
  176. func (c *Config) GetBool(k string, d bool) bool {
  177. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  178. v, err := strconv.ParseBool(r)
  179. if err != nil {
  180. switch r {
  181. case "y", "yes":
  182. return true
  183. case "n", "no":
  184. return false
  185. }
  186. return d
  187. }
  188. return v
  189. }
  190. // GetDuration will get the duration for k or return the default d if not found or invalid
  191. func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
  192. r := c.GetString(k, "")
  193. v, err := time.ParseDuration(r)
  194. if err != nil {
  195. return d
  196. }
  197. return v
  198. }
  199. func (c *Config) GetLocalAllowList(k string) (*LocalAllowList, error) {
  200. var nameRules []AllowListNameRule
  201. handleKey := func(key string, value interface{}) (bool, error) {
  202. if key == "interfaces" {
  203. var err error
  204. nameRules, err = c.getAllowListInterfaces(k, value)
  205. if err != nil {
  206. return false, err
  207. }
  208. return true, nil
  209. }
  210. return false, nil
  211. }
  212. al, err := c.GetAllowList(k, handleKey)
  213. if err != nil {
  214. return nil, err
  215. }
  216. return &LocalAllowList{AllowList: al, nameRules: nameRules}, nil
  217. }
  218. func (c *Config) GetRemoteAllowList(k, rangesKey string) (*RemoteAllowList, error) {
  219. al, err := c.GetAllowList(k, nil)
  220. if err != nil {
  221. return nil, err
  222. }
  223. remoteAllowRanges, err := c.getRemoteAllowRanges(rangesKey)
  224. if err != nil {
  225. return nil, err
  226. }
  227. return &RemoteAllowList{AllowList: al, insideAllowLists: remoteAllowRanges}, nil
  228. }
  229. func (c *Config) getRemoteAllowRanges(k string) (*CIDR6Tree, error) {
  230. value := c.Get(k)
  231. if value == nil {
  232. return nil, nil
  233. }
  234. remoteAllowRanges := NewCIDR6Tree()
  235. rawMap, ok := value.(map[interface{}]interface{})
  236. if !ok {
  237. return nil, fmt.Errorf("config `%s` has invalid type: %T", k, value)
  238. }
  239. for rawKey, rawValue := range rawMap {
  240. rawCIDR, ok := rawKey.(string)
  241. if !ok {
  242. return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
  243. }
  244. allowList, err := c.getAllowList(fmt.Sprintf("%s.%s", k, rawCIDR), rawValue, nil)
  245. if err != nil {
  246. return nil, err
  247. }
  248. _, cidr, err := net.ParseCIDR(rawCIDR)
  249. if err != nil {
  250. return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
  251. }
  252. remoteAllowRanges.AddCIDR(cidr, allowList)
  253. }
  254. return remoteAllowRanges, nil
  255. }
  256. // If the handleKey func returns true, the rest of the parsing is skipped
  257. // for this key. This allows parsing of special values like `interfaces`.
  258. func (c *Config) GetAllowList(k string, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
  259. r := c.Get(k)
  260. if r == nil {
  261. return nil, nil
  262. }
  263. return c.getAllowList(k, r, handleKey)
  264. }
  265. // If the handleKey func returns true, the rest of the parsing is skipped
  266. // for this key. This allows parsing of special values like `interfaces`.
  267. func (c *Config) getAllowList(k string, raw interface{}, handleKey func(key string, value interface{}) (bool, error)) (*AllowList, error) {
  268. rawMap, ok := raw.(map[interface{}]interface{})
  269. if !ok {
  270. return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
  271. }
  272. tree := NewCIDR6Tree()
  273. // Keep track of the rules we have added for both ipv4 and ipv6
  274. type allowListRules struct {
  275. firstValue bool
  276. allValuesMatch bool
  277. defaultSet bool
  278. allValues bool
  279. }
  280. rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
  281. rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
  282. for rawKey, rawValue := range rawMap {
  283. rawCIDR, ok := rawKey.(string)
  284. if !ok {
  285. return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
  286. }
  287. if handleKey != nil {
  288. handled, err := handleKey(rawCIDR, rawValue)
  289. if err != nil {
  290. return nil, err
  291. }
  292. if handled {
  293. continue
  294. }
  295. }
  296. value, ok := rawValue.(bool)
  297. if !ok {
  298. return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
  299. }
  300. _, cidr, err := net.ParseCIDR(rawCIDR)
  301. if err != nil {
  302. return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
  303. }
  304. // TODO: should we error on duplicate CIDRs in the config?
  305. tree.AddCIDR(cidr, value)
  306. maskBits, maskSize := cidr.Mask.Size()
  307. var rules *allowListRules
  308. if maskSize == 32 {
  309. rules = &rules4
  310. } else {
  311. rules = &rules6
  312. }
  313. if rules.firstValue {
  314. rules.allValues = value
  315. rules.firstValue = false
  316. } else {
  317. if value != rules.allValues {
  318. rules.allValuesMatch = false
  319. }
  320. }
  321. // Check if this is 0.0.0.0/0 or ::/0
  322. if maskBits == 0 {
  323. rules.defaultSet = true
  324. }
  325. }
  326. if !rules4.defaultSet {
  327. if rules4.allValuesMatch {
  328. _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
  329. tree.AddCIDR(zeroCIDR, !rules4.allValues)
  330. } else {
  331. return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
  332. }
  333. }
  334. if !rules6.defaultSet {
  335. if rules6.allValuesMatch {
  336. _, zeroCIDR, _ := net.ParseCIDR("::/0")
  337. tree.AddCIDR(zeroCIDR, !rules6.allValues)
  338. } else {
  339. return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
  340. }
  341. }
  342. return &AllowList{cidrTree: tree}, nil
  343. }
  344. func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
  345. var nameRules []AllowListNameRule
  346. rawRules, ok := v.(map[interface{}]interface{})
  347. if !ok {
  348. return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
  349. }
  350. firstEntry := true
  351. var allValues bool
  352. for rawName, rawAllow := range rawRules {
  353. name, ok := rawName.(string)
  354. if !ok {
  355. return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
  356. }
  357. allow, ok := rawAllow.(bool)
  358. if !ok {
  359. return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
  360. }
  361. nameRE, err := regexp.Compile("^" + name + "$")
  362. if err != nil {
  363. return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
  364. }
  365. nameRules = append(nameRules, AllowListNameRule{
  366. Name: nameRE,
  367. Allow: allow,
  368. })
  369. if firstEntry {
  370. allValues = allow
  371. firstEntry = false
  372. } else {
  373. if allow != allValues {
  374. return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
  375. }
  376. }
  377. }
  378. return nameRules, nil
  379. }
  380. func (c *Config) Get(k string) interface{} {
  381. return c.get(k, c.Settings)
  382. }
  383. func (c *Config) IsSet(k string) bool {
  384. return c.get(k, c.Settings) != nil
  385. }
  386. func (c *Config) get(k string, v interface{}) interface{} {
  387. parts := strings.Split(k, ".")
  388. for _, p := range parts {
  389. m, ok := v.(map[interface{}]interface{})
  390. if !ok {
  391. return nil
  392. }
  393. v, ok = m[p]
  394. if !ok {
  395. return nil
  396. }
  397. }
  398. return v
  399. }
  400. // direct signifies if this is the config path directly specified by the user,
  401. // versus a file/dir found by recursing into that path
  402. func (c *Config) resolve(path string, direct bool) error {
  403. i, err := os.Stat(path)
  404. if err != nil {
  405. return nil
  406. }
  407. if !i.IsDir() {
  408. c.addFile(path, direct)
  409. return nil
  410. }
  411. paths, err := readDirNames(path)
  412. if err != nil {
  413. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  414. }
  415. for _, p := range paths {
  416. err := c.resolve(filepath.Join(path, p), false)
  417. if err != nil {
  418. return err
  419. }
  420. }
  421. return nil
  422. }
  423. func (c *Config) addFile(path string, direct bool) error {
  424. ext := filepath.Ext(path)
  425. if !direct && ext != ".yaml" && ext != ".yml" {
  426. return nil
  427. }
  428. ap, err := filepath.Abs(path)
  429. if err != nil {
  430. return err
  431. }
  432. c.files = append(c.files, ap)
  433. return nil
  434. }
  435. func (c *Config) parseRaw(b []byte) error {
  436. var m map[interface{}]interface{}
  437. err := yaml.Unmarshal(b, &m)
  438. if err != nil {
  439. return err
  440. }
  441. c.Settings = m
  442. return nil
  443. }
  444. func (c *Config) parse() error {
  445. var m map[interface{}]interface{}
  446. for _, path := range c.files {
  447. b, err := ioutil.ReadFile(path)
  448. if err != nil {
  449. return err
  450. }
  451. var nm map[interface{}]interface{}
  452. err = yaml.Unmarshal(b, &nm)
  453. if err != nil {
  454. return err
  455. }
  456. // We need to use WithAppendSlice so that firewall rules in separate
  457. // files are appended together
  458. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  459. m = nm
  460. if err != nil {
  461. return err
  462. }
  463. }
  464. c.Settings = m
  465. return nil
  466. }
  467. func readDirNames(path string) ([]string, error) {
  468. f, err := os.Open(path)
  469. if err != nil {
  470. return nil, err
  471. }
  472. paths, err := f.Readdirnames(-1)
  473. f.Close()
  474. if err != nil {
  475. return nil, err
  476. }
  477. sort.Strings(paths)
  478. return paths, nil
  479. }
  480. func configLogger(c *Config) error {
  481. // set up our logging level
  482. logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
  483. if err != nil {
  484. return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
  485. }
  486. c.l.SetLevel(logLevel)
  487. disableTimestamp := c.GetBool("logging.disable_timestamp", false)
  488. timestampFormat := c.GetString("logging.timestamp_format", "")
  489. fullTimestamp := (timestampFormat != "")
  490. if timestampFormat == "" {
  491. timestampFormat = time.RFC3339
  492. }
  493. logFormat := strings.ToLower(c.GetString("logging.format", "text"))
  494. switch logFormat {
  495. case "text":
  496. c.l.Formatter = &logrus.TextFormatter{
  497. TimestampFormat: timestampFormat,
  498. FullTimestamp: fullTimestamp,
  499. DisableTimestamp: disableTimestamp,
  500. }
  501. case "json":
  502. c.l.Formatter = &logrus.JSONFormatter{
  503. TimestampFormat: timestampFormat,
  504. DisableTimestamp: disableTimestamp,
  505. }
  506. default:
  507. return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
  508. }
  509. return nil
  510. }