firewall_test.go 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712
  1. package nebula
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "math"
  6. "net"
  7. "testing"
  8. "time"
  9. "github.com/rcrowley/go-metrics"
  10. "github.com/slackhq/nebula/cert"
  11. "github.com/stretchr/testify/assert"
  12. )
  13. func TestNewFirewall(t *testing.T) {
  14. c := &cert.NebulaCertificate{}
  15. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  16. assert.NotNil(t, fw.Conns)
  17. assert.NotNil(t, fw.InRules)
  18. assert.NotNil(t, fw.OutRules)
  19. assert.NotNil(t, fw.TimerWheel)
  20. assert.Equal(t, time.Second, fw.TCPTimeout)
  21. assert.Equal(t, time.Minute, fw.UDPTimeout)
  22. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  23. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  24. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  25. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  26. fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
  27. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  28. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  29. fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
  30. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  32. fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
  33. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  34. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  35. fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
  36. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  37. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  38. fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
  39. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  40. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  41. }
  42. func TestFirewall_AddRule(t *testing.T) {
  43. c := &cert.NebulaCertificate{}
  44. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  45. assert.NotNil(t, fw.InRules)
  46. assert.NotNil(t, fw.OutRules)
  47. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  48. assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
  49. // Make sure an empty rule creates structure but doesn't allow anything to flow
  50. //TODO: ideally an empty rule would return an error
  51. assert.False(t, fw.InRules.TCP[1].Any)
  52. assert.Empty(t, fw.InRules.TCP[1].Groups)
  53. assert.Empty(t, fw.InRules.TCP[1].Hosts)
  54. assert.Nil(t, fw.InRules.TCP[1].CIDR.root.left)
  55. assert.Nil(t, fw.InRules.TCP[1].CIDR.root.right)
  56. assert.Nil(t, fw.InRules.TCP[1].CIDR.root.value)
  57. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  58. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
  59. assert.False(t, fw.InRules.UDP[1].Any)
  60. assert.Contains(t, fw.InRules.UDP[1].Groups[0], "g1")
  61. assert.Empty(t, fw.InRules.UDP[1].Hosts)
  62. assert.Nil(t, fw.InRules.UDP[1].CIDR.root.left)
  63. assert.Nil(t, fw.InRules.UDP[1].CIDR.root.right)
  64. assert.Nil(t, fw.InRules.UDP[1].CIDR.root.value)
  65. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  66. assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
  67. assert.False(t, fw.InRules.ICMP[1].Any)
  68. assert.Empty(t, fw.InRules.ICMP[1].Groups)
  69. assert.Contains(t, fw.InRules.ICMP[1].Hosts, "h1")
  70. assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.left)
  71. assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.right)
  72. assert.Nil(t, fw.InRules.ICMP[1].CIDR.root.value)
  73. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  74. assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
  75. assert.False(t, fw.OutRules.AnyProto[1].Any)
  76. assert.Empty(t, fw.OutRules.AnyProto[1].Groups)
  77. assert.Empty(t, fw.OutRules.AnyProto[1].Hosts)
  78. assert.NotNil(t, fw.OutRules.AnyProto[1].CIDR.Match(ip2int(ti.IP)))
  79. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  80. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
  81. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  82. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  83. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
  84. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  85. // Set any and clear fields
  86. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  87. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
  88. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Groups[0])
  89. assert.Contains(t, fw.OutRules.AnyProto[0].Hosts, "h1")
  90. assert.NotNil(t, fw.OutRules.AnyProto[0].CIDR.Match(ip2int(ti.IP)))
  91. // run twice just to make sure
  92. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  93. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  94. assert.True(t, fw.OutRules.AnyProto[0].Any)
  95. assert.Empty(t, fw.OutRules.AnyProto[0].Groups)
  96. assert.Empty(t, fw.OutRules.AnyProto[0].Hosts)
  97. assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.left)
  98. assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.right)
  99. assert.Nil(t, fw.OutRules.AnyProto[0].CIDR.root.value)
  100. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  101. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  102. assert.True(t, fw.OutRules.AnyProto[0].Any)
  103. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  104. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  105. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
  106. assert.True(t, fw.OutRules.AnyProto[0].Any)
  107. // Test error conditions
  108. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  109. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
  110. assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
  111. }
  112. func TestFirewall_Drop(t *testing.T) {
  113. p := FirewallPacket{
  114. ip2int(net.IPv4(1, 2, 3, 4)),
  115. ip2int(net.IPv4(1, 2, 3, 4)),
  116. 10,
  117. 90,
  118. fwProtoUDP,
  119. false,
  120. }
  121. ipNet := net.IPNet{
  122. IP: net.IPv4(1, 2, 3, 4),
  123. Mask: net.IPMask{255, 255, 255, 0},
  124. }
  125. c := cert.NebulaCertificate{
  126. Details: cert.NebulaCertificateDetails{
  127. Name: "host1",
  128. Ips: []*net.IPNet{&ipNet},
  129. Groups: []string{"default-group"},
  130. Issuer: "signer-shasum",
  131. },
  132. }
  133. h := HostInfo{
  134. ConnectionState: &ConnectionState{
  135. peerCert: &c,
  136. },
  137. }
  138. h.CreateRemoteCIDR(&c)
  139. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  140. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  141. cp := cert.NewCAPool()
  142. // Drop outbound
  143. assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
  144. // Allow inbound
  145. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  146. // Allow outbound because conntrack
  147. assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
  148. // test remote mismatch
  149. oldRemote := p.RemoteIP
  150. p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
  151. assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
  152. p.RemoteIP = oldRemote
  153. // test caSha assertions true
  154. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  155. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum"))
  156. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  157. // test caSha assertions false
  158. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  159. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", "signer-shasum-nope"))
  160. assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
  161. // test caName true
  162. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  163. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  164. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-good", ""))
  165. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  166. // test caName false
  167. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  168. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  169. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "ca-bad", ""))
  170. assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
  171. }
  172. func BenchmarkFirewallTable_match(b *testing.B) {
  173. ft := FirewallTable{
  174. TCP: firewallPort{},
  175. }
  176. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  177. ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
  178. ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
  179. ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
  180. ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
  181. ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
  182. cp := cert.NewCAPool()
  183. b.Run("fail on proto", func(b *testing.B) {
  184. c := &cert.NebulaCertificate{}
  185. for n := 0; n < b.N; n++ {
  186. ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
  187. }
  188. })
  189. b.Run("fail on port", func(b *testing.B) {
  190. c := &cert.NebulaCertificate{}
  191. for n := 0; n < b.N; n++ {
  192. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
  193. }
  194. })
  195. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  196. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  197. c := &cert.NebulaCertificate{
  198. Details: cert.NebulaCertificateDetails{
  199. InvertedGroups: map[string]struct{}{"nope": {}},
  200. Name: "nope",
  201. Ips: []*net.IPNet{ip},
  202. },
  203. }
  204. for n := 0; n < b.N; n++ {
  205. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  206. }
  207. })
  208. b.Run("pass on group", func(b *testing.B) {
  209. c := &cert.NebulaCertificate{
  210. Details: cert.NebulaCertificateDetails{
  211. InvertedGroups: map[string]struct{}{"good-group": {}},
  212. Name: "nope",
  213. },
  214. }
  215. for n := 0; n < b.N; n++ {
  216. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  217. }
  218. })
  219. b.Run("pass on name", func(b *testing.B) {
  220. c := &cert.NebulaCertificate{
  221. Details: cert.NebulaCertificateDetails{
  222. InvertedGroups: map[string]struct{}{"nope": {}},
  223. Name: "good-host",
  224. },
  225. }
  226. for n := 0; n < b.N; n++ {
  227. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  228. }
  229. })
  230. b.Run("pass on ip", func(b *testing.B) {
  231. ip := ip2int(net.IPv4(172, 1, 1, 1))
  232. c := &cert.NebulaCertificate{
  233. Details: cert.NebulaCertificateDetails{
  234. InvertedGroups: map[string]struct{}{"nope": {}},
  235. Name: "good-host",
  236. },
  237. }
  238. for n := 0; n < b.N; n++ {
  239. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  240. }
  241. })
  242. ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
  243. b.Run("pass on ip with any port", func(b *testing.B) {
  244. ip := ip2int(net.IPv4(172, 1, 1, 1))
  245. c := &cert.NebulaCertificate{
  246. Details: cert.NebulaCertificateDetails{
  247. InvertedGroups: map[string]struct{}{"nope": {}},
  248. Name: "good-host",
  249. },
  250. }
  251. for n := 0; n < b.N; n++ {
  252. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  253. }
  254. })
  255. }
  256. func TestFirewall_Drop2(t *testing.T) {
  257. p := FirewallPacket{
  258. ip2int(net.IPv4(1, 2, 3, 4)),
  259. ip2int(net.IPv4(1, 2, 3, 4)),
  260. 10,
  261. 90,
  262. fwProtoUDP,
  263. false,
  264. }
  265. ipNet := net.IPNet{
  266. IP: net.IPv4(1, 2, 3, 4),
  267. Mask: net.IPMask{255, 255, 255, 0},
  268. }
  269. c := cert.NebulaCertificate{
  270. Details: cert.NebulaCertificateDetails{
  271. Name: "host1",
  272. Ips: []*net.IPNet{&ipNet},
  273. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  274. },
  275. }
  276. h := HostInfo{
  277. ConnectionState: &ConnectionState{
  278. peerCert: &c,
  279. },
  280. }
  281. h.CreateRemoteCIDR(&c)
  282. c1 := cert.NebulaCertificate{
  283. Details: cert.NebulaCertificateDetails{
  284. Name: "host1",
  285. Ips: []*net.IPNet{&ipNet},
  286. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  287. },
  288. }
  289. h1 := HostInfo{
  290. ConnectionState: &ConnectionState{
  291. peerCert: &c1,
  292. },
  293. }
  294. h1.CreateRemoteCIDR(&c1)
  295. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  296. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
  297. cp := cert.NewCAPool()
  298. // h1/c1 lacks the proper groups
  299. assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
  300. // c has the proper groups
  301. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  302. }
  303. func BenchmarkLookup(b *testing.B) {
  304. ml := func(m map[string]struct{}, a [][]string) {
  305. for n := 0; n < b.N; n++ {
  306. for _, sg := range a {
  307. found := false
  308. for _, g := range sg {
  309. if _, ok := m[g]; !ok {
  310. found = false
  311. break
  312. }
  313. found = true
  314. }
  315. if found {
  316. return
  317. }
  318. }
  319. }
  320. }
  321. b.Run("array to map best", func(b *testing.B) {
  322. m := map[string]struct{}{
  323. "1ne": {},
  324. "2wo": {},
  325. "3hr": {},
  326. "4ou": {},
  327. "5iv": {},
  328. "6ix": {},
  329. }
  330. a := [][]string{
  331. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  332. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  333. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  334. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  335. {"one", "two", "thr", "fou", "5iv", "6ix"},
  336. {"one", "two", "thr", "fou", "fiv", "6ix"},
  337. {"one", "two", "thr", "fou", "fiv", "six"},
  338. }
  339. for n := 0; n < b.N; n++ {
  340. ml(m, a)
  341. }
  342. })
  343. b.Run("array to map worst", func(b *testing.B) {
  344. m := map[string]struct{}{
  345. "one": {},
  346. "two": {},
  347. "thr": {},
  348. "fou": {},
  349. "fiv": {},
  350. "six": {},
  351. }
  352. a := [][]string{
  353. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  354. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  355. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  356. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  357. {"one", "two", "thr", "fou", "5iv", "6ix"},
  358. {"one", "two", "thr", "fou", "fiv", "6ix"},
  359. {"one", "two", "thr", "fou", "fiv", "six"},
  360. }
  361. for n := 0; n < b.N; n++ {
  362. ml(m, a)
  363. }
  364. })
  365. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  366. }
  367. func Test_parsePort(t *testing.T) {
  368. _, _, err := parsePort("")
  369. assert.EqualError(t, err, "was not a number; ``")
  370. _, _, err = parsePort(" ")
  371. assert.EqualError(t, err, "was not a number; ` `")
  372. _, _, err = parsePort("-")
  373. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  374. _, _, err = parsePort(" - ")
  375. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  376. _, _, err = parsePort("a-b")
  377. assert.EqualError(t, err, "beginning range was not a number; `a`")
  378. _, _, err = parsePort("1-b")
  379. assert.EqualError(t, err, "ending range was not a number; `b`")
  380. s, e, err := parsePort(" 1 - 2 ")
  381. assert.Equal(t, int32(1), s)
  382. assert.Equal(t, int32(2), e)
  383. assert.Nil(t, err)
  384. s, e, err = parsePort("0-1")
  385. assert.Equal(t, int32(0), s)
  386. assert.Equal(t, int32(0), e)
  387. assert.Nil(t, err)
  388. s, e, err = parsePort("9919")
  389. assert.Equal(t, int32(9919), s)
  390. assert.Equal(t, int32(9919), e)
  391. assert.Nil(t, err)
  392. s, e, err = parsePort("any")
  393. assert.Equal(t, int32(0), s)
  394. assert.Equal(t, int32(0), e)
  395. assert.Nil(t, err)
  396. }
  397. func TestNewFirewallFromConfig(t *testing.T) {
  398. // Test a bad rule definition
  399. c := &cert.NebulaCertificate{}
  400. conf := NewConfig()
  401. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  402. _, err := NewFirewallFromConfig(c, conf)
  403. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  404. // Test both port and code
  405. conf = NewConfig()
  406. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  407. _, err = NewFirewallFromConfig(c, conf)
  408. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  409. // Test missing host, group, cidr, ca_name and ca_sha
  410. conf = NewConfig()
  411. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  412. _, err = NewFirewallFromConfig(c, conf)
  413. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
  414. // Test code/port error
  415. conf = NewConfig()
  416. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  417. _, err = NewFirewallFromConfig(c, conf)
  418. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  419. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  420. _, err = NewFirewallFromConfig(c, conf)
  421. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  422. // Test proto error
  423. conf = NewConfig()
  424. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  425. _, err = NewFirewallFromConfig(c, conf)
  426. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  427. // Test cidr parse error
  428. conf = NewConfig()
  429. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  430. _, err = NewFirewallFromConfig(c, conf)
  431. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  432. // Test both group and groups
  433. conf = NewConfig()
  434. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  435. _, err = NewFirewallFromConfig(c, conf)
  436. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  437. }
  438. func TestAddFirewallRulesFromConfig(t *testing.T) {
  439. // Test adding tcp rule
  440. conf := NewConfig()
  441. mf := &mockFirewall{}
  442. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  443. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  444. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  445. // Test adding udp rule
  446. conf = NewConfig()
  447. mf = &mockFirewall{}
  448. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  449. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  450. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  451. // Test adding icmp rule
  452. conf = NewConfig()
  453. mf = &mockFirewall{}
  454. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  455. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  456. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  457. // Test adding any rule
  458. conf = NewConfig()
  459. mf = &mockFirewall{}
  460. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  461. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  462. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  463. // Test adding rule with ca_sha
  464. conf = NewConfig()
  465. mf = &mockFirewall{}
  466. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  467. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  468. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
  469. // Test adding rule with ca_name
  470. conf = NewConfig()
  471. mf = &mockFirewall{}
  472. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  473. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  474. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
  475. // Test single group
  476. conf = NewConfig()
  477. mf = &mockFirewall{}
  478. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  479. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  480. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  481. // Test single groups
  482. conf = NewConfig()
  483. mf = &mockFirewall{}
  484. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  485. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  486. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  487. // Test multiple AND groups
  488. conf = NewConfig()
  489. mf = &mockFirewall{}
  490. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  491. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  492. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
  493. // Test Add error
  494. conf = NewConfig()
  495. mf = &mockFirewall{}
  496. mf.nextCallReturn = errors.New("test error")
  497. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  498. assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
  499. }
  500. func TestTCPRTTTracking(t *testing.T) {
  501. b := make([]byte, 200)
  502. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  503. b[0] = 15
  504. b[60+12] = 15 << 4
  505. f := Firewall{
  506. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  507. }
  508. // Set SEQ to 1
  509. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  510. c := &conn{}
  511. setTCPRTTTracking(c, b)
  512. assert.Equal(t, uint32(1), c.Seq)
  513. // Bad ack - no ack flag
  514. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  515. assert.False(t, f.checkTCPRTT(c, b))
  516. // Bad ack, number is too low
  517. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  518. b[60+13] = uint8(0x10)
  519. assert.False(t, f.checkTCPRTT(c, b))
  520. // Good ack
  521. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  522. assert.True(t, f.checkTCPRTT(c, b))
  523. assert.Equal(t, uint32(0), c.Seq)
  524. // Set SEQ to 1
  525. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  526. c = &conn{}
  527. setTCPRTTTracking(c, b)
  528. assert.Equal(t, uint32(1), c.Seq)
  529. // Good acks
  530. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  531. assert.True(t, f.checkTCPRTT(c, b))
  532. assert.Equal(t, uint32(0), c.Seq)
  533. // Set SEQ to max uint32 - 20
  534. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  535. c = &conn{}
  536. setTCPRTTTracking(c, b)
  537. assert.Equal(t, ^uint32(0)-20, c.Seq)
  538. // Good acks
  539. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  540. assert.True(t, f.checkTCPRTT(c, b))
  541. assert.Equal(t, uint32(0), c.Seq)
  542. // Set SEQ to max uint32 / 2
  543. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  544. c = &conn{}
  545. setTCPRTTTracking(c, b)
  546. assert.Equal(t, ^uint32(0)/2, c.Seq)
  547. // Below
  548. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  549. assert.False(t, f.checkTCPRTT(c, b))
  550. assert.Equal(t, ^uint32(0)/2, c.Seq)
  551. // Halfway below
  552. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  553. assert.False(t, f.checkTCPRTT(c, b))
  554. assert.Equal(t, ^uint32(0)/2, c.Seq)
  555. // Halfway above is ok
  556. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  557. assert.True(t, f.checkTCPRTT(c, b))
  558. assert.Equal(t, uint32(0), c.Seq)
  559. // Set SEQ to max uint32
  560. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  561. c = &conn{}
  562. setTCPRTTTracking(c, b)
  563. assert.Equal(t, ^uint32(0), c.Seq)
  564. // Halfway + 1 above
  565. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  566. assert.False(t, f.checkTCPRTT(c, b))
  567. assert.Equal(t, ^uint32(0), c.Seq)
  568. // Halfway above
  569. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  570. assert.True(t, f.checkTCPRTT(c, b))
  571. assert.Equal(t, uint32(0), c.Seq)
  572. }
  573. type addRuleCall struct {
  574. incoming bool
  575. proto uint8
  576. startPort int32
  577. endPort int32
  578. groups []string
  579. host string
  580. ip *net.IPNet
  581. caName string
  582. caSha string
  583. }
  584. type mockFirewall struct {
  585. lastCall addRuleCall
  586. nextCallReturn error
  587. }
  588. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  589. mf.lastCall = addRuleCall{
  590. incoming: incoming,
  591. proto: proto,
  592. startPort: startPort,
  593. endPort: endPort,
  594. groups: groups,
  595. host: host,
  596. ip: ip,
  597. caName: caName,
  598. caSha: caSha,
  599. }
  600. err := mf.nextCallReturn
  601. mf.nextCallReturn = nil
  602. return err
  603. }