firewall_test.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. "net"
  8. "testing"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestNewFirewall(t *testing.T) {
  15. l := NewTestLogger()
  16. c := &cert.NebulaCertificate{}
  17. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  18. conntrack := fw.Conntrack
  19. assert.NotNil(t, conntrack)
  20. assert.NotNil(t, conntrack.Conns)
  21. assert.NotNil(t, conntrack.TimerWheel)
  22. assert.NotNil(t, fw.InRules)
  23. assert.NotNil(t, fw.OutRules)
  24. assert.Equal(t, time.Second, fw.TCPTimeout)
  25. assert.Equal(t, time.Minute, fw.UDPTimeout)
  26. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  27. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  28. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  29. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  30. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  33. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  34. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  35. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  36. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  37. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  38. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  39. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  40. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  41. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  42. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  43. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  44. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  45. }
  46. func TestFirewall_AddRule(t *testing.T) {
  47. l := NewTestLogger()
  48. ob := &bytes.Buffer{}
  49. l.SetOutput(ob)
  50. c := &cert.NebulaCertificate{}
  51. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  52. assert.NotNil(t, fw.InRules)
  53. assert.NotNil(t, fw.OutRules)
  54. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  55. assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
  56. // An empty rule is any
  57. assert.True(t, fw.InRules.TCP[1].Any.Any)
  58. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  59. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  60. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
  61. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
  62. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
  63. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  64. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
  65. assert.False(t, fw.InRules.UDP[1].Any.Any)
  66. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
  67. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  68. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
  69. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
  70. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
  71. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  72. assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
  73. assert.False(t, fw.InRules.ICMP[1].Any.Any)
  74. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  75. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  76. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
  77. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
  78. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
  79. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  80. assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
  81. assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
  82. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
  83. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
  84. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
  85. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  86. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
  87. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  89. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
  90. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  91. // Set any and clear fields
  92. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  93. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
  94. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
  95. assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
  96. assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
  97. // run twice just to make sure
  98. //TODO: these ANY rules should clear the CA firewall portion
  99. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  100. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  101. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  102. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
  103. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
  104. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
  105. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
  106. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
  107. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  108. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  109. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  110. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  111. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  112. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
  113. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  114. // Test error conditions
  115. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  116. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
  117. assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
  118. }
  119. func TestFirewall_Drop(t *testing.T) {
  120. l := NewTestLogger()
  121. ob := &bytes.Buffer{}
  122. l.SetOutput(ob)
  123. p := FirewallPacket{
  124. ip2int(net.IPv4(1, 2, 3, 4)),
  125. ip2int(net.IPv4(1, 2, 3, 4)),
  126. 10,
  127. 90,
  128. fwProtoUDP,
  129. false,
  130. }
  131. ipNet := net.IPNet{
  132. IP: net.IPv4(1, 2, 3, 4),
  133. Mask: net.IPMask{255, 255, 255, 0},
  134. }
  135. c := cert.NebulaCertificate{
  136. Details: cert.NebulaCertificateDetails{
  137. Name: "host1",
  138. Ips: []*net.IPNet{&ipNet},
  139. Groups: []string{"default-group"},
  140. InvertedGroups: map[string]struct{}{"default-group": {}},
  141. Issuer: "signer-shasum",
  142. },
  143. }
  144. h := HostInfo{
  145. ConnectionState: &ConnectionState{
  146. peerCert: &c,
  147. },
  148. hostId: ip2int(ipNet.IP),
  149. }
  150. h.CreateRemoteCIDR(&c)
  151. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  152. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  153. cp := cert.NewCAPool()
  154. // Drop outbound
  155. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  156. // Allow inbound
  157. resetConntrack(fw)
  158. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  159. // Allow outbound because conntrack
  160. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  161. // test remote mismatch
  162. oldRemote := p.RemoteIP
  163. p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
  164. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
  165. p.RemoteIP = oldRemote
  166. // ensure signer doesn't get in the way of group checks
  167. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  168. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
  169. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
  170. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  171. // test caSha doesn't drop on match
  172. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  173. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
  174. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
  175. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  176. // ensure ca name doesn't get in the way of group checks
  177. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  178. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  179. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
  180. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
  181. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  182. // test caName doesn't drop on match
  183. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  184. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  185. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
  186. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
  187. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  188. }
  189. func BenchmarkFirewallTable_match(b *testing.B) {
  190. ft := FirewallTable{
  191. TCP: firewallPort{},
  192. }
  193. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  194. _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
  195. _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
  196. _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
  197. _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
  198. _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
  199. cp := cert.NewCAPool()
  200. b.Run("fail on proto", func(b *testing.B) {
  201. c := &cert.NebulaCertificate{}
  202. for n := 0; n < b.N; n++ {
  203. ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
  204. }
  205. })
  206. b.Run("fail on port", func(b *testing.B) {
  207. c := &cert.NebulaCertificate{}
  208. for n := 0; n < b.N; n++ {
  209. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
  210. }
  211. })
  212. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  213. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  214. c := &cert.NebulaCertificate{
  215. Details: cert.NebulaCertificateDetails{
  216. InvertedGroups: map[string]struct{}{"nope": {}},
  217. Name: "nope",
  218. Ips: []*net.IPNet{ip},
  219. },
  220. }
  221. for n := 0; n < b.N; n++ {
  222. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  223. }
  224. })
  225. b.Run("pass on group", func(b *testing.B) {
  226. c := &cert.NebulaCertificate{
  227. Details: cert.NebulaCertificateDetails{
  228. InvertedGroups: map[string]struct{}{"good-group": {}},
  229. Name: "nope",
  230. },
  231. }
  232. for n := 0; n < b.N; n++ {
  233. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  234. }
  235. })
  236. b.Run("pass on name", func(b *testing.B) {
  237. c := &cert.NebulaCertificate{
  238. Details: cert.NebulaCertificateDetails{
  239. InvertedGroups: map[string]struct{}{"nope": {}},
  240. Name: "good-host",
  241. },
  242. }
  243. for n := 0; n < b.N; n++ {
  244. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  245. }
  246. })
  247. b.Run("pass on ip", func(b *testing.B) {
  248. ip := ip2int(net.IPv4(172, 1, 1, 1))
  249. c := &cert.NebulaCertificate{
  250. Details: cert.NebulaCertificateDetails{
  251. InvertedGroups: map[string]struct{}{"nope": {}},
  252. Name: "good-host",
  253. },
  254. }
  255. for n := 0; n < b.N; n++ {
  256. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  257. }
  258. })
  259. _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
  260. b.Run("pass on ip with any port", func(b *testing.B) {
  261. ip := ip2int(net.IPv4(172, 1, 1, 1))
  262. c := &cert.NebulaCertificate{
  263. Details: cert.NebulaCertificateDetails{
  264. InvertedGroups: map[string]struct{}{"nope": {}},
  265. Name: "good-host",
  266. },
  267. }
  268. for n := 0; n < b.N; n++ {
  269. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  270. }
  271. })
  272. }
  273. func TestFirewall_Drop2(t *testing.T) {
  274. l := NewTestLogger()
  275. ob := &bytes.Buffer{}
  276. l.SetOutput(ob)
  277. p := FirewallPacket{
  278. ip2int(net.IPv4(1, 2, 3, 4)),
  279. ip2int(net.IPv4(1, 2, 3, 4)),
  280. 10,
  281. 90,
  282. fwProtoUDP,
  283. false,
  284. }
  285. ipNet := net.IPNet{
  286. IP: net.IPv4(1, 2, 3, 4),
  287. Mask: net.IPMask{255, 255, 255, 0},
  288. }
  289. c := cert.NebulaCertificate{
  290. Details: cert.NebulaCertificateDetails{
  291. Name: "host1",
  292. Ips: []*net.IPNet{&ipNet},
  293. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  294. },
  295. }
  296. h := HostInfo{
  297. ConnectionState: &ConnectionState{
  298. peerCert: &c,
  299. },
  300. hostId: ip2int(ipNet.IP),
  301. }
  302. h.CreateRemoteCIDR(&c)
  303. c1 := cert.NebulaCertificate{
  304. Details: cert.NebulaCertificateDetails{
  305. Name: "host1",
  306. Ips: []*net.IPNet{&ipNet},
  307. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  308. },
  309. }
  310. h1 := HostInfo{
  311. ConnectionState: &ConnectionState{
  312. peerCert: &c1,
  313. },
  314. }
  315. h1.CreateRemoteCIDR(&c1)
  316. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  317. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
  318. cp := cert.NewCAPool()
  319. // h1/c1 lacks the proper groups
  320. assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
  321. // c has the proper groups
  322. resetConntrack(fw)
  323. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  324. }
  325. func TestFirewall_Drop3(t *testing.T) {
  326. l := NewTestLogger()
  327. ob := &bytes.Buffer{}
  328. l.SetOutput(ob)
  329. p := FirewallPacket{
  330. ip2int(net.IPv4(1, 2, 3, 4)),
  331. ip2int(net.IPv4(1, 2, 3, 4)),
  332. 1,
  333. 1,
  334. fwProtoUDP,
  335. false,
  336. }
  337. ipNet := net.IPNet{
  338. IP: net.IPv4(1, 2, 3, 4),
  339. Mask: net.IPMask{255, 255, 255, 0},
  340. }
  341. c := cert.NebulaCertificate{
  342. Details: cert.NebulaCertificateDetails{
  343. Name: "host-owner",
  344. Ips: []*net.IPNet{&ipNet},
  345. },
  346. }
  347. c1 := cert.NebulaCertificate{
  348. Details: cert.NebulaCertificateDetails{
  349. Name: "host1",
  350. Ips: []*net.IPNet{&ipNet},
  351. Issuer: "signer-sha-bad",
  352. },
  353. }
  354. h1 := HostInfo{
  355. ConnectionState: &ConnectionState{
  356. peerCert: &c1,
  357. },
  358. hostId: ip2int(ipNet.IP),
  359. }
  360. h1.CreateRemoteCIDR(&c1)
  361. c2 := cert.NebulaCertificate{
  362. Details: cert.NebulaCertificateDetails{
  363. Name: "host2",
  364. Ips: []*net.IPNet{&ipNet},
  365. Issuer: "signer-sha",
  366. },
  367. }
  368. h2 := HostInfo{
  369. ConnectionState: &ConnectionState{
  370. peerCert: &c2,
  371. },
  372. hostId: ip2int(ipNet.IP),
  373. }
  374. h2.CreateRemoteCIDR(&c2)
  375. c3 := cert.NebulaCertificate{
  376. Details: cert.NebulaCertificateDetails{
  377. Name: "host3",
  378. Ips: []*net.IPNet{&ipNet},
  379. Issuer: "signer-sha-bad",
  380. },
  381. }
  382. h3 := HostInfo{
  383. ConnectionState: &ConnectionState{
  384. peerCert: &c3,
  385. },
  386. hostId: ip2int(ipNet.IP),
  387. }
  388. h3.CreateRemoteCIDR(&c3)
  389. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  390. assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
  391. assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
  392. cp := cert.NewCAPool()
  393. // c1 should pass because host match
  394. assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
  395. // c2 should pass because ca sha match
  396. resetConntrack(fw)
  397. assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
  398. // c3 should fail because no match
  399. resetConntrack(fw)
  400. assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
  401. }
  402. func TestFirewall_DropConntrackReload(t *testing.T) {
  403. l := NewTestLogger()
  404. ob := &bytes.Buffer{}
  405. l.SetOutput(ob)
  406. p := FirewallPacket{
  407. ip2int(net.IPv4(1, 2, 3, 4)),
  408. ip2int(net.IPv4(1, 2, 3, 4)),
  409. 10,
  410. 90,
  411. fwProtoUDP,
  412. false,
  413. }
  414. ipNet := net.IPNet{
  415. IP: net.IPv4(1, 2, 3, 4),
  416. Mask: net.IPMask{255, 255, 255, 0},
  417. }
  418. c := cert.NebulaCertificate{
  419. Details: cert.NebulaCertificateDetails{
  420. Name: "host1",
  421. Ips: []*net.IPNet{&ipNet},
  422. Groups: []string{"default-group"},
  423. InvertedGroups: map[string]struct{}{"default-group": {}},
  424. Issuer: "signer-shasum",
  425. },
  426. }
  427. h := HostInfo{
  428. ConnectionState: &ConnectionState{
  429. peerCert: &c,
  430. },
  431. hostId: ip2int(ipNet.IP),
  432. }
  433. h.CreateRemoteCIDR(&c)
  434. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  435. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  436. cp := cert.NewCAPool()
  437. // Drop outbound
  438. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  439. // Allow inbound
  440. resetConntrack(fw)
  441. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  442. // Allow outbound because conntrack
  443. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  444. oldFw := fw
  445. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  446. assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
  447. fw.Conntrack = oldFw.Conntrack
  448. fw.rulesVersion = oldFw.rulesVersion + 1
  449. // Allow outbound because conntrack and new rules allow port 10
  450. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  451. oldFw = fw
  452. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  453. assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
  454. fw.Conntrack = oldFw.Conntrack
  455. fw.rulesVersion = oldFw.rulesVersion + 1
  456. // Drop outbound because conntrack doesn't match new ruleset
  457. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  458. }
  459. func BenchmarkLookup(b *testing.B) {
  460. ml := func(m map[string]struct{}, a [][]string) {
  461. for n := 0; n < b.N; n++ {
  462. for _, sg := range a {
  463. found := false
  464. for _, g := range sg {
  465. if _, ok := m[g]; !ok {
  466. found = false
  467. break
  468. }
  469. found = true
  470. }
  471. if found {
  472. return
  473. }
  474. }
  475. }
  476. }
  477. b.Run("array to map best", func(b *testing.B) {
  478. m := map[string]struct{}{
  479. "1ne": {},
  480. "2wo": {},
  481. "3hr": {},
  482. "4ou": {},
  483. "5iv": {},
  484. "6ix": {},
  485. }
  486. a := [][]string{
  487. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  488. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  489. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  490. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  491. {"one", "two", "thr", "fou", "5iv", "6ix"},
  492. {"one", "two", "thr", "fou", "fiv", "6ix"},
  493. {"one", "two", "thr", "fou", "fiv", "six"},
  494. }
  495. for n := 0; n < b.N; n++ {
  496. ml(m, a)
  497. }
  498. })
  499. b.Run("array to map worst", func(b *testing.B) {
  500. m := map[string]struct{}{
  501. "one": {},
  502. "two": {},
  503. "thr": {},
  504. "fou": {},
  505. "fiv": {},
  506. "six": {},
  507. }
  508. a := [][]string{
  509. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  510. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  511. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  512. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  513. {"one", "two", "thr", "fou", "5iv", "6ix"},
  514. {"one", "two", "thr", "fou", "fiv", "6ix"},
  515. {"one", "two", "thr", "fou", "fiv", "six"},
  516. }
  517. for n := 0; n < b.N; n++ {
  518. ml(m, a)
  519. }
  520. })
  521. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  522. }
  523. func Test_parsePort(t *testing.T) {
  524. _, _, err := parsePort("")
  525. assert.EqualError(t, err, "was not a number; ``")
  526. _, _, err = parsePort(" ")
  527. assert.EqualError(t, err, "was not a number; ` `")
  528. _, _, err = parsePort("-")
  529. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  530. _, _, err = parsePort(" - ")
  531. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  532. _, _, err = parsePort("a-b")
  533. assert.EqualError(t, err, "beginning range was not a number; `a`")
  534. _, _, err = parsePort("1-b")
  535. assert.EqualError(t, err, "ending range was not a number; `b`")
  536. s, e, err := parsePort(" 1 - 2 ")
  537. assert.Equal(t, int32(1), s)
  538. assert.Equal(t, int32(2), e)
  539. assert.Nil(t, err)
  540. s, e, err = parsePort("0-1")
  541. assert.Equal(t, int32(0), s)
  542. assert.Equal(t, int32(0), e)
  543. assert.Nil(t, err)
  544. s, e, err = parsePort("9919")
  545. assert.Equal(t, int32(9919), s)
  546. assert.Equal(t, int32(9919), e)
  547. assert.Nil(t, err)
  548. s, e, err = parsePort("any")
  549. assert.Equal(t, int32(0), s)
  550. assert.Equal(t, int32(0), e)
  551. assert.Nil(t, err)
  552. }
  553. func TestNewFirewallFromConfig(t *testing.T) {
  554. l := NewTestLogger()
  555. // Test a bad rule definition
  556. c := &cert.NebulaCertificate{}
  557. conf := NewConfig(l)
  558. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  559. _, err := NewFirewallFromConfig(l, c, conf)
  560. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  561. // Test both port and code
  562. conf = NewConfig(l)
  563. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  564. _, err = NewFirewallFromConfig(l, c, conf)
  565. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  566. // Test missing host, group, cidr, ca_name and ca_sha
  567. conf = NewConfig(l)
  568. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  569. _, err = NewFirewallFromConfig(l, c, conf)
  570. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
  571. // Test code/port error
  572. conf = NewConfig(l)
  573. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  574. _, err = NewFirewallFromConfig(l, c, conf)
  575. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  576. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  577. _, err = NewFirewallFromConfig(l, c, conf)
  578. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  579. // Test proto error
  580. conf = NewConfig(l)
  581. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  582. _, err = NewFirewallFromConfig(l, c, conf)
  583. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  584. // Test cidr parse error
  585. conf = NewConfig(l)
  586. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  587. _, err = NewFirewallFromConfig(l, c, conf)
  588. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  589. // Test both group and groups
  590. conf = NewConfig(l)
  591. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  592. _, err = NewFirewallFromConfig(l, c, conf)
  593. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  594. }
  595. func TestAddFirewallRulesFromConfig(t *testing.T) {
  596. l := NewTestLogger()
  597. // Test adding tcp rule
  598. conf := NewConfig(l)
  599. mf := &mockFirewall{}
  600. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  601. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  602. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  603. // Test adding udp rule
  604. conf = NewConfig(l)
  605. mf = &mockFirewall{}
  606. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  607. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  608. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  609. // Test adding icmp rule
  610. conf = NewConfig(l)
  611. mf = &mockFirewall{}
  612. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  613. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  614. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  615. // Test adding any rule
  616. conf = NewConfig(l)
  617. mf = &mockFirewall{}
  618. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  619. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  620. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  621. // Test adding rule with ca_sha
  622. conf = NewConfig(l)
  623. mf = &mockFirewall{}
  624. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  625. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  626. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
  627. // Test adding rule with ca_name
  628. conf = NewConfig(l)
  629. mf = &mockFirewall{}
  630. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  631. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  632. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
  633. // Test single group
  634. conf = NewConfig(l)
  635. mf = &mockFirewall{}
  636. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  637. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  638. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  639. // Test single groups
  640. conf = NewConfig(l)
  641. mf = &mockFirewall{}
  642. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  643. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  644. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  645. // Test multiple AND groups
  646. conf = NewConfig(l)
  647. mf = &mockFirewall{}
  648. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  649. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  650. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
  651. // Test Add error
  652. conf = NewConfig(l)
  653. mf = &mockFirewall{}
  654. mf.nextCallReturn = errors.New("test error")
  655. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  656. assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  657. }
  658. func TestTCPRTTTracking(t *testing.T) {
  659. b := make([]byte, 200)
  660. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  661. b[0] = 15
  662. b[60+12] = 15 << 4
  663. f := Firewall{
  664. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  665. }
  666. // Set SEQ to 1
  667. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  668. c := &conn{}
  669. setTCPRTTTracking(c, b)
  670. assert.Equal(t, uint32(1), c.Seq)
  671. // Bad ack - no ack flag
  672. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  673. assert.False(t, f.checkTCPRTT(c, b))
  674. // Bad ack, number is too low
  675. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  676. b[60+13] = uint8(0x10)
  677. assert.False(t, f.checkTCPRTT(c, b))
  678. // Good ack
  679. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  680. assert.True(t, f.checkTCPRTT(c, b))
  681. assert.Equal(t, uint32(0), c.Seq)
  682. // Set SEQ to 1
  683. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  684. c = &conn{}
  685. setTCPRTTTracking(c, b)
  686. assert.Equal(t, uint32(1), c.Seq)
  687. // Good acks
  688. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  689. assert.True(t, f.checkTCPRTT(c, b))
  690. assert.Equal(t, uint32(0), c.Seq)
  691. // Set SEQ to max uint32 - 20
  692. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  693. c = &conn{}
  694. setTCPRTTTracking(c, b)
  695. assert.Equal(t, ^uint32(0)-20, c.Seq)
  696. // Good acks
  697. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  698. assert.True(t, f.checkTCPRTT(c, b))
  699. assert.Equal(t, uint32(0), c.Seq)
  700. // Set SEQ to max uint32 / 2
  701. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  702. c = &conn{}
  703. setTCPRTTTracking(c, b)
  704. assert.Equal(t, ^uint32(0)/2, c.Seq)
  705. // Below
  706. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  707. assert.False(t, f.checkTCPRTT(c, b))
  708. assert.Equal(t, ^uint32(0)/2, c.Seq)
  709. // Halfway below
  710. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  711. assert.False(t, f.checkTCPRTT(c, b))
  712. assert.Equal(t, ^uint32(0)/2, c.Seq)
  713. // Halfway above is ok
  714. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  715. assert.True(t, f.checkTCPRTT(c, b))
  716. assert.Equal(t, uint32(0), c.Seq)
  717. // Set SEQ to max uint32
  718. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  719. c = &conn{}
  720. setTCPRTTTracking(c, b)
  721. assert.Equal(t, ^uint32(0), c.Seq)
  722. // Halfway + 1 above
  723. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  724. assert.False(t, f.checkTCPRTT(c, b))
  725. assert.Equal(t, ^uint32(0), c.Seq)
  726. // Halfway above
  727. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  728. assert.True(t, f.checkTCPRTT(c, b))
  729. assert.Equal(t, uint32(0), c.Seq)
  730. }
  731. func TestFirewall_convertRule(t *testing.T) {
  732. l := NewTestLogger()
  733. ob := &bytes.Buffer{}
  734. l.SetOutput(ob)
  735. // Ensure group array of 1 is converted and a warning is printed
  736. c := map[interface{}]interface{}{
  737. "group": []interface{}{"group1"},
  738. }
  739. r, err := convertRule(l, c, "test", 1)
  740. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  741. assert.Nil(t, err)
  742. assert.Equal(t, "group1", r.Group)
  743. // Ensure group array of > 1 is errord
  744. ob.Reset()
  745. c = map[interface{}]interface{}{
  746. "group": []interface{}{"group1", "group2"},
  747. }
  748. r, err = convertRule(l, c, "test", 1)
  749. assert.Equal(t, "", ob.String())
  750. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  751. // Make sure a well formed group is alright
  752. ob.Reset()
  753. c = map[interface{}]interface{}{
  754. "group": "group1",
  755. }
  756. r, err = convertRule(l, c, "test", 1)
  757. assert.Nil(t, err)
  758. assert.Equal(t, "group1", r.Group)
  759. }
  760. type addRuleCall struct {
  761. incoming bool
  762. proto uint8
  763. startPort int32
  764. endPort int32
  765. groups []string
  766. host string
  767. ip *net.IPNet
  768. caName string
  769. caSha string
  770. }
  771. type mockFirewall struct {
  772. lastCall addRuleCall
  773. nextCallReturn error
  774. }
  775. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  776. mf.lastCall = addRuleCall{
  777. incoming: incoming,
  778. proto: proto,
  779. startPort: startPort,
  780. endPort: endPort,
  781. groups: groups,
  782. host: host,
  783. ip: ip,
  784. caName: caName,
  785. caSha: caSha,
  786. }
  787. err := mf.nextCallReturn
  788. mf.nextCallReturn = nil
  789. return err
  790. }
  791. func resetConntrack(fw *Firewall) {
  792. fw.Conntrack.Lock()
  793. fw.Conntrack.Conns = map[FirewallPacket]*conn{}
  794. fw.Conntrack.Unlock()
  795. }