3
0

firewall_test.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "github.com/rcrowley/go-metrics"
  6. "github.com/stretchr/testify/assert"
  7. "math"
  8. "net"
  9. "github.com/slackhq/nebula/cert"
  10. "testing"
  11. "time"
  12. )
  13. func TestNewFirewall(t *testing.T) {
  14. c := &cert.NebulaCertificate{}
  15. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  16. assert.NotNil(t, fw.Conns)
  17. assert.NotNil(t, fw.InRules)
  18. assert.NotNil(t, fw.OutRules)
  19. assert.NotNil(t, fw.TimerWheel)
  20. assert.Equal(t, time.Second, fw.TCPTimeout)
  21. assert.Equal(t, time.Minute, fw.UDPTimeout)
  22. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  23. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  24. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  25. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  26. fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
  27. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  28. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  29. fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
  30. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  32. fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
  33. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  34. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  35. fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
  36. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  37. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  38. fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
  39. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  40. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  41. }
  42. func TestFirewall_AddRule(t *testing.T) {
  43. c := &cert.NebulaCertificate{}
  44. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  45. assert.NotNil(t, fw.InRules)
  46. assert.NotNil(t, fw.OutRules)
  47. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  48. assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
  49. // Make sure an empty rule creates structure but doesn't allow anything to flow
  50. //TODO: ideally an empty rule would return an error
  51. assert.False(t, fw.InRules.TCP[1].Any)
  52. assert.Empty(t, fw.InRules.TCP[1].Groups)
  53. assert.Empty(t, fw.InRules.TCP[1].Hosts)
  54. assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
  55. assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
  56. assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
  57. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  58. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
  59. assert.False(t, fw.InRules.UDP[1].Any)
  60. assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
  61. assert.Empty(t, fw.InRules.UDP[1].Hosts)
  62. assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
  63. assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
  64. assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
  65. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  66. assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
  67. assert.False(t, fw.InRules.ICMP[1].Any)
  68. assert.Empty(t, fw.InRules.ICMP[1].Groups)
  69. assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
  70. assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
  71. assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
  72. assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
  73. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  74. assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
  75. assert.False(t, fw.OutRules.AnyProto[1].Any)
  76. assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
  77. assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
  78. assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
  79. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  80. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
  81. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  82. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  83. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
  84. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  85. // Set any and clear fields
  86. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  87. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
  88. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
  89. assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
  90. assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
  91. // run twice just to make sure
  92. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  93. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  94. assert.True(t, fw.OutRules.AnyProto[0].Any)
  95. assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
  96. assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
  97. assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
  98. assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
  99. assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
  100. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  101. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  102. assert.True(t, fw.OutRules.AnyProto[0].Any)
  103. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  104. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  105. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
  106. assert.True(t, fw.OutRules.AnyProto[0].Any)
  107. // Test error conditions
  108. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  109. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
  110. assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
  111. }
  112. func TestFirewall_Drop(t *testing.T) {
  113. p := FirewallPacket{
  114. ip2int(net.IPv4(1, 2, 3, 4)),
  115. 101,
  116. 10,
  117. 90,
  118. fwProtoUDP,
  119. false,
  120. }
  121. ipNet := net.IPNet{
  122. IP: net.IPv4(1, 2, 3, 4),
  123. Mask: net.IPMask{255, 255, 255, 0},
  124. }
  125. c := cert.NebulaCertificate{
  126. Details: cert.NebulaCertificateDetails{
  127. Name: "host1",
  128. Ips: []*net.IPNet{&ipNet},
  129. Groups: []string{"default-group"},
  130. Issuer: "signer-shasum",
  131. },
  132. }
  133. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  134. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  135. cp := cert.NewCAPool()
  136. // Drop outbound
  137. assert.True(t, fw.Drop([]byte{}, p, false, &c, cp))
  138. // Allow inbound
  139. assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
  140. // Allow outbound because conntrack
  141. assert.False(t, fw.Drop([]byte{}, p, false, &c, cp))
  142. // test caSha assertions true
  143. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  144. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
  145. assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
  146. // test caSha assertions false
  147. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  148. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
  149. assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
  150. // test caName true
  151. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  152. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  153. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
  154. assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
  155. // test caName false
  156. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  157. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  158. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
  159. assert.True(t, fw.Drop([]byte{}, p, true, &c, cp))
  160. }
  161. func BenchmarkFirewallTable_match(b *testing.B) {
  162. ft := FirewallTable{
  163. TCP: firewallPort{},
  164. }
  165. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  166. ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
  167. ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
  168. ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
  169. ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
  170. ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
  171. cp := cert.NewCAPool()
  172. b.Run("fail on proto", func(b *testing.B) {
  173. c := &cert.NebulaCertificate{}
  174. for n := 0; n < b.N; n++ {
  175. ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
  176. }
  177. })
  178. b.Run("fail on port", func(b *testing.B) {
  179. c := &cert.NebulaCertificate{}
  180. for n := 0; n < b.N; n++ {
  181. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
  182. }
  183. })
  184. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  185. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  186. c := &cert.NebulaCertificate{
  187. Details: cert.NebulaCertificateDetails{
  188. InvertedGroups: map[string]struct{}{"nope": {}},
  189. Name: "nope",
  190. Ips: []*net.IPNet{ip},
  191. },
  192. }
  193. for n := 0; n < b.N; n++ {
  194. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  195. }
  196. })
  197. b.Run("pass on group", func(b *testing.B) {
  198. c := &cert.NebulaCertificate{
  199. Details: cert.NebulaCertificateDetails{
  200. InvertedGroups: map[string]struct{}{"good-group": {}},
  201. Name: "nope",
  202. },
  203. }
  204. for n := 0; n < b.N; n++ {
  205. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  206. }
  207. })
  208. b.Run("pass on name", func(b *testing.B) {
  209. c := &cert.NebulaCertificate{
  210. Details: cert.NebulaCertificateDetails{
  211. InvertedGroups: map[string]struct{}{"nope": {}},
  212. Name: "good-host",
  213. },
  214. }
  215. for n := 0; n < b.N; n++ {
  216. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  217. }
  218. })
  219. b.Run("pass on ip", func(b *testing.B) {
  220. ip := ip2int(net.IPv4(172, 1, 1, 1))
  221. c := &cert.NebulaCertificate{
  222. Details: cert.NebulaCertificateDetails{
  223. InvertedGroups: map[string]struct{}{"nope": {}},
  224. Name: "good-host",
  225. },
  226. }
  227. for n := 0; n < b.N; n++ {
  228. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  229. }
  230. })
  231. ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
  232. b.Run("pass on ip with any port", func(b *testing.B) {
  233. ip := ip2int(net.IPv4(172, 1, 1, 1))
  234. c := &cert.NebulaCertificate{
  235. Details: cert.NebulaCertificateDetails{
  236. InvertedGroups: map[string]struct{}{"nope": {}},
  237. Name: "good-host",
  238. },
  239. }
  240. for n := 0; n < b.N; n++ {
  241. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  242. }
  243. })
  244. }
  245. func TestFirewall_Drop2(t *testing.T) {
  246. p := FirewallPacket{
  247. ip2int(net.IPv4(1, 2, 3, 4)),
  248. 101,
  249. 10,
  250. 90,
  251. fwProtoUDP,
  252. false,
  253. }
  254. ipNet := net.IPNet{
  255. IP: net.IPv4(1, 2, 3, 4),
  256. Mask: net.IPMask{255, 255, 255, 0},
  257. }
  258. c := cert.NebulaCertificate{
  259. Details: cert.NebulaCertificateDetails{
  260. Name: "host1",
  261. Ips: []*net.IPNet{&ipNet},
  262. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  263. },
  264. }
  265. c1 := cert.NebulaCertificate{
  266. Details: cert.NebulaCertificateDetails{
  267. Name: "host1",
  268. Ips: []*net.IPNet{&ipNet},
  269. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  270. },
  271. }
  272. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  273. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
  274. cp := cert.NewCAPool()
  275. // c1 lacks the proper groups
  276. assert.True(t, fw.Drop([]byte{}, p, true, &c1, cp))
  277. // c has the proper groups
  278. assert.False(t, fw.Drop([]byte{}, p, true, &c, cp))
  279. }
  280. func BenchmarkLookup(b *testing.B) {
  281. ml := func(m map[string]struct{}, a [][]string) {
  282. for n := 0; n < b.N; n++ {
  283. for _, sg := range a {
  284. found := false
  285. for _, g := range sg {
  286. if _, ok := m[g]; !ok {
  287. found = false
  288. break
  289. }
  290. found = true
  291. }
  292. if found {
  293. return
  294. }
  295. }
  296. }
  297. }
  298. b.Run("array to map best", func(b *testing.B) {
  299. m := map[string]struct{}{
  300. "1ne": {},
  301. "2wo": {},
  302. "3hr": {},
  303. "4ou": {},
  304. "5iv": {},
  305. "6ix": {},
  306. }
  307. a := [][]string{
  308. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  309. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  310. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  311. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  312. {"one", "two", "thr", "fou", "5iv", "6ix"},
  313. {"one", "two", "thr", "fou", "fiv", "6ix"},
  314. {"one", "two", "thr", "fou", "fiv", "six"},
  315. }
  316. for n := 0; n < b.N; n++ {
  317. ml(m, a)
  318. }
  319. })
  320. b.Run("array to map worst", func(b *testing.B) {
  321. m := map[string]struct{}{
  322. "one": {},
  323. "two": {},
  324. "thr": {},
  325. "fou": {},
  326. "fiv": {},
  327. "six": {},
  328. }
  329. a := [][]string{
  330. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  331. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  332. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  333. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  334. {"one", "two", "thr", "fou", "5iv", "6ix"},
  335. {"one", "two", "thr", "fou", "fiv", "6ix"},
  336. {"one", "two", "thr", "fou", "fiv", "six"},
  337. }
  338. for n := 0; n < b.N; n++ {
  339. ml(m, a)
  340. }
  341. })
  342. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  343. }
  344. func Test_parsePort(t *testing.T) {
  345. _, _, err := parsePort("")
  346. assert.EqualError(t, err, "was not a number; ``")
  347. _, _, err = parsePort(" ")
  348. assert.EqualError(t, err, "was not a number; ` `")
  349. _, _, err = parsePort("-")
  350. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  351. _, _, err = parsePort(" - ")
  352. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  353. _, _, err = parsePort("a-b")
  354. assert.EqualError(t, err, "beginning range was not a number; `a`")
  355. _, _, err = parsePort("1-b")
  356. assert.EqualError(t, err, "ending range was not a number; `b`")
  357. s, e, err := parsePort(" 1 - 2 ")
  358. assert.Equal(t, int32(1), s)
  359. assert.Equal(t, int32(2), e)
  360. assert.Nil(t, err)
  361. s, e, err = parsePort("0-1")
  362. assert.Equal(t, int32(0), s)
  363. assert.Equal(t, int32(0), e)
  364. assert.Nil(t, err)
  365. s, e, err = parsePort("9919")
  366. assert.Equal(t, int32(9919), s)
  367. assert.Equal(t, int32(9919), e)
  368. assert.Nil(t, err)
  369. s, e, err = parsePort("any")
  370. assert.Equal(t, int32(0), s)
  371. assert.Equal(t, int32(0), e)
  372. assert.Nil(t, err)
  373. }
  374. func TestNewFirewallFromConfig(t *testing.T) {
  375. // Test a bad rule definition
  376. c := &cert.NebulaCertificate{}
  377. conf := NewConfig()
  378. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  379. _, err := NewFirewallFromConfig(c, conf)
  380. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  381. // Test both port and code
  382. conf = NewConfig()
  383. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  384. _, err = NewFirewallFromConfig(c, conf)
  385. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  386. // Test missing host, group, cidr, ca_name and ca_sha
  387. conf = NewConfig()
  388. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  389. _, err = NewFirewallFromConfig(c, conf)
  390. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
  391. // Test code/port error
  392. conf = NewConfig()
  393. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  394. _, err = NewFirewallFromConfig(c, conf)
  395. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  396. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  397. _, err = NewFirewallFromConfig(c, conf)
  398. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  399. // Test proto error
  400. conf = NewConfig()
  401. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  402. _, err = NewFirewallFromConfig(c, conf)
  403. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  404. // Test cidr parse error
  405. conf = NewConfig()
  406. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  407. _, err = NewFirewallFromConfig(c, conf)
  408. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  409. // Test both group and groups
  410. conf = NewConfig()
  411. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  412. _, err = NewFirewallFromConfig(c, conf)
  413. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  414. }
  415. func TestAddFirewallRulesFromConfig(t *testing.T) {
  416. // Test adding tcp rule
  417. conf := NewConfig()
  418. mf := &mockFirewall{}
  419. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  420. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  421. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  422. // Test adding udp rule
  423. conf = NewConfig()
  424. mf = &mockFirewall{}
  425. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  426. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  427. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  428. // Test adding icmp rule
  429. conf = NewConfig()
  430. mf = &mockFirewall{}
  431. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  432. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  433. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  434. // Test adding any rule
  435. conf = NewConfig()
  436. mf = &mockFirewall{}
  437. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  438. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  439. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  440. // Test adding rule with ca_sha
  441. conf = NewConfig()
  442. mf = &mockFirewall{}
  443. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  444. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  445. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
  446. // Test adding rule with ca_name
  447. conf = NewConfig()
  448. mf = &mockFirewall{}
  449. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  450. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  451. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
  452. // Test single group
  453. conf = NewConfig()
  454. mf = &mockFirewall{}
  455. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  456. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  457. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  458. // Test single groups
  459. conf = NewConfig()
  460. mf = &mockFirewall{}
  461. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  462. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  463. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  464. // Test multiple AND groups
  465. conf = NewConfig()
  466. mf = &mockFirewall{}
  467. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  468. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  469. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
  470. // Test Add error
  471. conf = NewConfig()
  472. mf = &mockFirewall{}
  473. mf.nextCallReturn = errors.New("test error")
  474. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  475. assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
  476. }
  477. func TestTCPRTTTracking(t *testing.T) {
  478. b := make([]byte, 200)
  479. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  480. b[0] = 15
  481. b[60+12] = 15 << 4
  482. f := Firewall{
  483. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  484. }
  485. // Set SEQ to 1
  486. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  487. c := &conn{}
  488. setTCPRTTTracking(c, b)
  489. assert.Equal(t, uint32(1), c.Seq)
  490. // Bad ack - no ack flag
  491. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  492. assert.False(t, f.checkTCPRTT(c, b))
  493. // Bad ack, number is too low
  494. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  495. b[60+13] = uint8(0x10)
  496. assert.False(t, f.checkTCPRTT(c, b))
  497. // Good ack
  498. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  499. assert.True(t, f.checkTCPRTT(c, b))
  500. assert.Equal(t, uint32(0), c.Seq)
  501. // Set SEQ to 1
  502. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  503. c = &conn{}
  504. setTCPRTTTracking(c, b)
  505. assert.Equal(t, uint32(1), c.Seq)
  506. // Good acks
  507. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  508. assert.True(t, f.checkTCPRTT(c, b))
  509. assert.Equal(t, uint32(0), c.Seq)
  510. // Set SEQ to max uint32 - 20
  511. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  512. c = &conn{}
  513. setTCPRTTTracking(c, b)
  514. assert.Equal(t, ^uint32(0)-20, c.Seq)
  515. // Good acks
  516. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  517. assert.True(t, f.checkTCPRTT(c, b))
  518. assert.Equal(t, uint32(0), c.Seq)
  519. // Set SEQ to max uint32 / 2
  520. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  521. c = &conn{}
  522. setTCPRTTTracking(c, b)
  523. assert.Equal(t, ^uint32(0)/2, c.Seq)
  524. // Below
  525. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  526. assert.False(t, f.checkTCPRTT(c, b))
  527. assert.Equal(t, ^uint32(0)/2, c.Seq)
  528. // Halfway below
  529. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  530. assert.False(t, f.checkTCPRTT(c, b))
  531. assert.Equal(t, ^uint32(0)/2, c.Seq)
  532. // Halfway above is ok
  533. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  534. assert.True(t, f.checkTCPRTT(c, b))
  535. assert.Equal(t, uint32(0), c.Seq)
  536. // Set SEQ to max uint32
  537. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  538. c = &conn{}
  539. setTCPRTTTracking(c, b)
  540. assert.Equal(t, ^uint32(0), c.Seq)
  541. // Halfway + 1 above
  542. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  543. assert.False(t, f.checkTCPRTT(c, b))
  544. assert.Equal(t, ^uint32(0), c.Seq)
  545. // Halfway above
  546. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  547. assert.True(t, f.checkTCPRTT(c, b))
  548. assert.Equal(t, uint32(0), c.Seq)
  549. }
  550. type addRuleCall struct {
  551. incoming bool
  552. proto uint8
  553. startPort int32
  554. endPort int32
  555. groups []string
  556. host string
  557. ip *net.IPNet
  558. caName string
  559. caSha string
  560. }
  561. type mockFirewall struct {
  562. lastCall addRuleCall
  563. nextCallReturn error
  564. }
  565. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  566. mf.lastCall = addRuleCall{
  567. incoming: incoming,
  568. proto: proto,
  569. startPort: startPort,
  570. endPort: endPort,
  571. groups: groups,
  572. host: host,
  573. ip: ip,
  574. caName: caName,
  575. caSha: caSha,
  576. }
  577. err := mf.nextCallReturn
  578. mf.nextCallReturn = nil
  579. return err
  580. }