firewall_test.go 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. "net"
  8. "testing"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestNewFirewall(t *testing.T) {
  15. c := &cert.NebulaCertificate{}
  16. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  17. assert.NotNil(t, fw.Conns)
  18. assert.NotNil(t, fw.InRules)
  19. assert.NotNil(t, fw.OutRules)
  20. assert.NotNil(t, fw.TimerWheel)
  21. assert.Equal(t, time.Second, fw.TCPTimeout)
  22. assert.Equal(t, time.Minute, fw.UDPTimeout)
  23. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  24. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  25. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  26. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  27. fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
  28. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  29. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  30. fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
  31. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  32. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  33. fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
  34. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  35. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  36. fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
  37. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  38. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  39. fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
  40. assert.Equal(t, time.Hour, fw.TimerWheel.wheelDuration)
  41. assert.Equal(t, 3601, fw.TimerWheel.wheelLen)
  42. }
  43. func TestFirewall_AddRule(t *testing.T) {
  44. ob := &bytes.Buffer{}
  45. out := l.Out
  46. l.SetOutput(ob)
  47. defer l.SetOutput(out)
  48. c := &cert.NebulaCertificate{}
  49. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  50. assert.NotNil(t, fw.InRules)
  51. assert.NotNil(t, fw.OutRules)
  52. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  53. assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
  54. // An empty rule is any
  55. assert.True(t, fw.InRules.TCP[1].Any.Any)
  56. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  57. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  58. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
  59. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
  60. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
  61. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  62. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
  63. assert.False(t, fw.InRules.UDP[1].Any.Any)
  64. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
  65. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  66. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
  67. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
  68. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
  69. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  70. assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
  71. assert.False(t, fw.InRules.ICMP[1].Any.Any)
  72. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  73. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  74. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
  75. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
  76. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
  77. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  78. assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
  79. assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
  80. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
  81. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
  82. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
  83. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  84. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
  85. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  86. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  87. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
  88. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  89. // Set any and clear fields
  90. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  91. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
  92. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
  93. assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
  94. assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
  95. // run twice just to make sure
  96. //TODO: these ANY rules should clear the CA firewall portion
  97. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  98. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  99. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  100. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
  101. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
  102. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
  103. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
  104. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
  105. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  106. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  107. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  108. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  109. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  110. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
  111. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  112. // Test error conditions
  113. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  114. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
  115. assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
  116. }
  117. func TestFirewall_Drop(t *testing.T) {
  118. ob := &bytes.Buffer{}
  119. out := l.Out
  120. l.SetOutput(ob)
  121. defer l.SetOutput(out)
  122. p := FirewallPacket{
  123. ip2int(net.IPv4(1, 2, 3, 4)),
  124. ip2int(net.IPv4(1, 2, 3, 4)),
  125. 10,
  126. 90,
  127. fwProtoUDP,
  128. false,
  129. }
  130. ipNet := net.IPNet{
  131. IP: net.IPv4(1, 2, 3, 4),
  132. Mask: net.IPMask{255, 255, 255, 0},
  133. }
  134. c := cert.NebulaCertificate{
  135. Details: cert.NebulaCertificateDetails{
  136. Name: "host1",
  137. Ips: []*net.IPNet{&ipNet},
  138. Groups: []string{"default-group"},
  139. InvertedGroups: map[string]struct{}{"default-group": {}},
  140. Issuer: "signer-shasum",
  141. },
  142. }
  143. h := HostInfo{
  144. ConnectionState: &ConnectionState{
  145. peerCert: &c,
  146. },
  147. hostId: ip2int(ipNet.IP),
  148. }
  149. h.CreateRemoteCIDR(&c)
  150. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  151. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  152. cp := cert.NewCAPool()
  153. // Drop outbound
  154. assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
  155. // Allow inbound
  156. resetConntrack(fw)
  157. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  158. // Allow outbound because conntrack
  159. assert.False(t, fw.Drop([]byte{}, p, false, &h, cp))
  160. // test remote mismatch
  161. oldRemote := p.RemoteIP
  162. p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
  163. assert.True(t, fw.Drop([]byte{}, p, false, &h, cp))
  164. p.RemoteIP = oldRemote
  165. // ensure signer doesn't get in the way of group checks
  166. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  167. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
  168. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
  169. assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
  170. // test caSha doesn't drop on match
  171. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  172. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
  173. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
  174. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  175. // ensure ca name doesn't get in the way of group checks
  176. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  177. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  178. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
  179. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
  180. assert.True(t, fw.Drop([]byte{}, p, true, &h, cp))
  181. // test caName doesn't drop on match
  182. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  183. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  184. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
  185. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
  186. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  187. }
  188. func BenchmarkFirewallTable_match(b *testing.B) {
  189. ft := FirewallTable{
  190. TCP: firewallPort{},
  191. }
  192. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  193. _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
  194. _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
  195. _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
  196. _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
  197. _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
  198. cp := cert.NewCAPool()
  199. b.Run("fail on proto", func(b *testing.B) {
  200. c := &cert.NebulaCertificate{}
  201. for n := 0; n < b.N; n++ {
  202. ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
  203. }
  204. })
  205. b.Run("fail on port", func(b *testing.B) {
  206. c := &cert.NebulaCertificate{}
  207. for n := 0; n < b.N; n++ {
  208. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
  209. }
  210. })
  211. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  212. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  213. c := &cert.NebulaCertificate{
  214. Details: cert.NebulaCertificateDetails{
  215. InvertedGroups: map[string]struct{}{"nope": {}},
  216. Name: "nope",
  217. Ips: []*net.IPNet{ip},
  218. },
  219. }
  220. for n := 0; n < b.N; n++ {
  221. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  222. }
  223. })
  224. b.Run("pass on group", func(b *testing.B) {
  225. c := &cert.NebulaCertificate{
  226. Details: cert.NebulaCertificateDetails{
  227. InvertedGroups: map[string]struct{}{"good-group": {}},
  228. Name: "nope",
  229. },
  230. }
  231. for n := 0; n < b.N; n++ {
  232. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  233. }
  234. })
  235. b.Run("pass on name", func(b *testing.B) {
  236. c := &cert.NebulaCertificate{
  237. Details: cert.NebulaCertificateDetails{
  238. InvertedGroups: map[string]struct{}{"nope": {}},
  239. Name: "good-host",
  240. },
  241. }
  242. for n := 0; n < b.N; n++ {
  243. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  244. }
  245. })
  246. b.Run("pass on ip", func(b *testing.B) {
  247. ip := ip2int(net.IPv4(172, 1, 1, 1))
  248. c := &cert.NebulaCertificate{
  249. Details: cert.NebulaCertificateDetails{
  250. InvertedGroups: map[string]struct{}{"nope": {}},
  251. Name: "good-host",
  252. },
  253. }
  254. for n := 0; n < b.N; n++ {
  255. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  256. }
  257. })
  258. _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
  259. b.Run("pass on ip with any port", func(b *testing.B) {
  260. ip := ip2int(net.IPv4(172, 1, 1, 1))
  261. c := &cert.NebulaCertificate{
  262. Details: cert.NebulaCertificateDetails{
  263. InvertedGroups: map[string]struct{}{"nope": {}},
  264. Name: "good-host",
  265. },
  266. }
  267. for n := 0; n < b.N; n++ {
  268. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  269. }
  270. })
  271. }
  272. func TestFirewall_Drop2(t *testing.T) {
  273. ob := &bytes.Buffer{}
  274. out := l.Out
  275. l.SetOutput(ob)
  276. defer l.SetOutput(out)
  277. p := FirewallPacket{
  278. ip2int(net.IPv4(1, 2, 3, 4)),
  279. ip2int(net.IPv4(1, 2, 3, 4)),
  280. 10,
  281. 90,
  282. fwProtoUDP,
  283. false,
  284. }
  285. ipNet := net.IPNet{
  286. IP: net.IPv4(1, 2, 3, 4),
  287. Mask: net.IPMask{255, 255, 255, 0},
  288. }
  289. c := cert.NebulaCertificate{
  290. Details: cert.NebulaCertificateDetails{
  291. Name: "host1",
  292. Ips: []*net.IPNet{&ipNet},
  293. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  294. },
  295. }
  296. h := HostInfo{
  297. ConnectionState: &ConnectionState{
  298. peerCert: &c,
  299. },
  300. hostId: ip2int(ipNet.IP),
  301. }
  302. h.CreateRemoteCIDR(&c)
  303. c1 := cert.NebulaCertificate{
  304. Details: cert.NebulaCertificateDetails{
  305. Name: "host1",
  306. Ips: []*net.IPNet{&ipNet},
  307. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  308. },
  309. }
  310. h1 := HostInfo{
  311. ConnectionState: &ConnectionState{
  312. peerCert: &c1,
  313. },
  314. }
  315. h1.CreateRemoteCIDR(&c1)
  316. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  317. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
  318. cp := cert.NewCAPool()
  319. // h1/c1 lacks the proper groups
  320. assert.True(t, fw.Drop([]byte{}, p, true, &h1, cp))
  321. // c has the proper groups
  322. resetConntrack(fw)
  323. assert.False(t, fw.Drop([]byte{}, p, true, &h, cp))
  324. }
  325. func TestFirewall_Drop3(t *testing.T) {
  326. ob := &bytes.Buffer{}
  327. out := l.Out
  328. l.SetOutput(ob)
  329. defer l.SetOutput(out)
  330. p := FirewallPacket{
  331. ip2int(net.IPv4(1, 2, 3, 4)),
  332. ip2int(net.IPv4(1, 2, 3, 4)),
  333. 1,
  334. 1,
  335. fwProtoUDP,
  336. false,
  337. }
  338. ipNet := net.IPNet{
  339. IP: net.IPv4(1, 2, 3, 4),
  340. Mask: net.IPMask{255, 255, 255, 0},
  341. }
  342. c := cert.NebulaCertificate{
  343. Details: cert.NebulaCertificateDetails{
  344. Name: "host-owner",
  345. Ips: []*net.IPNet{&ipNet},
  346. },
  347. }
  348. c1 := cert.NebulaCertificate{
  349. Details: cert.NebulaCertificateDetails{
  350. Name: "host1",
  351. Ips: []*net.IPNet{&ipNet},
  352. Issuer: "signer-sha-bad",
  353. },
  354. }
  355. h1 := HostInfo{
  356. ConnectionState: &ConnectionState{
  357. peerCert: &c1,
  358. },
  359. hostId: ip2int(ipNet.IP),
  360. }
  361. h1.CreateRemoteCIDR(&c1)
  362. c2 := cert.NebulaCertificate{
  363. Details: cert.NebulaCertificateDetails{
  364. Name: "host2",
  365. Ips: []*net.IPNet{&ipNet},
  366. Issuer: "signer-sha",
  367. },
  368. }
  369. h2 := HostInfo{
  370. ConnectionState: &ConnectionState{
  371. peerCert: &c2,
  372. },
  373. hostId: ip2int(ipNet.IP),
  374. }
  375. h2.CreateRemoteCIDR(&c2)
  376. c3 := cert.NebulaCertificate{
  377. Details: cert.NebulaCertificateDetails{
  378. Name: "host3",
  379. Ips: []*net.IPNet{&ipNet},
  380. Issuer: "signer-sha-bad",
  381. },
  382. }
  383. h3 := HostInfo{
  384. ConnectionState: &ConnectionState{
  385. peerCert: &c3,
  386. },
  387. hostId: ip2int(ipNet.IP),
  388. }
  389. h3.CreateRemoteCIDR(&c3)
  390. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  391. assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
  392. assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
  393. cp := cert.NewCAPool()
  394. // c1 should pass because host match
  395. assert.False(t, fw.Drop([]byte{}, p, true, &h1, cp))
  396. // c2 should pass because ca sha match
  397. resetConntrack(fw)
  398. assert.False(t, fw.Drop([]byte{}, p, true, &h2, cp))
  399. // c3 should fail because no match
  400. resetConntrack(fw)
  401. assert.True(t, fw.Drop([]byte{}, p, true, &h3, cp))
  402. }
  403. func BenchmarkLookup(b *testing.B) {
  404. ml := func(m map[string]struct{}, a [][]string) {
  405. for n := 0; n < b.N; n++ {
  406. for _, sg := range a {
  407. found := false
  408. for _, g := range sg {
  409. if _, ok := m[g]; !ok {
  410. found = false
  411. break
  412. }
  413. found = true
  414. }
  415. if found {
  416. return
  417. }
  418. }
  419. }
  420. }
  421. b.Run("array to map best", func(b *testing.B) {
  422. m := map[string]struct{}{
  423. "1ne": {},
  424. "2wo": {},
  425. "3hr": {},
  426. "4ou": {},
  427. "5iv": {},
  428. "6ix": {},
  429. }
  430. a := [][]string{
  431. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  432. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  433. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  434. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  435. {"one", "two", "thr", "fou", "5iv", "6ix"},
  436. {"one", "two", "thr", "fou", "fiv", "6ix"},
  437. {"one", "two", "thr", "fou", "fiv", "six"},
  438. }
  439. for n := 0; n < b.N; n++ {
  440. ml(m, a)
  441. }
  442. })
  443. b.Run("array to map worst", func(b *testing.B) {
  444. m := map[string]struct{}{
  445. "one": {},
  446. "two": {},
  447. "thr": {},
  448. "fou": {},
  449. "fiv": {},
  450. "six": {},
  451. }
  452. a := [][]string{
  453. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  454. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  455. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  456. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  457. {"one", "two", "thr", "fou", "5iv", "6ix"},
  458. {"one", "two", "thr", "fou", "fiv", "6ix"},
  459. {"one", "two", "thr", "fou", "fiv", "six"},
  460. }
  461. for n := 0; n < b.N; n++ {
  462. ml(m, a)
  463. }
  464. })
  465. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  466. }
  467. func Test_parsePort(t *testing.T) {
  468. _, _, err := parsePort("")
  469. assert.EqualError(t, err, "was not a number; ``")
  470. _, _, err = parsePort(" ")
  471. assert.EqualError(t, err, "was not a number; ` `")
  472. _, _, err = parsePort("-")
  473. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  474. _, _, err = parsePort(" - ")
  475. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  476. _, _, err = parsePort("a-b")
  477. assert.EqualError(t, err, "beginning range was not a number; `a`")
  478. _, _, err = parsePort("1-b")
  479. assert.EqualError(t, err, "ending range was not a number; `b`")
  480. s, e, err := parsePort(" 1 - 2 ")
  481. assert.Equal(t, int32(1), s)
  482. assert.Equal(t, int32(2), e)
  483. assert.Nil(t, err)
  484. s, e, err = parsePort("0-1")
  485. assert.Equal(t, int32(0), s)
  486. assert.Equal(t, int32(0), e)
  487. assert.Nil(t, err)
  488. s, e, err = parsePort("9919")
  489. assert.Equal(t, int32(9919), s)
  490. assert.Equal(t, int32(9919), e)
  491. assert.Nil(t, err)
  492. s, e, err = parsePort("any")
  493. assert.Equal(t, int32(0), s)
  494. assert.Equal(t, int32(0), e)
  495. assert.Nil(t, err)
  496. }
  497. func TestNewFirewallFromConfig(t *testing.T) {
  498. // Test a bad rule definition
  499. c := &cert.NebulaCertificate{}
  500. conf := NewConfig()
  501. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  502. _, err := NewFirewallFromConfig(c, conf)
  503. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  504. // Test both port and code
  505. conf = NewConfig()
  506. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  507. _, err = NewFirewallFromConfig(c, conf)
  508. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  509. // Test missing host, group, cidr, ca_name and ca_sha
  510. conf = NewConfig()
  511. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  512. _, err = NewFirewallFromConfig(c, conf)
  513. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
  514. // Test code/port error
  515. conf = NewConfig()
  516. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  517. _, err = NewFirewallFromConfig(c, conf)
  518. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  519. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  520. _, err = NewFirewallFromConfig(c, conf)
  521. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  522. // Test proto error
  523. conf = NewConfig()
  524. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  525. _, err = NewFirewallFromConfig(c, conf)
  526. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  527. // Test cidr parse error
  528. conf = NewConfig()
  529. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  530. _, err = NewFirewallFromConfig(c, conf)
  531. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  532. // Test both group and groups
  533. conf = NewConfig()
  534. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  535. _, err = NewFirewallFromConfig(c, conf)
  536. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  537. }
  538. func TestAddFirewallRulesFromConfig(t *testing.T) {
  539. // Test adding tcp rule
  540. conf := NewConfig()
  541. mf := &mockFirewall{}
  542. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  543. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  544. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  545. // Test adding udp rule
  546. conf = NewConfig()
  547. mf = &mockFirewall{}
  548. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  549. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  550. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  551. // Test adding icmp rule
  552. conf = NewConfig()
  553. mf = &mockFirewall{}
  554. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  555. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  556. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  557. // Test adding any rule
  558. conf = NewConfig()
  559. mf = &mockFirewall{}
  560. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  561. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  562. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  563. // Test adding rule with ca_sha
  564. conf = NewConfig()
  565. mf = &mockFirewall{}
  566. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  567. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  568. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
  569. // Test adding rule with ca_name
  570. conf = NewConfig()
  571. mf = &mockFirewall{}
  572. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  573. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  574. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
  575. // Test single group
  576. conf = NewConfig()
  577. mf = &mockFirewall{}
  578. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  579. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  580. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  581. // Test single groups
  582. conf = NewConfig()
  583. mf = &mockFirewall{}
  584. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  585. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  586. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  587. // Test multiple AND groups
  588. conf = NewConfig()
  589. mf = &mockFirewall{}
  590. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  591. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  592. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
  593. // Test Add error
  594. conf = NewConfig()
  595. mf = &mockFirewall{}
  596. mf.nextCallReturn = errors.New("test error")
  597. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  598. assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
  599. }
  600. func TestTCPRTTTracking(t *testing.T) {
  601. b := make([]byte, 200)
  602. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  603. b[0] = 15
  604. b[60+12] = 15 << 4
  605. f := Firewall{
  606. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  607. }
  608. // Set SEQ to 1
  609. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  610. c := &conn{}
  611. setTCPRTTTracking(c, b)
  612. assert.Equal(t, uint32(1), c.Seq)
  613. // Bad ack - no ack flag
  614. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  615. assert.False(t, f.checkTCPRTT(c, b))
  616. // Bad ack, number is too low
  617. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  618. b[60+13] = uint8(0x10)
  619. assert.False(t, f.checkTCPRTT(c, b))
  620. // Good ack
  621. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  622. assert.True(t, f.checkTCPRTT(c, b))
  623. assert.Equal(t, uint32(0), c.Seq)
  624. // Set SEQ to 1
  625. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  626. c = &conn{}
  627. setTCPRTTTracking(c, b)
  628. assert.Equal(t, uint32(1), c.Seq)
  629. // Good acks
  630. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  631. assert.True(t, f.checkTCPRTT(c, b))
  632. assert.Equal(t, uint32(0), c.Seq)
  633. // Set SEQ to max uint32 - 20
  634. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  635. c = &conn{}
  636. setTCPRTTTracking(c, b)
  637. assert.Equal(t, ^uint32(0)-20, c.Seq)
  638. // Good acks
  639. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  640. assert.True(t, f.checkTCPRTT(c, b))
  641. assert.Equal(t, uint32(0), c.Seq)
  642. // Set SEQ to max uint32 / 2
  643. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  644. c = &conn{}
  645. setTCPRTTTracking(c, b)
  646. assert.Equal(t, ^uint32(0)/2, c.Seq)
  647. // Below
  648. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  649. assert.False(t, f.checkTCPRTT(c, b))
  650. assert.Equal(t, ^uint32(0)/2, c.Seq)
  651. // Halfway below
  652. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  653. assert.False(t, f.checkTCPRTT(c, b))
  654. assert.Equal(t, ^uint32(0)/2, c.Seq)
  655. // Halfway above is ok
  656. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  657. assert.True(t, f.checkTCPRTT(c, b))
  658. assert.Equal(t, uint32(0), c.Seq)
  659. // Set SEQ to max uint32
  660. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  661. c = &conn{}
  662. setTCPRTTTracking(c, b)
  663. assert.Equal(t, ^uint32(0), c.Seq)
  664. // Halfway + 1 above
  665. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  666. assert.False(t, f.checkTCPRTT(c, b))
  667. assert.Equal(t, ^uint32(0), c.Seq)
  668. // Halfway above
  669. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  670. assert.True(t, f.checkTCPRTT(c, b))
  671. assert.Equal(t, uint32(0), c.Seq)
  672. }
  673. func TestFirewall_convertRule(t *testing.T) {
  674. ob := &bytes.Buffer{}
  675. out := l.Out
  676. l.SetOutput(ob)
  677. defer l.SetOutput(out)
  678. // Ensure group array of 1 is converted and a warning is printed
  679. c := map[interface{}]interface{}{
  680. "group": []interface{}{"group1"},
  681. }
  682. r, err := convertRule(c, "test", 1)
  683. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  684. assert.Nil(t, err)
  685. assert.Equal(t, "group1", r.Group)
  686. // Ensure group array of > 1 is errord
  687. ob.Reset()
  688. c = map[interface{}]interface{}{
  689. "group": []interface{}{"group1", "group2"},
  690. }
  691. r, err = convertRule(c, "test", 1)
  692. assert.Equal(t, "", ob.String())
  693. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  694. // Make sure a well formed group is alright
  695. ob.Reset()
  696. c = map[interface{}]interface{}{
  697. "group": "group1",
  698. }
  699. r, err = convertRule(c, "test", 1)
  700. assert.Nil(t, err)
  701. assert.Equal(t, "group1", r.Group)
  702. }
  703. type addRuleCall struct {
  704. incoming bool
  705. proto uint8
  706. startPort int32
  707. endPort int32
  708. groups []string
  709. host string
  710. ip *net.IPNet
  711. caName string
  712. caSha string
  713. }
  714. type mockFirewall struct {
  715. lastCall addRuleCall
  716. nextCallReturn error
  717. }
  718. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  719. mf.lastCall = addRuleCall{
  720. incoming: incoming,
  721. proto: proto,
  722. startPort: startPort,
  723. endPort: endPort,
  724. groups: groups,
  725. host: host,
  726. ip: ip,
  727. caName: caName,
  728. caSha: caSha,
  729. }
  730. err := mf.nextCallReturn
  731. mf.nextCallReturn = nil
  732. return err
  733. }
  734. func resetConntrack(fw *Firewall) {
  735. fw.connMutex.Lock()
  736. fw.Conns = map[FirewallPacket]*conn{}
  737. fw.connMutex.Unlock()
  738. }