config_test.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. package config
  2. import (
  3. "os"
  4. "path/filepath"
  5. "testing"
  6. "time"
  7. "dario.cat/mergo"
  8. "github.com/slackhq/nebula/test"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "gopkg.in/yaml.v2"
  12. )
  13. func TestConfig_Load(t *testing.T) {
  14. l := test.NewLogger()
  15. dir, err := os.MkdirTemp("", "config-test")
  16. // invalid yaml
  17. c := NewC(l)
  18. os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
  19. require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
  20. // simple multi config merge
  21. c = NewC(l)
  22. os.RemoveAll(dir)
  23. os.Mkdir(dir, 0755)
  24. require.NoError(t, err)
  25. os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
  26. os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
  27. require.NoError(t, c.Load(dir))
  28. expected := map[interface{}]interface{}{
  29. "outer": map[interface{}]interface{}{
  30. "inner": "override",
  31. },
  32. "new": "hi",
  33. }
  34. assert.Equal(t, expected, c.Settings)
  35. }
  36. func TestConfig_Get(t *testing.T) {
  37. l := test.NewLogger()
  38. // test simple type
  39. c := NewC(l)
  40. c.Settings["firewall"] = map[interface{}]interface{}{"outbound": "hi"}
  41. assert.Equal(t, "hi", c.Get("firewall.outbound"))
  42. // test complex type
  43. inner := []map[interface{}]interface{}{{"port": "1", "code": "2"}}
  44. c.Settings["firewall"] = map[interface{}]interface{}{"outbound": inner}
  45. assert.EqualValues(t, inner, c.Get("firewall.outbound"))
  46. // test missing
  47. assert.Nil(t, c.Get("firewall.nope"))
  48. }
  49. func TestConfig_GetStringSlice(t *testing.T) {
  50. l := test.NewLogger()
  51. c := NewC(l)
  52. c.Settings["slice"] = []interface{}{"one", "two"}
  53. assert.Equal(t, []string{"one", "two"}, c.GetStringSlice("slice", []string{}))
  54. }
  55. func TestConfig_GetBool(t *testing.T) {
  56. l := test.NewLogger()
  57. c := NewC(l)
  58. c.Settings["bool"] = true
  59. assert.True(t, c.GetBool("bool", false))
  60. c.Settings["bool"] = "true"
  61. assert.True(t, c.GetBool("bool", false))
  62. c.Settings["bool"] = false
  63. assert.False(t, c.GetBool("bool", true))
  64. c.Settings["bool"] = "false"
  65. assert.False(t, c.GetBool("bool", true))
  66. c.Settings["bool"] = "Y"
  67. assert.True(t, c.GetBool("bool", false))
  68. c.Settings["bool"] = "yEs"
  69. assert.True(t, c.GetBool("bool", false))
  70. c.Settings["bool"] = "N"
  71. assert.False(t, c.GetBool("bool", true))
  72. c.Settings["bool"] = "nO"
  73. assert.False(t, c.GetBool("bool", true))
  74. }
  75. func TestConfig_HasChanged(t *testing.T) {
  76. l := test.NewLogger()
  77. // No reload has occurred, return false
  78. c := NewC(l)
  79. c.Settings["test"] = "hi"
  80. assert.False(t, c.HasChanged(""))
  81. // Test key change
  82. c = NewC(l)
  83. c.Settings["test"] = "hi"
  84. c.oldSettings = map[interface{}]interface{}{"test": "no"}
  85. assert.True(t, c.HasChanged("test"))
  86. assert.True(t, c.HasChanged(""))
  87. // No key change
  88. c = NewC(l)
  89. c.Settings["test"] = "hi"
  90. c.oldSettings = map[interface{}]interface{}{"test": "hi"}
  91. assert.False(t, c.HasChanged("test"))
  92. assert.False(t, c.HasChanged(""))
  93. }
  94. func TestConfig_ReloadConfig(t *testing.T) {
  95. l := test.NewLogger()
  96. done := make(chan bool, 1)
  97. dir, err := os.MkdirTemp("", "config-test")
  98. require.NoError(t, err)
  99. os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
  100. c := NewC(l)
  101. require.NoError(t, c.Load(dir))
  102. assert.False(t, c.HasChanged("outer.inner"))
  103. assert.False(t, c.HasChanged("outer"))
  104. assert.False(t, c.HasChanged(""))
  105. os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: ho"), 0644)
  106. c.RegisterReloadCallback(func(c *C) {
  107. done <- true
  108. })
  109. c.ReloadConfig()
  110. assert.True(t, c.HasChanged("outer.inner"))
  111. assert.True(t, c.HasChanged("outer"))
  112. assert.True(t, c.HasChanged(""))
  113. // Make sure we call the callbacks
  114. select {
  115. case <-done:
  116. case <-time.After(1 * time.Second):
  117. panic("timeout")
  118. }
  119. }
  120. // Ensure mergo merges are done the way we expect.
  121. // This is needed to test for potential regressions, like:
  122. // - https://github.com/imdario/mergo/issues/187
  123. func TestConfig_MergoMerge(t *testing.T) {
  124. configs := [][]byte{
  125. []byte(`
  126. listen:
  127. port: 1234
  128. `),
  129. []byte(`
  130. firewall:
  131. inbound:
  132. - port: 443
  133. proto: tcp
  134. groups:
  135. - server
  136. - port: 443
  137. proto: tcp
  138. groups:
  139. - webapp
  140. `),
  141. []byte(`
  142. listen:
  143. host: 0.0.0.0
  144. port: 4242
  145. firewall:
  146. outbound:
  147. - port: any
  148. proto: any
  149. host: any
  150. inbound:
  151. - port: any
  152. proto: icmp
  153. host: any
  154. `),
  155. }
  156. var m map[any]any
  157. // merge the same way config.parse() merges
  158. for _, b := range configs {
  159. var nm map[any]any
  160. err := yaml.Unmarshal(b, &nm)
  161. require.NoError(t, err)
  162. // We need to use WithAppendSlice so that firewall rules in separate
  163. // files are appended together
  164. err = mergo.Merge(&nm, m, mergo.WithAppendSlice)
  165. m = nm
  166. require.NoError(t, err)
  167. }
  168. t.Logf("Merged Config: %#v", m)
  169. mYaml, err := yaml.Marshal(m)
  170. require.NoError(t, err)
  171. t.Logf("Merged Config as YAML:\n%s", mYaml)
  172. // If a bug is present, some items might be replaced instead of merged like we expect
  173. expected := map[any]any{
  174. "firewall": map[any]any{
  175. "inbound": []any{
  176. map[any]any{"host": "any", "port": "any", "proto": "icmp"},
  177. map[any]any{"groups": []any{"server"}, "port": 443, "proto": "tcp"},
  178. map[any]any{"groups": []any{"webapp"}, "port": 443, "proto": "tcp"}},
  179. "outbound": []any{
  180. map[any]any{"host": "any", "port": "any", "proto": "any"}}},
  181. "listen": map[any]any{
  182. "host": "0.0.0.0",
  183. "port": 4242,
  184. },
  185. }
  186. assert.Equal(t, expected, m)
  187. }