firewall_test.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. "net"
  8. "testing"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/firewall"
  14. "github.com/slackhq/nebula/iputil"
  15. "github.com/slackhq/nebula/util"
  16. "github.com/stretchr/testify/assert"
  17. )
  18. func TestNewFirewall(t *testing.T) {
  19. l := util.NewTestLogger()
  20. c := &cert.NebulaCertificate{}
  21. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  22. conntrack := fw.Conntrack
  23. assert.NotNil(t, conntrack)
  24. assert.NotNil(t, conntrack.Conns)
  25. assert.NotNil(t, conntrack.TimerWheel)
  26. assert.NotNil(t, fw.InRules)
  27. assert.NotNil(t, fw.OutRules)
  28. assert.Equal(t, time.Second, fw.TCPTimeout)
  29. assert.Equal(t, time.Minute, fw.UDPTimeout)
  30. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  46. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  47. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  48. assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
  49. }
  50. func TestFirewall_AddRule(t *testing.T) {
  51. l := util.NewTestLogger()
  52. ob := &bytes.Buffer{}
  53. l.SetOutput(ob)
  54. c := &cert.NebulaCertificate{}
  55. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  56. assert.NotNil(t, fw.InRules)
  57. assert.NotNil(t, fw.OutRules)
  58. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  59. assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, "", ""))
  60. // An empty rule is any
  61. assert.True(t, fw.InRules.TCP[1].Any.Any)
  62. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  63. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  64. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  65. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", ""))
  66. assert.False(t, fw.InRules.UDP[1].Any.Any)
  67. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
  68. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  69. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  70. assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, "", ""))
  71. assert.False(t, fw.InRules.ICMP[1].Any.Any)
  72. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  73. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  74. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  75. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, "", ""))
  76. assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
  77. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
  78. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
  79. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
  80. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  81. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "ca-name", ""))
  82. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  83. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  84. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, "", "ca-sha"))
  85. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  86. // Set any and clear fields
  87. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  88. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, "", ""))
  89. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
  90. assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
  91. assert.NotNil(t, fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP)))
  92. // run twice just to make sure
  93. //TODO: these ANY rules should clear the CA firewall portion
  94. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  95. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  96. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  97. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
  98. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
  99. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  100. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, "", ""))
  101. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  102. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  103. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  104. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, "", ""))
  105. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  106. // Test error conditions
  107. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  108. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, "", ""))
  109. assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, "", ""))
  110. }
  111. func TestFirewall_Drop(t *testing.T) {
  112. l := util.NewTestLogger()
  113. ob := &bytes.Buffer{}
  114. l.SetOutput(ob)
  115. p := firewall.Packet{
  116. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  117. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  118. 10,
  119. 90,
  120. firewall.ProtoUDP,
  121. false,
  122. }
  123. ipNet := net.IPNet{
  124. IP: net.IPv4(1, 2, 3, 4),
  125. Mask: net.IPMask{255, 255, 255, 0},
  126. }
  127. c := cert.NebulaCertificate{
  128. Details: cert.NebulaCertificateDetails{
  129. Name: "host1",
  130. Ips: []*net.IPNet{&ipNet},
  131. Groups: []string{"default-group"},
  132. InvertedGroups: map[string]struct{}{"default-group": {}},
  133. Issuer: "signer-shasum",
  134. },
  135. }
  136. h := HostInfo{
  137. ConnectionState: &ConnectionState{
  138. peerCert: &c,
  139. },
  140. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  141. }
  142. h.CreateRemoteCIDR(&c)
  143. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  144. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  145. cp := cert.NewCAPool()
  146. // Drop outbound
  147. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  148. // Allow inbound
  149. resetConntrack(fw)
  150. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  151. // Allow outbound because conntrack
  152. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  153. // test remote mismatch
  154. oldRemote := p.RemoteIP
  155. p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
  156. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
  157. p.RemoteIP = oldRemote
  158. // ensure signer doesn't get in the way of group checks
  159. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  160. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum"))
  161. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum-bad"))
  162. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  163. // test caSha doesn't drop on match
  164. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  165. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "", "signer-shasum-bad"))
  166. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "", "signer-shasum"))
  167. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  168. // ensure ca name doesn't get in the way of group checks
  169. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  170. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  171. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good", ""))
  172. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good-bad", ""))
  173. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  174. // test caName doesn't drop on match
  175. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  176. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  177. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, "ca-good-bad", ""))
  178. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, "ca-good", ""))
  179. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  180. }
  181. func BenchmarkFirewallTable_match(b *testing.B) {
  182. ft := FirewallTable{
  183. TCP: firewallPort{},
  184. }
  185. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  186. _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, "", "")
  187. _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, "", "")
  188. _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, "", "")
  189. _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, "", "")
  190. _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, "", "")
  191. cp := cert.NewCAPool()
  192. b.Run("fail on proto", func(b *testing.B) {
  193. c := &cert.NebulaCertificate{}
  194. for n := 0; n < b.N; n++ {
  195. ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
  196. }
  197. })
  198. b.Run("fail on port", func(b *testing.B) {
  199. c := &cert.NebulaCertificate{}
  200. for n := 0; n < b.N; n++ {
  201. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
  202. }
  203. })
  204. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  205. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  206. c := &cert.NebulaCertificate{
  207. Details: cert.NebulaCertificateDetails{
  208. InvertedGroups: map[string]struct{}{"nope": {}},
  209. Name: "nope",
  210. Ips: []*net.IPNet{ip},
  211. },
  212. }
  213. for n := 0; n < b.N; n++ {
  214. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  215. }
  216. })
  217. b.Run("pass on group", func(b *testing.B) {
  218. c := &cert.NebulaCertificate{
  219. Details: cert.NebulaCertificateDetails{
  220. InvertedGroups: map[string]struct{}{"good-group": {}},
  221. Name: "nope",
  222. },
  223. }
  224. for n := 0; n < b.N; n++ {
  225. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  226. }
  227. })
  228. b.Run("pass on name", func(b *testing.B) {
  229. c := &cert.NebulaCertificate{
  230. Details: cert.NebulaCertificateDetails{
  231. InvertedGroups: map[string]struct{}{"nope": {}},
  232. Name: "good-host",
  233. },
  234. }
  235. for n := 0; n < b.N; n++ {
  236. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  237. }
  238. })
  239. b.Run("pass on ip", func(b *testing.B) {
  240. ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  241. c := &cert.NebulaCertificate{
  242. Details: cert.NebulaCertificateDetails{
  243. InvertedGroups: map[string]struct{}{"nope": {}},
  244. Name: "good-host",
  245. },
  246. }
  247. for n := 0; n < b.N; n++ {
  248. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  249. }
  250. })
  251. _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, "", "")
  252. b.Run("pass on ip with any port", func(b *testing.B) {
  253. ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  254. c := &cert.NebulaCertificate{
  255. Details: cert.NebulaCertificateDetails{
  256. InvertedGroups: map[string]struct{}{"nope": {}},
  257. Name: "good-host",
  258. },
  259. }
  260. for n := 0; n < b.N; n++ {
  261. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  262. }
  263. })
  264. }
  265. func TestFirewall_Drop2(t *testing.T) {
  266. l := util.NewTestLogger()
  267. ob := &bytes.Buffer{}
  268. l.SetOutput(ob)
  269. p := firewall.Packet{
  270. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  271. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  272. 10,
  273. 90,
  274. firewall.ProtoUDP,
  275. false,
  276. }
  277. ipNet := net.IPNet{
  278. IP: net.IPv4(1, 2, 3, 4),
  279. Mask: net.IPMask{255, 255, 255, 0},
  280. }
  281. c := cert.NebulaCertificate{
  282. Details: cert.NebulaCertificateDetails{
  283. Name: "host1",
  284. Ips: []*net.IPNet{&ipNet},
  285. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  286. },
  287. }
  288. h := HostInfo{
  289. ConnectionState: &ConnectionState{
  290. peerCert: &c,
  291. },
  292. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  293. }
  294. h.CreateRemoteCIDR(&c)
  295. c1 := cert.NebulaCertificate{
  296. Details: cert.NebulaCertificateDetails{
  297. Name: "host1",
  298. Ips: []*net.IPNet{&ipNet},
  299. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  300. },
  301. }
  302. h1 := HostInfo{
  303. ConnectionState: &ConnectionState{
  304. peerCert: &c1,
  305. },
  306. }
  307. h1.CreateRemoteCIDR(&c1)
  308. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  309. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, "", ""))
  310. cp := cert.NewCAPool()
  311. // h1/c1 lacks the proper groups
  312. assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
  313. // c has the proper groups
  314. resetConntrack(fw)
  315. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  316. }
  317. func TestFirewall_Drop3(t *testing.T) {
  318. l := util.NewTestLogger()
  319. ob := &bytes.Buffer{}
  320. l.SetOutput(ob)
  321. p := firewall.Packet{
  322. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  323. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  324. 1,
  325. 1,
  326. firewall.ProtoUDP,
  327. false,
  328. }
  329. ipNet := net.IPNet{
  330. IP: net.IPv4(1, 2, 3, 4),
  331. Mask: net.IPMask{255, 255, 255, 0},
  332. }
  333. c := cert.NebulaCertificate{
  334. Details: cert.NebulaCertificateDetails{
  335. Name: "host-owner",
  336. Ips: []*net.IPNet{&ipNet},
  337. },
  338. }
  339. c1 := cert.NebulaCertificate{
  340. Details: cert.NebulaCertificateDetails{
  341. Name: "host1",
  342. Ips: []*net.IPNet{&ipNet},
  343. Issuer: "signer-sha-bad",
  344. },
  345. }
  346. h1 := HostInfo{
  347. ConnectionState: &ConnectionState{
  348. peerCert: &c1,
  349. },
  350. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  351. }
  352. h1.CreateRemoteCIDR(&c1)
  353. c2 := cert.NebulaCertificate{
  354. Details: cert.NebulaCertificateDetails{
  355. Name: "host2",
  356. Ips: []*net.IPNet{&ipNet},
  357. Issuer: "signer-sha",
  358. },
  359. }
  360. h2 := HostInfo{
  361. ConnectionState: &ConnectionState{
  362. peerCert: &c2,
  363. },
  364. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  365. }
  366. h2.CreateRemoteCIDR(&c2)
  367. c3 := cert.NebulaCertificate{
  368. Details: cert.NebulaCertificateDetails{
  369. Name: "host3",
  370. Ips: []*net.IPNet{&ipNet},
  371. Issuer: "signer-sha-bad",
  372. },
  373. }
  374. h3 := HostInfo{
  375. ConnectionState: &ConnectionState{
  376. peerCert: &c3,
  377. },
  378. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  379. }
  380. h3.CreateRemoteCIDR(&c3)
  381. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  382. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, "", ""))
  383. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, "", "signer-sha"))
  384. cp := cert.NewCAPool()
  385. // c1 should pass because host match
  386. assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
  387. // c2 should pass because ca sha match
  388. resetConntrack(fw)
  389. assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
  390. // c3 should fail because no match
  391. resetConntrack(fw)
  392. assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
  393. }
  394. func TestFirewall_DropConntrackReload(t *testing.T) {
  395. l := util.NewTestLogger()
  396. ob := &bytes.Buffer{}
  397. l.SetOutput(ob)
  398. p := firewall.Packet{
  399. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  400. iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  401. 10,
  402. 90,
  403. firewall.ProtoUDP,
  404. false,
  405. }
  406. ipNet := net.IPNet{
  407. IP: net.IPv4(1, 2, 3, 4),
  408. Mask: net.IPMask{255, 255, 255, 0},
  409. }
  410. c := cert.NebulaCertificate{
  411. Details: cert.NebulaCertificateDetails{
  412. Name: "host1",
  413. Ips: []*net.IPNet{&ipNet},
  414. Groups: []string{"default-group"},
  415. InvertedGroups: map[string]struct{}{"default-group": {}},
  416. Issuer: "signer-shasum",
  417. },
  418. }
  419. h := HostInfo{
  420. ConnectionState: &ConnectionState{
  421. peerCert: &c,
  422. },
  423. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  424. }
  425. h.CreateRemoteCIDR(&c)
  426. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  427. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, "", ""))
  428. cp := cert.NewCAPool()
  429. // Drop outbound
  430. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  431. // Allow inbound
  432. resetConntrack(fw)
  433. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  434. // Allow outbound because conntrack
  435. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  436. oldFw := fw
  437. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  438. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, "", ""))
  439. fw.Conntrack = oldFw.Conntrack
  440. fw.rulesVersion = oldFw.rulesVersion + 1
  441. // Allow outbound because conntrack and new rules allow port 10
  442. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  443. oldFw = fw
  444. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  445. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, "", ""))
  446. fw.Conntrack = oldFw.Conntrack
  447. fw.rulesVersion = oldFw.rulesVersion + 1
  448. // Drop outbound because conntrack doesn't match new ruleset
  449. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  450. }
  451. func BenchmarkLookup(b *testing.B) {
  452. ml := func(m map[string]struct{}, a [][]string) {
  453. for n := 0; n < b.N; n++ {
  454. for _, sg := range a {
  455. found := false
  456. for _, g := range sg {
  457. if _, ok := m[g]; !ok {
  458. found = false
  459. break
  460. }
  461. found = true
  462. }
  463. if found {
  464. return
  465. }
  466. }
  467. }
  468. }
  469. b.Run("array to map best", func(b *testing.B) {
  470. m := map[string]struct{}{
  471. "1ne": {},
  472. "2wo": {},
  473. "3hr": {},
  474. "4ou": {},
  475. "5iv": {},
  476. "6ix": {},
  477. }
  478. a := [][]string{
  479. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  480. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  481. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  482. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  483. {"one", "two", "thr", "fou", "5iv", "6ix"},
  484. {"one", "two", "thr", "fou", "fiv", "6ix"},
  485. {"one", "two", "thr", "fou", "fiv", "six"},
  486. }
  487. for n := 0; n < b.N; n++ {
  488. ml(m, a)
  489. }
  490. })
  491. b.Run("array to map worst", func(b *testing.B) {
  492. m := map[string]struct{}{
  493. "one": {},
  494. "two": {},
  495. "thr": {},
  496. "fou": {},
  497. "fiv": {},
  498. "six": {},
  499. }
  500. a := [][]string{
  501. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  502. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  503. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  504. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  505. {"one", "two", "thr", "fou", "5iv", "6ix"},
  506. {"one", "two", "thr", "fou", "fiv", "6ix"},
  507. {"one", "two", "thr", "fou", "fiv", "six"},
  508. }
  509. for n := 0; n < b.N; n++ {
  510. ml(m, a)
  511. }
  512. })
  513. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  514. }
  515. func Test_parsePort(t *testing.T) {
  516. _, _, err := parsePort("")
  517. assert.EqualError(t, err, "was not a number; ``")
  518. _, _, err = parsePort(" ")
  519. assert.EqualError(t, err, "was not a number; ` `")
  520. _, _, err = parsePort("-")
  521. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  522. _, _, err = parsePort(" - ")
  523. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  524. _, _, err = parsePort("a-b")
  525. assert.EqualError(t, err, "beginning range was not a number; `a`")
  526. _, _, err = parsePort("1-b")
  527. assert.EqualError(t, err, "ending range was not a number; `b`")
  528. s, e, err := parsePort(" 1 - 2 ")
  529. assert.Equal(t, int32(1), s)
  530. assert.Equal(t, int32(2), e)
  531. assert.Nil(t, err)
  532. s, e, err = parsePort("0-1")
  533. assert.Equal(t, int32(0), s)
  534. assert.Equal(t, int32(0), e)
  535. assert.Nil(t, err)
  536. s, e, err = parsePort("9919")
  537. assert.Equal(t, int32(9919), s)
  538. assert.Equal(t, int32(9919), e)
  539. assert.Nil(t, err)
  540. s, e, err = parsePort("any")
  541. assert.Equal(t, int32(0), s)
  542. assert.Equal(t, int32(0), e)
  543. assert.Nil(t, err)
  544. }
  545. func TestNewFirewallFromConfig(t *testing.T) {
  546. l := util.NewTestLogger()
  547. // Test a bad rule definition
  548. c := &cert.NebulaCertificate{}
  549. conf := config.NewC(l)
  550. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  551. _, err := NewFirewallFromConfig(l, c, conf)
  552. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  553. // Test both port and code
  554. conf = config.NewC(l)
  555. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  556. _, err = NewFirewallFromConfig(l, c, conf)
  557. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  558. // Test missing host, group, cidr, ca_name and ca_sha
  559. conf = config.NewC(l)
  560. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  561. _, err = NewFirewallFromConfig(l, c, conf)
  562. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, ca_name, or ca_sha must be provided")
  563. // Test code/port error
  564. conf = config.NewC(l)
  565. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  566. _, err = NewFirewallFromConfig(l, c, conf)
  567. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  568. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  569. _, err = NewFirewallFromConfig(l, c, conf)
  570. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  571. // Test proto error
  572. conf = config.NewC(l)
  573. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  574. _, err = NewFirewallFromConfig(l, c, conf)
  575. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  576. // Test cidr parse error
  577. conf = config.NewC(l)
  578. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  579. _, err = NewFirewallFromConfig(l, c, conf)
  580. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  581. // Test both group and groups
  582. conf = config.NewC(l)
  583. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  584. _, err = NewFirewallFromConfig(l, c, conf)
  585. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  586. }
  587. func TestAddFirewallRulesFromConfig(t *testing.T) {
  588. l := util.NewTestLogger()
  589. // Test adding tcp rule
  590. conf := config.NewC(l)
  591. mf := &mockFirewall{}
  592. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  593. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  594. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  595. // Test adding udp rule
  596. conf = config.NewC(l)
  597. mf = &mockFirewall{}
  598. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  599. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  600. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  601. // Test adding icmp rule
  602. conf = config.NewC(l)
  603. mf = &mockFirewall{}
  604. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  605. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  606. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  607. // Test adding any rule
  608. conf = config.NewC(l)
  609. mf = &mockFirewall{}
  610. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  611. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  612. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil}, mf.lastCall)
  613. // Test adding rule with ca_sha
  614. conf = config.NewC(l)
  615. mf = &mockFirewall{}
  616. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  617. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  618. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caSha: "12312313123"}, mf.lastCall)
  619. // Test adding rule with ca_name
  620. conf = config.NewC(l)
  621. mf = &mockFirewall{}
  622. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  623. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  624. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, caName: "root01"}, mf.lastCall)
  625. // Test single group
  626. conf = config.NewC(l)
  627. mf = &mockFirewall{}
  628. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  629. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  630. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  631. // Test single groups
  632. conf = config.NewC(l)
  633. mf = &mockFirewall{}
  634. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  635. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  636. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil}, mf.lastCall)
  637. // Test multiple AND groups
  638. conf = config.NewC(l)
  639. mf = &mockFirewall{}
  640. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  641. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  642. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil}, mf.lastCall)
  643. // Test Add error
  644. conf = config.NewC(l)
  645. mf = &mockFirewall{}
  646. mf.nextCallReturn = errors.New("test error")
  647. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  648. assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  649. }
  650. func TestTCPRTTTracking(t *testing.T) {
  651. b := make([]byte, 200)
  652. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  653. b[0] = 15
  654. b[60+12] = 15 << 4
  655. f := Firewall{
  656. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  657. }
  658. // Set SEQ to 1
  659. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  660. c := &conn{}
  661. setTCPRTTTracking(c, b)
  662. assert.Equal(t, uint32(1), c.Seq)
  663. // Bad ack - no ack flag
  664. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  665. assert.False(t, f.checkTCPRTT(c, b))
  666. // Bad ack, number is too low
  667. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  668. b[60+13] = uint8(0x10)
  669. assert.False(t, f.checkTCPRTT(c, b))
  670. // Good ack
  671. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  672. assert.True(t, f.checkTCPRTT(c, b))
  673. assert.Equal(t, uint32(0), c.Seq)
  674. // Set SEQ to 1
  675. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  676. c = &conn{}
  677. setTCPRTTTracking(c, b)
  678. assert.Equal(t, uint32(1), c.Seq)
  679. // Good acks
  680. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  681. assert.True(t, f.checkTCPRTT(c, b))
  682. assert.Equal(t, uint32(0), c.Seq)
  683. // Set SEQ to max uint32 - 20
  684. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  685. c = &conn{}
  686. setTCPRTTTracking(c, b)
  687. assert.Equal(t, ^uint32(0)-20, c.Seq)
  688. // Good acks
  689. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  690. assert.True(t, f.checkTCPRTT(c, b))
  691. assert.Equal(t, uint32(0), c.Seq)
  692. // Set SEQ to max uint32 / 2
  693. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  694. c = &conn{}
  695. setTCPRTTTracking(c, b)
  696. assert.Equal(t, ^uint32(0)/2, c.Seq)
  697. // Below
  698. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  699. assert.False(t, f.checkTCPRTT(c, b))
  700. assert.Equal(t, ^uint32(0)/2, c.Seq)
  701. // Halfway below
  702. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  703. assert.False(t, f.checkTCPRTT(c, b))
  704. assert.Equal(t, ^uint32(0)/2, c.Seq)
  705. // Halfway above is ok
  706. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  707. assert.True(t, f.checkTCPRTT(c, b))
  708. assert.Equal(t, uint32(0), c.Seq)
  709. // Set SEQ to max uint32
  710. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  711. c = &conn{}
  712. setTCPRTTTracking(c, b)
  713. assert.Equal(t, ^uint32(0), c.Seq)
  714. // Halfway + 1 above
  715. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  716. assert.False(t, f.checkTCPRTT(c, b))
  717. assert.Equal(t, ^uint32(0), c.Seq)
  718. // Halfway above
  719. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  720. assert.True(t, f.checkTCPRTT(c, b))
  721. assert.Equal(t, uint32(0), c.Seq)
  722. }
  723. func TestFirewall_convertRule(t *testing.T) {
  724. l := util.NewTestLogger()
  725. ob := &bytes.Buffer{}
  726. l.SetOutput(ob)
  727. // Ensure group array of 1 is converted and a warning is printed
  728. c := map[interface{}]interface{}{
  729. "group": []interface{}{"group1"},
  730. }
  731. r, err := convertRule(l, c, "test", 1)
  732. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  733. assert.Nil(t, err)
  734. assert.Equal(t, "group1", r.Group)
  735. // Ensure group array of > 1 is errord
  736. ob.Reset()
  737. c = map[interface{}]interface{}{
  738. "group": []interface{}{"group1", "group2"},
  739. }
  740. r, err = convertRule(l, c, "test", 1)
  741. assert.Equal(t, "", ob.String())
  742. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  743. // Make sure a well formed group is alright
  744. ob.Reset()
  745. c = map[interface{}]interface{}{
  746. "group": "group1",
  747. }
  748. r, err = convertRule(l, c, "test", 1)
  749. assert.Nil(t, err)
  750. assert.Equal(t, "group1", r.Group)
  751. }
  752. type addRuleCall struct {
  753. incoming bool
  754. proto uint8
  755. startPort int32
  756. endPort int32
  757. groups []string
  758. host string
  759. ip *net.IPNet
  760. caName string
  761. caSha string
  762. }
  763. type mockFirewall struct {
  764. lastCall addRuleCall
  765. nextCallReturn error
  766. }
  767. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  768. mf.lastCall = addRuleCall{
  769. incoming: incoming,
  770. proto: proto,
  771. startPort: startPort,
  772. endPort: endPort,
  773. groups: groups,
  774. host: host,
  775. ip: ip,
  776. caName: caName,
  777. caSha: caSha,
  778. }
  779. err := mf.nextCallReturn
  780. mf.nextCallReturn = nil
  781. return err
  782. }
  783. func resetConntrack(fw *Firewall) {
  784. fw.Conntrack.Lock()
  785. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  786. fw.Conntrack.Unlock()
  787. }