123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484 |
- package nebula
- import (
- "fmt"
- "github.com/imdario/mergo"
- "github.com/sirupsen/logrus"
- "gopkg.in/yaml.v2"
- "io/ioutil"
- "net"
- "os"
- "os/signal"
- "path/filepath"
- "regexp"
- "sort"
- "strconv"
- "strings"
- "syscall"
- "time"
- )
- type Config struct {
- path string
- files []string
- Settings map[interface{}]interface{}
- oldSettings map[interface{}]interface{}
- callbacks []func(*Config)
- }
- func NewConfig() *Config {
- return &Config{
- Settings: make(map[interface{}]interface{}),
- }
- }
- // Load will find all yaml files within path and load them in lexical order
- func (c *Config) Load(path string) error {
- c.path = path
- c.files = make([]string, 0)
- err := c.resolve(path, true)
- if err != nil {
- return err
- }
- if len(c.files) == 0 {
- return fmt.Errorf("no config files found at %s", path)
- }
- sort.Strings(c.files)
- err = c.parse()
- if err != nil {
- return err
- }
- return nil
- }
- // RegisterReloadCallback stores a function to be called when a config reload is triggered. The functions registered
- // here should decide if they need to make a change to the current process before making the change. HasChanged can be
- // used to help decide if a change is necessary.
- // These functions should return quickly or spawn their own go routine if they will take a while
- func (c *Config) RegisterReloadCallback(f func(*Config)) {
- c.callbacks = append(c.callbacks, f)
- }
- // HasChanged checks if the underlying structure of the provided key has changed after a config reload. The value of
- // k in both the old and new settings will be serialized, the result of the string comparison is returned.
- // If k is an empty string the entire config is tested.
- // It's important to note that this is very rudimentary and susceptible to configuration ordering issues indicating
- // there is change when there actually wasn't any.
- func (c *Config) HasChanged(k string) bool {
- if c.oldSettings == nil {
- return false
- }
- var (
- nv interface{}
- ov interface{}
- )
- if k == "" {
- nv = c.Settings
- ov = c.oldSettings
- k = "all settings"
- } else {
- nv = c.get(k, c.Settings)
- ov = c.get(k, c.oldSettings)
- }
- newVals, err := yaml.Marshal(nv)
- if err != nil {
- l.WithField("config_path", k).WithError(err).Error("Error while marshaling new config")
- }
- oldVals, err := yaml.Marshal(ov)
- if err != nil {
- l.WithField("config_path", k).WithError(err).Error("Error while marshaling old config")
- }
- return string(newVals) != string(oldVals)
- }
- // CatchHUP will listen for the HUP signal in a go routine and reload all configs found in the
- // original path provided to Load. The old settings are shallow copied for change detection after the reload.
- func (c *Config) CatchHUP() {
- ch := make(chan os.Signal, 1)
- signal.Notify(ch, syscall.SIGHUP)
- go func() {
- for range ch {
- l.Info("Caught HUP, reloading config")
- c.ReloadConfig()
- }
- }()
- }
- func (c *Config) ReloadConfig() {
- c.oldSettings = make(map[interface{}]interface{})
- for k, v := range c.Settings {
- c.oldSettings[k] = v
- }
- err := c.Load(c.path)
- if err != nil {
- l.WithField("config_path", c.path).WithError(err).Error("Error occurred while reloading config")
- return
- }
- for _, v := range c.callbacks {
- v(c)
- }
- }
- // GetString will get the string for k or return the default d if not found or invalid
- func (c *Config) GetString(k, d string) string {
- r := c.Get(k)
- if r == nil {
- return d
- }
- return fmt.Sprintf("%v", r)
- }
- // GetStringSlice will get the slice of strings for k or return the default d if not found or invalid
- func (c *Config) GetStringSlice(k string, d []string) []string {
- r := c.Get(k)
- if r == nil {
- return d
- }
- rv, ok := r.([]interface{})
- if !ok {
- return d
- }
- v := make([]string, len(rv))
- for i := 0; i < len(v); i++ {
- v[i] = fmt.Sprintf("%v", rv[i])
- }
- return v
- }
- // GetMap will get the map for k or return the default d if not found or invalid
- func (c *Config) GetMap(k string, d map[interface{}]interface{}) map[interface{}]interface{} {
- r := c.Get(k)
- if r == nil {
- return d
- }
- v, ok := r.(map[interface{}]interface{})
- if !ok {
- return d
- }
- return v
- }
- // GetInt will get the int for k or return the default d if not found or invalid
- func (c *Config) GetInt(k string, d int) int {
- r := c.GetString(k, strconv.Itoa(d))
- v, err := strconv.Atoi(r)
- if err != nil {
- return d
- }
- return v
- }
- // GetBool will get the bool for k or return the default d if not found or invalid
- func (c *Config) GetBool(k string, d bool) bool {
- r := strings.ToLower(c.GetString(k, fmt.Sprintf("%v", d)))
- v, err := strconv.ParseBool(r)
- if err != nil {
- switch r {
- case "y", "yes":
- return true
- case "n", "no":
- return false
- }
- return d
- }
- return v
- }
- // GetDuration will get the duration for k or return the default d if not found or invalid
- func (c *Config) GetDuration(k string, d time.Duration) time.Duration {
- r := c.GetString(k, "")
- v, err := time.ParseDuration(r)
- if err != nil {
- return d
- }
- return v
- }
- func (c *Config) GetAllowList(k string, allowInterfaces bool) (*AllowList, error) {
- r := c.Get(k)
- if r == nil {
- return nil, nil
- }
- rawMap, ok := r.(map[interface{}]interface{})
- if !ok {
- return nil, fmt.Errorf("config `%s` has invalid type: %T", k, r)
- }
- tree := NewCIDRTree()
- var nameRules []AllowListNameRule
- firstValue := true
- allValuesMatch := true
- defaultSet := false
- var allValues bool
- for rawKey, rawValue := range rawMap {
- rawCIDR, ok := rawKey.(string)
- if !ok {
- return nil, fmt.Errorf("config `%s` has invalid key (type %T): %v", k, rawKey, rawKey)
- }
- // Special rule for interface names
- if rawCIDR == "interfaces" {
- if !allowInterfaces {
- return nil, fmt.Errorf("config `%s` does not support `interfaces`", k)
- }
- var err error
- nameRules, err = c.getAllowListInterfaces(k, rawValue)
- if err != nil {
- return nil, err
- }
- continue
- }
- value, ok := rawValue.(bool)
- if !ok {
- return nil, fmt.Errorf("config `%s` has invalid value (type %T): %v", k, rawValue, rawValue)
- }
- _, cidr, err := net.ParseCIDR(rawCIDR)
- if err != nil {
- return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR)
- }
- // TODO: should we error on duplicate CIDRs in the config?
- tree.AddCIDR(cidr, value)
- if firstValue {
- allValues = value
- firstValue = false
- } else {
- if value != allValues {
- allValuesMatch = false
- }
- }
- // Check if this is 0.0.0.0/0
- bits, size := cidr.Mask.Size()
- if bits == 0 && size == 32 {
- defaultSet = true
- }
- }
- if !defaultSet {
- if allValuesMatch {
- _, zeroCIDR, _ := net.ParseCIDR("0.0.0.0/0")
- tree.AddCIDR(zeroCIDR, !allValues)
- } else {
- return nil, fmt.Errorf("config `%s` contains both true and false rules, but no default set for 0.0.0.0/0", k)
- }
- }
- return &AllowList{cidrTree: tree, nameRules: nameRules}, nil
- }
- func (c *Config) getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error) {
- var nameRules []AllowListNameRule
- rawRules, ok := v.(map[interface{}]interface{})
- if !ok {
- return nil, fmt.Errorf("config `%s.interfaces` is invalid (type %T): %v", k, v, v)
- }
- firstEntry := true
- var allValues bool
- for rawName, rawAllow := range rawRules {
- name, ok := rawName.(string)
- if !ok {
- return nil, fmt.Errorf("config `%s.interfaces` has invalid key (type %T): %v", k, rawName, rawName)
- }
- allow, ok := rawAllow.(bool)
- if !ok {
- return nil, fmt.Errorf("config `%s.interfaces` has invalid value (type %T): %v", k, rawAllow, rawAllow)
- }
- nameRE, err := regexp.Compile("^" + name + "$")
- if err != nil {
- return nil, fmt.Errorf("config `%s.interfaces` has invalid key: %s: %v", k, name, err)
- }
- nameRules = append(nameRules, AllowListNameRule{
- Name: nameRE,
- Allow: allow,
- })
- if firstEntry {
- allValues = allow
- firstEntry = false
- } else {
- if allow != allValues {
- return nil, fmt.Errorf("config `%s.interfaces` values must all be the same true/false value", k)
- }
- }
- }
- return nameRules, nil
- }
- func (c *Config) Get(k string) interface{} {
- return c.get(k, c.Settings)
- }
- func (c *Config) IsSet(k string) bool {
- return c.get(k, c.Settings) != nil
- }
- func (c *Config) get(k string, v interface{}) interface{} {
- parts := strings.Split(k, ".")
- for _, p := range parts {
- m, ok := v.(map[interface{}]interface{})
- if !ok {
- return nil
- }
- v, ok = m[p]
- if !ok {
- return nil
- }
- }
- return v
- }
- // direct signifies if this is the config path directly specified by the user,
- // versus a file/dir found by recursing into that path
- func (c *Config) resolve(path string, direct bool) error {
- i, err := os.Stat(path)
- if err != nil {
- return nil
- }
- if !i.IsDir() {
- c.addFile(path, direct)
- return nil
- }
- paths, err := readDirNames(path)
- if err != nil {
- return fmt.Errorf("problem while reading directory %s: %s", path, err)
- }
- for _, p := range paths {
- err := c.resolve(filepath.Join(path, p), false)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func (c *Config) addFile(path string, direct bool) error {
- ext := filepath.Ext(path)
- if !direct && ext != ".yaml" && ext != ".yml" {
- return nil
- }
- ap, err := filepath.Abs(path)
- if err != nil {
- return err
- }
- c.files = append(c.files, ap)
- return nil
- }
- func (c *Config) parse() error {
- var m map[interface{}]interface{}
- for _, path := range c.files {
- b, err := ioutil.ReadFile(path)
- if err != nil {
- return err
- }
- var nm map[interface{}]interface{}
- err = yaml.Unmarshal(b, &nm)
- if err != nil {
- return err
- }
- // We need to use WithAppendSlice so that firewall rules in separate
- // files are appended together
- err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
- m = nm
- if err != nil {
- return err
- }
- }
- c.Settings = m
- return nil
- }
- func readDirNames(path string) ([]string, error) {
- f, err := os.Open(path)
- if err != nil {
- return nil, err
- }
- paths, err := f.Readdirnames(-1)
- f.Close()
- if err != nil {
- return nil, err
- }
- sort.Strings(paths)
- return paths, nil
- }
- func configLogger(c *Config) error {
- // set up our logging level
- logLevel, err := logrus.ParseLevel(strings.ToLower(c.GetString("logging.level", "info")))
- if err != nil {
- return fmt.Errorf("%s; possible levels: %s", err, logrus.AllLevels)
- }
- l.SetLevel(logLevel)
- timestampFormat := c.GetString("logging.timestamp_format", "")
- fullTimestamp := (timestampFormat != "")
- if timestampFormat == "" {
- timestampFormat = time.RFC3339
- }
- logFormat := strings.ToLower(c.GetString("logging.format", "text"))
- switch logFormat {
- case "text":
- l.Formatter = &logrus.TextFormatter{
- TimestampFormat: timestampFormat,
- FullTimestamp: fullTimestamp,
- }
- case "json":
- l.Formatter = &logrus.JSONFormatter{
- TimestampFormat: timestampFormat,
- }
- default:
- return fmt.Errorf("unknown log format `%s`. possible formats: %s", logFormat, []string{"text", "json"})
- }
- return nil
- }
|