firewall_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. "net"
  8. "testing"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/stretchr/testify/assert"
  13. )
  14. func TestNewFirewall(t *testing.T) {
  15. c := &cert.NebulaCertificate{}
  16. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  17. conntrack := fw.Conntrack
  18. assert.NotNil(t, conntrack)
  19. assert.NotNil(t, conntrack.Conns)
  20. assert.NotNil(t, conntrack.TimerWheel)
  21. assert.NotNil(t, fw.InRules)
  22. assert.NotNil(t, fw.OutRules)
  23. assert.Equal(t, time.Second, fw.TCPTimeout)
  24. assert.Equal(t, time.Minute, fw.UDPTimeout)
  25. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  26. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  27. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  28. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  29. fw = NewFirewall(time.Second, time.Hour, time.Minute, c)
  30. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  32. fw = NewFirewall(time.Hour, time.Second, time.Minute, c)
  33. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  34. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  35. fw = NewFirewall(time.Hour, time.Minute, time.Second, c)
  36. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  37. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  38. fw = NewFirewall(time.Minute, time.Hour, time.Second, c)
  39. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  40. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  41. fw = NewFirewall(time.Minute, time.Second, time.Hour, c)
  42. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  43. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  44. }
  45. func TestFirewall_AddRule(t *testing.T) {
  46. ob := &bytes.Buffer{}
  47. out := l.Out
  48. l.SetOutput(ob)
  49. defer l.SetOutput(out)
  50. c := &cert.NebulaCertificate{}
  51. fw := NewFirewall(time.Second, time.Minute, time.Hour, c)
  52. assert.NotNil(t, fw.InRules)
  53. assert.NotNil(t, fw.OutRules)
  54. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  55. assert.Nil(t, fw.AddRule(true, fwProtoTCP, 1, 1, []string{}, "", nil, "", ""))
  56. // An empty rule is any
  57. assert.True(t, fw.InRules.TCP[1].Any.Any)
  58. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  59. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  60. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.left)
  61. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.right)
  62. assert.Nil(t, fw.InRules.TCP[1].Any.CIDR.root.value)
  63. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  64. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
  65. assert.False(t, fw.InRules.UDP[1].Any.Any)
  66. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
  67. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  68. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.left)
  69. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.right)
  70. assert.Nil(t, fw.InRules.UDP[1].Any.CIDR.root.value)
  71. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  72. assert.Nil(t, fw.AddRule(true, fwProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
  73. assert.False(t, fw.InRules.ICMP[1].Any.Any)
  74. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  75. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  76. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.left)
  77. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.right)
  78. assert.Nil(t, fw.InRules.ICMP[1].Any.CIDR.root.value)
  79. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  80. assert.Nil(t, fw.AddRule(false, fwProtoAny, 1, 1, []string{}, "", ti, "", ""))
  81. assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
  82. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
  83. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
  84. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(ip2int(ti.IP)))
  85. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  86. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
  87. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  88. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  89. assert.Nil(t, fw.AddRule(true, fwProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
  90. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  91. // Set any and clear fields
  92. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  93. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
  94. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
  95. assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
  96. assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(ip2int(ti.IP)))
  97. // run twice just to make sure
  98. //TODO: these ANY rules should clear the CA firewall portion
  99. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  100. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  101. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  102. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
  103. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
  104. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.left)
  105. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.right)
  106. assert.Nil(t, fw.OutRules.AnyProto[0].Any.CIDR.root.value)
  107. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  108. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  109. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  110. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  111. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  112. assert.Nil(t, fw.AddRule(false, fwProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
  113. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  114. // Test error conditions
  115. fw = NewFirewall(time.Second, time.Minute, time.Hour, c)
  116. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
  117. assert.Error(t, fw.AddRule(true, fwProtoAny, 10, 0, []string{}, "", nil, "", ""))
  118. }
  119. func TestFirewall_Drop(t *testing.T) {
  120. ob := &bytes.Buffer{}
  121. out := l.Out
  122. l.SetOutput(ob)
  123. defer l.SetOutput(out)
  124. p := FirewallPacket{
  125. ip2int(net.IPv4(1, 2, 3, 4)),
  126. ip2int(net.IPv4(1, 2, 3, 4)),
  127. 10,
  128. 90,
  129. fwProtoUDP,
  130. false,
  131. }
  132. ipNet := net.IPNet{
  133. IP: net.IPv4(1, 2, 3, 4),
  134. Mask: net.IPMask{255, 255, 255, 0},
  135. }
  136. c := cert.NebulaCertificate{
  137. Details: cert.NebulaCertificateDetails{
  138. Name: "host1",
  139. Ips: []*net.IPNet{&ipNet},
  140. Groups: []string{"default-group"},
  141. InvertedGroups: map[string]struct{}{"default-group": {}},
  142. Issuer: "signer-shasum",
  143. },
  144. }
  145. h := HostInfo{
  146. ConnectionState: &ConnectionState{
  147. peerCert: &c,
  148. },
  149. hostId: ip2int(ipNet.IP),
  150. }
  151. h.CreateRemoteCIDR(&c)
  152. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  153. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  154. cp := cert.NewCAPool()
  155. // Drop outbound
  156. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  157. // Allow inbound
  158. resetConntrack(fw)
  159. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  160. // Allow outbound because conntrack
  161. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  162. // test remote mismatch
  163. oldRemote := p.RemoteIP
  164. p.RemoteIP = ip2int(net.IPv4(1, 2, 3, 10))
  165. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
  166. p.RemoteIP = oldRemote
  167. // ensure signer doesn't get in the way of group checks
  168. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  169. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
  170. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
  171. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  172. // test caSha doesn't drop on match
  173. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  174. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
  175. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
  176. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  177. // ensure ca name doesn't get in the way of group checks
  178. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  179. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  180. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
  181. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
  182. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  183. // test caName doesn't drop on match
  184. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  185. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  186. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
  187. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
  188. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  189. }
  190. func BenchmarkFirewallTable_match(b *testing.B) {
  191. ft := FirewallTable{
  192. TCP: firewallPort{},
  193. }
  194. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  195. _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
  196. _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
  197. _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
  198. _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
  199. _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
  200. cp := cert.NewCAPool()
  201. b.Run("fail on proto", func(b *testing.B) {
  202. c := &cert.NebulaCertificate{}
  203. for n := 0; n < b.N; n++ {
  204. ft.match(FirewallPacket{Protocol: fwProtoUDP}, true, c, cp)
  205. }
  206. })
  207. b.Run("fail on port", func(b *testing.B) {
  208. c := &cert.NebulaCertificate{}
  209. for n := 0; n < b.N; n++ {
  210. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 1}, true, c, cp)
  211. }
  212. })
  213. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  214. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  215. c := &cert.NebulaCertificate{
  216. Details: cert.NebulaCertificateDetails{
  217. InvertedGroups: map[string]struct{}{"nope": {}},
  218. Name: "nope",
  219. Ips: []*net.IPNet{ip},
  220. },
  221. }
  222. for n := 0; n < b.N; n++ {
  223. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  224. }
  225. })
  226. b.Run("pass on group", func(b *testing.B) {
  227. c := &cert.NebulaCertificate{
  228. Details: cert.NebulaCertificateDetails{
  229. InvertedGroups: map[string]struct{}{"good-group": {}},
  230. Name: "nope",
  231. },
  232. }
  233. for n := 0; n < b.N; n++ {
  234. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  235. }
  236. })
  237. b.Run("pass on name", func(b *testing.B) {
  238. c := &cert.NebulaCertificate{
  239. Details: cert.NebulaCertificateDetails{
  240. InvertedGroups: map[string]struct{}{"nope": {}},
  241. Name: "good-host",
  242. },
  243. }
  244. for n := 0; n < b.N; n++ {
  245. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10}, true, c, cp)
  246. }
  247. })
  248. b.Run("pass on ip", func(b *testing.B) {
  249. ip := ip2int(net.IPv4(172, 1, 1, 1))
  250. c := &cert.NebulaCertificate{
  251. Details: cert.NebulaCertificateDetails{
  252. InvertedGroups: map[string]struct{}{"nope": {}},
  253. Name: "good-host",
  254. },
  255. }
  256. for n := 0; n < b.N; n++ {
  257. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  258. }
  259. })
  260. _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
  261. b.Run("pass on ip with any port", func(b *testing.B) {
  262. ip := ip2int(net.IPv4(172, 1, 1, 1))
  263. c := &cert.NebulaCertificate{
  264. Details: cert.NebulaCertificateDetails{
  265. InvertedGroups: map[string]struct{}{"nope": {}},
  266. Name: "good-host",
  267. },
  268. }
  269. for n := 0; n < b.N; n++ {
  270. ft.match(FirewallPacket{Protocol: fwProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  271. }
  272. })
  273. }
  274. func TestFirewall_Drop2(t *testing.T) {
  275. ob := &bytes.Buffer{}
  276. out := l.Out
  277. l.SetOutput(ob)
  278. defer l.SetOutput(out)
  279. p := FirewallPacket{
  280. ip2int(net.IPv4(1, 2, 3, 4)),
  281. ip2int(net.IPv4(1, 2, 3, 4)),
  282. 10,
  283. 90,
  284. fwProtoUDP,
  285. false,
  286. }
  287. ipNet := net.IPNet{
  288. IP: net.IPv4(1, 2, 3, 4),
  289. Mask: net.IPMask{255, 255, 255, 0},
  290. }
  291. c := cert.NebulaCertificate{
  292. Details: cert.NebulaCertificateDetails{
  293. Name: "host1",
  294. Ips: []*net.IPNet{&ipNet},
  295. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  296. },
  297. }
  298. h := HostInfo{
  299. ConnectionState: &ConnectionState{
  300. peerCert: &c,
  301. },
  302. hostId: ip2int(ipNet.IP),
  303. }
  304. h.CreateRemoteCIDR(&c)
  305. c1 := cert.NebulaCertificate{
  306. Details: cert.NebulaCertificateDetails{
  307. Name: "host1",
  308. Ips: []*net.IPNet{&ipNet},
  309. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  310. },
  311. }
  312. h1 := HostInfo{
  313. ConnectionState: &ConnectionState{
  314. peerCert: &c1,
  315. },
  316. }
  317. h1.CreateRemoteCIDR(&c1)
  318. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  319. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
  320. cp := cert.NewCAPool()
  321. // h1/c1 lacks the proper groups
  322. assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
  323. // c has the proper groups
  324. resetConntrack(fw)
  325. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  326. }
  327. func TestFirewall_Drop3(t *testing.T) {
  328. ob := &bytes.Buffer{}
  329. out := l.Out
  330. l.SetOutput(ob)
  331. defer l.SetOutput(out)
  332. p := FirewallPacket{
  333. ip2int(net.IPv4(1, 2, 3, 4)),
  334. ip2int(net.IPv4(1, 2, 3, 4)),
  335. 1,
  336. 1,
  337. fwProtoUDP,
  338. false,
  339. }
  340. ipNet := net.IPNet{
  341. IP: net.IPv4(1, 2, 3, 4),
  342. Mask: net.IPMask{255, 255, 255, 0},
  343. }
  344. c := cert.NebulaCertificate{
  345. Details: cert.NebulaCertificateDetails{
  346. Name: "host-owner",
  347. Ips: []*net.IPNet{&ipNet},
  348. },
  349. }
  350. c1 := cert.NebulaCertificate{
  351. Details: cert.NebulaCertificateDetails{
  352. Name: "host1",
  353. Ips: []*net.IPNet{&ipNet},
  354. Issuer: "signer-sha-bad",
  355. },
  356. }
  357. h1 := HostInfo{
  358. ConnectionState: &ConnectionState{
  359. peerCert: &c1,
  360. },
  361. hostId: ip2int(ipNet.IP),
  362. }
  363. h1.CreateRemoteCIDR(&c1)
  364. c2 := cert.NebulaCertificate{
  365. Details: cert.NebulaCertificateDetails{
  366. Name: "host2",
  367. Ips: []*net.IPNet{&ipNet},
  368. Issuer: "signer-sha",
  369. },
  370. }
  371. h2 := HostInfo{
  372. ConnectionState: &ConnectionState{
  373. peerCert: &c2,
  374. },
  375. hostId: ip2int(ipNet.IP),
  376. }
  377. h2.CreateRemoteCIDR(&c2)
  378. c3 := cert.NebulaCertificate{
  379. Details: cert.NebulaCertificateDetails{
  380. Name: "host3",
  381. Ips: []*net.IPNet{&ipNet},
  382. Issuer: "signer-sha-bad",
  383. },
  384. }
  385. h3 := HostInfo{
  386. ConnectionState: &ConnectionState{
  387. peerCert: &c3,
  388. },
  389. hostId: ip2int(ipNet.IP),
  390. }
  391. h3.CreateRemoteCIDR(&c3)
  392. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  393. assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
  394. assert.Nil(t, fw.AddRule(true, fwProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
  395. cp := cert.NewCAPool()
  396. // c1 should pass because host match
  397. assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
  398. // c2 should pass because ca sha match
  399. resetConntrack(fw)
  400. assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
  401. // c3 should fail because no match
  402. resetConntrack(fw)
  403. assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
  404. }
  405. func TestFirewall_DropConntrackReload(t *testing.T) {
  406. ob := &bytes.Buffer{}
  407. out := l.Out
  408. l.SetOutput(ob)
  409. defer l.SetOutput(out)
  410. p := FirewallPacket{
  411. ip2int(net.IPv4(1, 2, 3, 4)),
  412. ip2int(net.IPv4(1, 2, 3, 4)),
  413. 10,
  414. 90,
  415. fwProtoUDP,
  416. false,
  417. }
  418. ipNet := net.IPNet{
  419. IP: net.IPv4(1, 2, 3, 4),
  420. Mask: net.IPMask{255, 255, 255, 0},
  421. }
  422. c := cert.NebulaCertificate{
  423. Details: cert.NebulaCertificateDetails{
  424. Name: "host1",
  425. Ips: []*net.IPNet{&ipNet},
  426. Groups: []string{"default-group"},
  427. InvertedGroups: map[string]struct{}{"default-group": {}},
  428. Issuer: "signer-shasum",
  429. },
  430. }
  431. h := HostInfo{
  432. ConnectionState: &ConnectionState{
  433. peerCert: &c,
  434. },
  435. hostId: ip2int(ipNet.IP),
  436. }
  437. h.CreateRemoteCIDR(&c)
  438. fw := NewFirewall(time.Second, time.Minute, time.Hour, &c)
  439. assert.Nil(t, fw.AddRule(true, fwProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  440. cp := cert.NewCAPool()
  441. // Drop outbound
  442. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  443. // Allow inbound
  444. resetConntrack(fw)
  445. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  446. // Allow outbound because conntrack
  447. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  448. oldFw := fw
  449. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  450. assert.Nil(t, fw.AddRule(true, fwProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
  451. fw.Conntrack = oldFw.Conntrack
  452. fw.rulesVersion = oldFw.rulesVersion + 1
  453. // Allow outbound because conntrack and new rules allow port 10
  454. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  455. oldFw = fw
  456. fw = NewFirewall(time.Second, time.Minute, time.Hour, &c)
  457. assert.Nil(t, fw.AddRule(true, fwProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
  458. fw.Conntrack = oldFw.Conntrack
  459. fw.rulesVersion = oldFw.rulesVersion + 1
  460. // Drop outbound because conntrack doesn't match new ruleset
  461. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  462. }
  463. func BenchmarkLookup(b *testing.B) {
  464. ml := func(m map[string]struct{}, a [][]string) {
  465. for n := 0; n < b.N; n++ {
  466. for _, sg := range a {
  467. found := false
  468. for _, g := range sg {
  469. if _, ok := m[g]; !ok {
  470. found = false
  471. break
  472. }
  473. found = true
  474. }
  475. if found {
  476. return
  477. }
  478. }
  479. }
  480. }
  481. b.Run("array to map best", func(b *testing.B) {
  482. m := map[string]struct{}{
  483. "1ne": {},
  484. "2wo": {},
  485. "3hr": {},
  486. "4ou": {},
  487. "5iv": {},
  488. "6ix": {},
  489. }
  490. a := [][]string{
  491. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  492. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  493. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  494. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  495. {"one", "two", "thr", "fou", "5iv", "6ix"},
  496. {"one", "two", "thr", "fou", "fiv", "6ix"},
  497. {"one", "two", "thr", "fou", "fiv", "six"},
  498. }
  499. for n := 0; n < b.N; n++ {
  500. ml(m, a)
  501. }
  502. })
  503. b.Run("array to map worst", func(b *testing.B) {
  504. m := map[string]struct{}{
  505. "one": {},
  506. "two": {},
  507. "thr": {},
  508. "fou": {},
  509. "fiv": {},
  510. "six": {},
  511. }
  512. a := [][]string{
  513. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  514. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  515. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  516. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  517. {"one", "two", "thr", "fou", "5iv", "6ix"},
  518. {"one", "two", "thr", "fou", "fiv", "6ix"},
  519. {"one", "two", "thr", "fou", "fiv", "six"},
  520. }
  521. for n := 0; n < b.N; n++ {
  522. ml(m, a)
  523. }
  524. })
  525. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  526. }
  527. func Test_parsePort(t *testing.T) {
  528. _, _, err := parsePort("")
  529. assert.EqualError(t, err, "was not a number; ``")
  530. _, _, err = parsePort(" ")
  531. assert.EqualError(t, err, "was not a number; ` `")
  532. _, _, err = parsePort("-")
  533. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  534. _, _, err = parsePort(" - ")
  535. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  536. _, _, err = parsePort("a-b")
  537. assert.EqualError(t, err, "beginning range was not a number; `a`")
  538. _, _, err = parsePort("1-b")
  539. assert.EqualError(t, err, "ending range was not a number; `b`")
  540. s, e, err := parsePort(" 1 - 2 ")
  541. assert.Equal(t, int32(1), s)
  542. assert.Equal(t, int32(2), e)
  543. assert.Nil(t, err)
  544. s, e, err = parsePort("0-1")
  545. assert.Equal(t, int32(0), s)
  546. assert.Equal(t, int32(0), e)
  547. assert.Nil(t, err)
  548. s, e, err = parsePort("9919")
  549. assert.Equal(t, int32(9919), s)
  550. assert.Equal(t, int32(9919), e)
  551. assert.Nil(t, err)
  552. s, e, err = parsePort("any")
  553. assert.Equal(t, int32(0), s)
  554. assert.Equal(t, int32(0), e)
  555. assert.Nil(t, err)
  556. }
  557. func TestNewFirewallFromConfig(t *testing.T) {
  558. // Test a bad rule definition
  559. c := &cert.NebulaCertificate{}
  560. conf := NewConfig()
  561. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  562. _, err := NewFirewallFromConfig(c, conf)
  563. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  564. // Test both port and code
  565. conf = NewConfig()
  566. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  567. _, err = NewFirewallFromConfig(c, conf)
  568. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  569. // Test missing host, group, cidr, ca_name and ca_sha
  570. conf = NewConfig()
  571. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  572. _, err = NewFirewallFromConfig(c, conf)
  573. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
  574. // Test code/port error
  575. conf = NewConfig()
  576. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  577. _, err = NewFirewallFromConfig(c, conf)
  578. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  579. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  580. _, err = NewFirewallFromConfig(c, conf)
  581. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  582. // Test proto error
  583. conf = NewConfig()
  584. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  585. _, err = NewFirewallFromConfig(c, conf)
  586. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  587. // Test cidr parse error
  588. conf = NewConfig()
  589. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  590. _, err = NewFirewallFromConfig(c, conf)
  591. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  592. // Test both group and groups
  593. conf = NewConfig()
  594. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  595. _, err = NewFirewallFromConfig(c, conf)
  596. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  597. }
  598. func TestAddFirewallRulesFromConfig(t *testing.T) {
  599. // Test adding tcp rule
  600. conf := NewConfig()
  601. mf := &mockFirewall{}
  602. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  603. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  604. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  605. // Test adding udp rule
  606. conf = NewConfig()
  607. mf = &mockFirewall{}
  608. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  609. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  610. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  611. // Test adding icmp rule
  612. conf = NewConfig()
  613. mf = &mockFirewall{}
  614. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  615. assert.Nil(t, AddFirewallRulesFromConfig(false, conf, mf))
  616. assert.Equal(t, addRuleCall{incoming: false, proto: fwProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  617. // Test adding any rule
  618. conf = NewConfig()
  619. mf = &mockFirewall{}
  620. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  621. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  622. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  623. // Test adding rule with ca_sha
  624. conf = NewConfig()
  625. mf = &mockFirewall{}
  626. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  627. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  628. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
  629. // Test adding rule with ca_name
  630. conf = NewConfig()
  631. mf = &mockFirewall{}
  632. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  633. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  634. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
  635. // Test single group
  636. conf = NewConfig()
  637. mf = &mockFirewall{}
  638. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  639. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  640. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  641. // Test single groups
  642. conf = NewConfig()
  643. mf = &mockFirewall{}
  644. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  645. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  646. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  647. // Test multiple AND groups
  648. conf = NewConfig()
  649. mf = &mockFirewall{}
  650. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  651. assert.Nil(t, AddFirewallRulesFromConfig(true, conf, mf))
  652. assert.Equal(t, addRuleCall{incoming: true, proto: fwProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
  653. // Test Add error
  654. conf = NewConfig()
  655. mf = &mockFirewall{}
  656. mf.nextCallReturn = errors.New("test error")
  657. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  658. assert.EqualError(t, AddFirewallRulesFromConfig(true, conf, mf), "firewall.inbound rule #0; `test error`")
  659. }
  660. func TestTCPRTTTracking(t *testing.T) {
  661. b := make([]byte, 200)
  662. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  663. b[0] = 15
  664. b[60+12] = 15 << 4
  665. f := Firewall{
  666. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  667. }
  668. // Set SEQ to 1
  669. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  670. c := &conn{}
  671. setTCPRTTTracking(c, b)
  672. assert.Equal(t, uint32(1), c.Seq)
  673. // Bad ack - no ack flag
  674. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  675. assert.False(t, f.checkTCPRTT(c, b))
  676. // Bad ack, number is too low
  677. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  678. b[60+13] = uint8(0x10)
  679. assert.False(t, f.checkTCPRTT(c, b))
  680. // Good ack
  681. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  682. assert.True(t, f.checkTCPRTT(c, b))
  683. assert.Equal(t, uint32(0), c.Seq)
  684. // Set SEQ to 1
  685. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  686. c = &conn{}
  687. setTCPRTTTracking(c, b)
  688. assert.Equal(t, uint32(1), c.Seq)
  689. // Good acks
  690. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  691. assert.True(t, f.checkTCPRTT(c, b))
  692. assert.Equal(t, uint32(0), c.Seq)
  693. // Set SEQ to max uint32 - 20
  694. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  695. c = &conn{}
  696. setTCPRTTTracking(c, b)
  697. assert.Equal(t, ^uint32(0)-20, c.Seq)
  698. // Good acks
  699. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  700. assert.True(t, f.checkTCPRTT(c, b))
  701. assert.Equal(t, uint32(0), c.Seq)
  702. // Set SEQ to max uint32 / 2
  703. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  704. c = &conn{}
  705. setTCPRTTTracking(c, b)
  706. assert.Equal(t, ^uint32(0)/2, c.Seq)
  707. // Below
  708. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  709. assert.False(t, f.checkTCPRTT(c, b))
  710. assert.Equal(t, ^uint32(0)/2, c.Seq)
  711. // Halfway below
  712. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  713. assert.False(t, f.checkTCPRTT(c, b))
  714. assert.Equal(t, ^uint32(0)/2, c.Seq)
  715. // Halfway above is ok
  716. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  717. assert.True(t, f.checkTCPRTT(c, b))
  718. assert.Equal(t, uint32(0), c.Seq)
  719. // Set SEQ to max uint32
  720. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  721. c = &conn{}
  722. setTCPRTTTracking(c, b)
  723. assert.Equal(t, ^uint32(0), c.Seq)
  724. // Halfway + 1 above
  725. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  726. assert.False(t, f.checkTCPRTT(c, b))
  727. assert.Equal(t, ^uint32(0), c.Seq)
  728. // Halfway above
  729. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  730. assert.True(t, f.checkTCPRTT(c, b))
  731. assert.Equal(t, uint32(0), c.Seq)
  732. }
  733. func TestFirewall_convertRule(t *testing.T) {
  734. ob := &bytes.Buffer{}
  735. out := l.Out
  736. l.SetOutput(ob)
  737. defer l.SetOutput(out)
  738. // Ensure group array of 1 is converted and a warning is printed
  739. c := map[interface{}]interface{}{
  740. "group": []interface{}{"group1"},
  741. }
  742. r, err := convertRule(c, "test", 1)
  743. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  744. assert.Nil(t, err)
  745. assert.Equal(t, "group1", r.Group)
  746. // Ensure group array of > 1 is errord
  747. ob.Reset()
  748. c = map[interface{}]interface{}{
  749. "group": []interface{}{"group1", "group2"},
  750. }
  751. r, err = convertRule(c, "test", 1)
  752. assert.Equal(t, "", ob.String())
  753. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  754. // Make sure a well formed group is alright
  755. ob.Reset()
  756. c = map[interface{}]interface{}{
  757. "group": "group1",
  758. }
  759. r, err = convertRule(c, "test", 1)
  760. assert.Nil(t, err)
  761. assert.Equal(t, "group1", r.Group)
  762. }
  763. type addRuleCall struct {
  764. incoming bool
  765. proto uint8
  766. startPort int32
  767. endPort int32
  768. groups []string
  769. host string
  770. ip *net.IPNet
  771. caName string
  772. caSha string
  773. }
  774. type mockFirewall struct {
  775. lastCall addRuleCall
  776. nextCallReturn error
  777. }
  778. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  779. mf.lastCall = addRuleCall{
  780. incoming: incoming,
  781. proto: proto,
  782. startPort: startPort,
  783. endPort: endPort,
  784. groups: groups,
  785. host: host,
  786. ip: ip,
  787. caName: caName,
  788. caSha: caSha,
  789. }
  790. err := mf.nextCallReturn
  791. mf.nextCallReturn = nil
  792. return err
  793. }
  794. func resetConntrack(fw *Firewall) {
  795. fw.Conntrack.Lock()
  796. fw.Conntrack.Conns = map[FirewallPacket]*conn{}
  797. fw.Conntrack.Unlock()
  798. }