config.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. package nebula
  2. import (
  3. "errors"
  4. "fmt"
  5. "io/ioutil"
  6. "net"
  7. "os"
  8. "os/signal"
  9. "path/filepath"
  10. "regexp"
  11. "sort"
  12. "strconv"
  13. "strings"
  14. "syscall"
  15. "time"
  16. "github.com/imdario/mergo"
  17. "github.com/sirupsen/logrus"
  18. "gopkg.in/yaml.v2"
  19. )
  20. type Config struct {
  21. path string
  22. files []string
  23. Settings map[interface{}]interface{}
  24. oldSettings map[interface{}]interface{}
  25. callbacks []func(*Config)
  26. l *logrus.Logger
  27. }
  28. func NewConfig(l *logrus.Logger) *Config {
  29. return &Config{
  30. Settings: make(map[interface{}]interface{}),
  31. l: l,
  32. }
  33. }
  34. // Load will find all yaml files within path and load them in lexical order
  35. func (c *Config) Load(path string) error {
  36. c.path = path
  37. c.files = make([]string, 0)
  38. err := c.resolve(path, true)
  39. if err != nil {
  40. return err
  41. }
  42. if len(c.files) == 0 {
  43. return fmt.Errorf("no config files found at %s", path)
  44. }
  45. sort.Strings(c.files)
  46. err = c.parse()
  47. if err != nil {
  48. return err
  49. }
  50. return nil
  51. }
  52. func (c *Config) LoadString(raw string) error {
  53. if raw == "" {
  54. return errors.New("Empty configuration")
  55. }
  56. return c.parseRaw([]byte(raw))
  57. }
  58. // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
  59. // here should decide if they need to make a change to the current process before making the change. HasChanged can be
  60. // used to help decide if a change is necessary.
  61. // These functions should return quickly or spawn their own go routine if they will take a while
  62. func (c *Config) RegisterReloadCallback(f func(*Config)) {
  63. c.callbacks = append(c.callbacks, f)
  64. }
  65. // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
  66. // k in both the old and new settings will be serialized, the result of the string comparison is returned.
  67. // If k is an empty string the entire config is tested.
  68. // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
  69. // there is change when there actually wasn't any.
  70. func (c *Config) HasChanged(k string) bool {
  71. if c.oldSettings == nil {
  72. return false
  73. }
  74. var (
  75. nv interface{}
  76. ov interface{}
  77. )
  78. if k == "" {
  79. nv = c.Settings
  80. ov = c.oldSettings
  81. k = "all settings"
  82. } else {
  83. nv = c.get(k, c.Settings)
  84. ov = c.get(k, c.oldSettings)
  85. }
  86. newVals, err := yaml.Marshal(nv)
  87. if err != nil {
  88. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
  89. }
  90. oldVals, err := yaml.Marshal(ov)
  91. if err != nil {
  92. c.l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
  93. }
  94. return string(newVals) != string(oldVals)
  95. }
  96. // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
  97. // original path provided to Load. The old settings are shallow copied for change detection after the reload.
  98. func (c *Config) CatchHUP() {
  99. ch := make(chan os.Signal, 1)
  100. signal.Notify(ch, syscall.SIGHUP)
  101. go func() {
  102. for range ch {
  103. c.l.Info("Caught HUP, reloading config")
  104. c.ReloadConfig()
  105. }
  106. }()
  107. }
  108. func (c *Config) ReloadConfig() {
  109. c.oldSettings = make(map[interface{}]interface{})
  110. for k, v := range c.Settings {
  111. c.oldSettings[k] = v
  112. }
  113. err := c.Load(c.path)
  114. if err != nil {
  115. c.l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
  116. return
  117. }
  118. for _, v := range c.callbacks {
  119. v(c)
  120. }
  121. }
  122. // GetString will get the string for k or return the default d if not found or invalid
  123. func (c *Config) GetString(k, d string) string {
  124. r := c.Get(k)
  125. if r == nil {
  126. return d
  127. }
  128. return fmt.Sprintf("%v", r)
  129. }
  130. // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
  131. func (c *Config) GetStringSlice(k string, d []string) []string {
  132. r := c.Get(k)
  133. if r == nil {
  134. return d
  135. }
  136. rv, ok := r.([]interface{})
  137. if !ok {
  138. return d
  139. }
  140. v := make([]string, len(rv))
  141. for i := 0; i < len(v); i++ {
  142. v[i] = fmt.Sprintf("%v", rv[i])
  143. }
  144. return v
  145. }
  146. // GetMap will get the map for k or return the default d if not found or invalid
  147. func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
  148. r := c.Get(k)
  149. if r == nil {
  150. return d
  151. }
  152. v, ok := r.(map[interface{}]interface{})
  153. if !ok {
  154. return d
  155. }
  156. return v
  157. }
  158. // GetInt will get the int for k or return the default d if not found or invalid
  159. func (c *Config) GetInt(k string, d int) int {
  160. r := c.GetString(k, strconv.Itoa(d))
  161. v, err := strconv.Atoi(r)
  162. if err != nil {
  163. return d
  164. }
  165. return v
  166. }
  167. // GetBool will get the bool for k or return the default d if not found or invalid
  168. func (c *Config) GetBool(k string, d bool) bool {
  169. r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
  170. v, err := strconv.ParseBool(r)
  171. if err != nil {
  172. switch r {
  173. case "y", "yes":
  174. return true
  175. case "n", "no":
  176. return false
  177. }
  178. return d
  179. }
  180. return v
  181. }
  182. // GetDuration will get the duration for k or return the default d if not found or invalid
  183. func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
  184. r := c.GetString(k, "")
  185. v, err := time.ParseDuration(r)
  186. if err != nil {
  187. return d
  188. }
  189. return v
  190. }
  191. func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) {
  192. r := c.Get(k)
  193. if r == nil {
  194. return nil, nil
  195. }
  196. rawMap, ok := r.(map[interface{}]interface{})
  197. if !ok {
  198. return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
  199. }
  200. tree := NewCIDR6Tree()
  201. var nameRules []AllowListNameRule
  202. // Keep track of the rules we have added for both ipv4 and ipv6
  203. type allowListRules struct {
  204. firstValue bool
  205. allValuesMatch bool
  206. defaultSet bool
  207. allValues bool
  208. }
  209. rules4 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
  210. rules6 := allowListRules{firstValue: true, allValuesMatch: true, defaultSet: false}
  211. for rawKey, rawValue := range rawMap {
  212. rawCIDR, ok := rawKey.(string)
  213. if !ok {
  214. return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
  215. }
  216. // Special rule for interface names
  217. if rawCIDR == "interfaces" {
  218. if !allowInterfaces {
  219. return nil, fmt.Errorf("config `%s` does not support `interfaces`", k)
  220. }
  221. var err error
  222. nameRules, err = c.getAllowListInterfaces(k, rawValue)
  223. if err != nil {
  224. return nil, err
  225. }
  226. continue
  227. }
  228. value, ok := rawValue.(bool)
  229. if !ok {
  230. return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
  231. }
  232. _, cidr, err := net.ParseCIDR(rawCIDR)
  233. if err != nil {
  234. return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
  235. }
  236. // TODO: should we error on duplicate CIDRs in the config?
  237. tree.AddCIDR(cidr, value)
  238. maskBits, maskSize := cidr.Mask.Size()
  239. var rules *allowListRules
  240. if maskSize == 32 {
  241. rules = &rules4
  242. } else {
  243. rules = &rules6
  244. }
  245. if rules.firstValue {
  246. rules.allValues = value
  247. rules.firstValue = false
  248. } else {
  249. if value != rules.allValues {
  250. rules.allValuesMatch = false
  251. }
  252. }
  253. // Check if this is 0.0.0.0/0 or ::/0
  254. if maskBits == 0 {
  255. rules.defaultSet = true
  256. }
  257. }
  258. if !rules4.defaultSet {
  259. if rules4.allValuesMatch {
  260. _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
  261. tree.AddCIDR(zeroCIDR, !rules4.allValues)
  262. } else {
  263. return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
  264. }
  265. }
  266. if !rules6.defaultSet {
  267. if rules6.allValuesMatch {
  268. _, zeroCIDR, _ := net.ParseCIDR("::/0")
  269. tree.AddCIDR(zeroCIDR, !rules6.allValues)
  270. } else {
  271. return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for ::/0", k)
  272. }
  273. }
  274. return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
  275. }
  276. func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
  277. var nameRules []AllowListNameRule
  278. rawRules, ok := v.(map[interface{}]interface{})
  279. if !ok {
  280. return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
  281. }
  282. firstEntry := true
  283. var allValues bool
  284. for rawName, rawAllow := range rawRules {
  285. name, ok := rawName.(string)
  286. if !ok {
  287. return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
  288. }
  289. allow, ok := rawAllow.(bool)
  290. if !ok {
  291. return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
  292. }
  293. nameRE, err := regexp.Compile("^" + name + "$")
  294. if err != nil {
  295. return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
  296. }
  297. nameRules = append(nameRules, AllowListNameRule{
  298. Name: nameRE,
  299. Allow: allow,
  300. })
  301. if firstEntry {
  302. allValues = allow
  303. firstEntry = false
  304. } else {
  305. if allow != allValues {
  306. return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
  307. }
  308. }
  309. }
  310. return nameRules, nil
  311. }
  312. func (c *Config) Get(k string) interface{} {
  313. return c.get(k, c.Settings)
  314. }
  315. func (c *Config) IsSet(k string) bool {
  316. return c.get(k, c.Settings) != nil
  317. }
  318. func (c *Config) get(k string, v interface{}) interface{} {
  319. parts := strings.Split(k, ".")
  320. for _, p := range parts {
  321. m, ok := v.(map[interface{}]interface{})
  322. if !ok {
  323. return nil
  324. }
  325. v, ok = m[p]
  326. if !ok {
  327. return nil
  328. }
  329. }
  330. return v
  331. }
  332. // direct signifies if this is the config path directly specified by the user,
  333. // versus a file/dir found by recursing into that path
  334. func (c *Config) resolve(path string, direct bool) error {
  335. i, err := os.Stat(path)
  336. if err != nil {
  337. return nil
  338. }
  339. if !i.IsDir() {
  340. c.addFile(path, direct)
  341. return nil
  342. }
  343. paths, err := readDirNames(path)
  344. if err != nil {
  345. return fmt.Errorf("problem while reading directory %s: %s", path, err)
  346. }
  347. for _, p := range paths {
  348. err := c.resolve(filepath.Join(path, p), false)
  349. if err != nil {
  350. return err
  351. }
  352. }
  353. return nil
  354. }
  355. func (c *Config) addFile(path string, direct bool) error {
  356. ext := filepath.Ext(path)
  357. if !direct && ext != ".yaml" && ext != ".yml" {
  358. return nil
  359. }
  360. ap, err := filepath.Abs(path)
  361. if err != nil {
  362. return err
  363. }
  364. c.files = append(c.files, ap)
  365. return nil
  366. }
  367. func (c *Config) parseRaw(b []byte) error {
  368. var m map[interface{}]interface{}
  369. err := yaml.Unmarshal(b, &m)
  370. if err != nil {
  371. return err
  372. }
  373. c.Settings = m
  374. return nil
  375. }
  376. func (c *Config) parse() error {
  377. var m map[interface{}]interface{}
  378. for _, path := range c.files {
  379. b, err := ioutil.ReadFile(path)
  380. if err != nil {
  381. return err
  382. }
  383. var nm map[interface{}]interface{}
  384. err = yaml.Unmarshal(b, &nm)
  385. if err != nil {
  386. return err
  387. }
  388. // We need to use WithAppendSlice so that firewall rules in separate
  389. // files are appended together
  390. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  391. m = nm
  392. if err != nil {
  393. return err
  394. }
  395. }
  396. c.Settings = m
  397. return nil
  398. }
  399. func readDirNames(path string) ([]string, error) {
  400. f, err := os.Open(path)
  401. if err != nil {
  402. return nil, err
  403. }
  404. paths, err := f.Readdirnames(-1)
  405. f.Close()
  406. if err != nil {
  407. return nil, err
  408. }
  409. sort.Strings(paths)
  410. return paths, nil
  411. }
  412. func configLogger(c *Config) error {
  413. // set up our logging level
  414. logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
  415. if err != nil {
  416. return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
  417. }
  418. c.l.SetLevel(logLevel)
  419. disableTimestamp := c.GetBool("logging.disable_timestamp", false)
  420. timestampFormat := c.GetString("logging.timestamp_format", "")
  421. fullTimestamp := (timestampFormat != "")
  422. if timestampFormat == "" {
  423. timestampFormat = time.RFC3339
  424. }
  425. logFormat := strings.ToLower(c.GetString("logging.format", "text"))
  426. switch logFormat {
  427. case "text":
  428. c.l.Formatter = &logrus.TextFormatter{
  429. TimestampFormat: timestampFormat,
  430. FullTimestamp: fullTimestamp,
  431. DisableTimestamp: disableTimestamp,
  432. }
  433. case "json":
  434. c.l.Formatter = &logrus.JSONFormatter{
  435. TimestampFormat: timestampFormat,
  436. DisableTimestamp: disableTimestamp,
  437. }
  438. default:
  439. return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
  440. }
  441. return nil
  442. }