firewall_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/config"
  11. "github.com/slackhq/nebula/firewall"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. func TestNewFirewall(t *testing.T) {
  17. l := test.NewLogger()
  18. c := &dummyCert{}
  19. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  20. conntrack := fw.Conntrack
  21. assert.NotNil(t, conntrack)
  22. assert.NotNil(t, conntrack.Conns)
  23. assert.NotNil(t, conntrack.TimerWheel)
  24. assert.NotNil(t, fw.InRules)
  25. assert.NotNil(t, fw.OutRules)
  26. assert.Equal(t, time.Second, fw.TCPTimeout)
  27. assert.Equal(t, time.Minute, fw.UDPTimeout)
  28. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  29. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  30. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  32. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  33. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  34. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  35. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  36. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  37. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  38. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  39. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  40. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  41. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  42. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  43. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  44. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  45. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  46. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  47. }
  48. func TestFirewall_AddRule(t *testing.T) {
  49. l := test.NewLogger()
  50. ob := &bytes.Buffer{}
  51. l.SetOutput(ob)
  52. c := &dummyCert{}
  53. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  54. assert.NotNil(t, fw.InRules)
  55. assert.NotNil(t, fw.OutRules)
  56. ti, err := netip.ParsePrefix("1.2.3.4/32")
  57. require.NoError(t, err)
  58. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  59. // An empty rule is any
  60. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  61. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  62. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  63. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  64. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  65. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  66. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  67. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  68. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  69. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
  70. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  71. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  72. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  73. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  74. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
  75. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  76. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  77. assert.True(t, ok)
  78. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  79. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
  80. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  81. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  82. assert.True(t, ok)
  83. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  84. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
  85. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  86. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  87. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
  88. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  89. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  90. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
  91. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  92. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  93. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  94. require.NoError(t, err)
  95. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
  96. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  97. // Test error conditions
  98. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  99. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  100. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  101. }
  102. func TestFirewall_Drop(t *testing.T) {
  103. l := test.NewLogger()
  104. ob := &bytes.Buffer{}
  105. l.SetOutput(ob)
  106. p := firewall.Packet{
  107. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  108. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  109. LocalPort: 10,
  110. RemotePort: 90,
  111. Protocol: firewall.ProtoUDP,
  112. Fragment: false,
  113. }
  114. c := dummyCert{
  115. name: "host1",
  116. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  117. groups: []string{"default-group"},
  118. issuer: "signer-shasum",
  119. }
  120. h := HostInfo{
  121. ConnectionState: &ConnectionState{
  122. peerCert: &cert.CachedCertificate{
  123. Certificate: &c,
  124. InvertedGroups: map[string]struct{}{"default-group": {}},
  125. },
  126. },
  127. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  128. }
  129. h.buildNetworks(c.networks, c.unsafeNetworks)
  130. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  131. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  132. cp := cert.NewCAPool()
  133. // Drop outbound
  134. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  135. // Allow inbound
  136. resetConntrack(fw)
  137. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  138. // Allow outbound because conntrack
  139. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  140. // test remote mismatch
  141. oldRemote := p.RemoteAddr
  142. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  143. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  144. p.RemoteAddr = oldRemote
  145. // ensure signer doesn't get in the way of group checks
  146. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  147. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  148. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  149. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  150. // test caSha doesn't drop on match
  151. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  152. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  153. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  154. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  155. // ensure ca name doesn't get in the way of group checks
  156. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  157. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  158. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  159. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  160. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  161. // test caName doesn't drop on match
  162. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  163. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  164. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  165. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  166. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  167. }
  168. func BenchmarkFirewallTable_match(b *testing.B) {
  169. f := &Firewall{}
  170. ft := FirewallTable{
  171. TCP: firewallPort{},
  172. }
  173. pfix := netip.MustParsePrefix("172.1.1.1/32")
  174. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
  175. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
  176. cp := cert.NewCAPool()
  177. b.Run("fail on proto", func(b *testing.B) {
  178. // This benchmark is showing us the cost of failing to match the protocol
  179. c := &cert.CachedCertificate{
  180. Certificate: &dummyCert{},
  181. }
  182. for n := 0; n < b.N; n++ {
  183. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  184. }
  185. })
  186. b.Run("pass proto, fail on port", func(b *testing.B) {
  187. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  188. c := &cert.CachedCertificate{
  189. Certificate: &dummyCert{},
  190. }
  191. for n := 0; n < b.N; n++ {
  192. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  193. }
  194. })
  195. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  196. c := &cert.CachedCertificate{
  197. Certificate: &dummyCert{},
  198. }
  199. ip := netip.MustParsePrefix("9.254.254.254/32")
  200. for n := 0; n < b.N; n++ {
  201. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  202. }
  203. })
  204. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  205. c := &cert.CachedCertificate{
  206. Certificate: &dummyCert{
  207. name: "nope",
  208. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  209. },
  210. InvertedGroups: map[string]struct{}{"nope": {}},
  211. }
  212. for n := 0; n < b.N; n++ {
  213. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  214. }
  215. })
  216. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  217. c := &cert.CachedCertificate{
  218. Certificate: &dummyCert{
  219. name: "nope",
  220. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  221. },
  222. InvertedGroups: map[string]struct{}{"nope": {}},
  223. }
  224. for n := 0; n < b.N; n++ {
  225. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  226. }
  227. })
  228. b.Run("pass on group on any local cidr", func(b *testing.B) {
  229. c := &cert.CachedCertificate{
  230. Certificate: &dummyCert{
  231. name: "nope",
  232. },
  233. InvertedGroups: map[string]struct{}{"good-group": {}},
  234. }
  235. for n := 0; n < b.N; n++ {
  236. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  237. }
  238. })
  239. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  240. c := &cert.CachedCertificate{
  241. Certificate: &dummyCert{
  242. name: "nope",
  243. },
  244. InvertedGroups: map[string]struct{}{"good-group": {}},
  245. }
  246. for n := 0; n < b.N; n++ {
  247. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  248. }
  249. })
  250. b.Run("pass on name", func(b *testing.B) {
  251. c := &cert.CachedCertificate{
  252. Certificate: &dummyCert{
  253. name: "good-host",
  254. },
  255. InvertedGroups: map[string]struct{}{"nope": {}},
  256. }
  257. for n := 0; n < b.N; n++ {
  258. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  259. }
  260. })
  261. }
  262. func TestFirewall_Drop2(t *testing.T) {
  263. l := test.NewLogger()
  264. ob := &bytes.Buffer{}
  265. l.SetOutput(ob)
  266. p := firewall.Packet{
  267. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  268. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  269. LocalPort: 10,
  270. RemotePort: 90,
  271. Protocol: firewall.ProtoUDP,
  272. Fragment: false,
  273. }
  274. network := netip.MustParsePrefix("1.2.3.4/24")
  275. c := cert.CachedCertificate{
  276. Certificate: &dummyCert{
  277. name: "host1",
  278. networks: []netip.Prefix{network},
  279. },
  280. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  281. }
  282. h := HostInfo{
  283. ConnectionState: &ConnectionState{
  284. peerCert: &c,
  285. },
  286. vpnAddrs: []netip.Addr{network.Addr()},
  287. }
  288. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  289. c1 := cert.CachedCertificate{
  290. Certificate: &dummyCert{
  291. name: "host1",
  292. networks: []netip.Prefix{network},
  293. },
  294. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  295. }
  296. h1 := HostInfo{
  297. vpnAddrs: []netip.Addr{network.Addr()},
  298. ConnectionState: &ConnectionState{
  299. peerCert: &c1,
  300. },
  301. }
  302. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  303. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  304. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  305. cp := cert.NewCAPool()
  306. // h1/c1 lacks the proper groups
  307. require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  308. // c has the proper groups
  309. resetConntrack(fw)
  310. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  311. }
  312. func TestFirewall_Drop3(t *testing.T) {
  313. l := test.NewLogger()
  314. ob := &bytes.Buffer{}
  315. l.SetOutput(ob)
  316. p := firewall.Packet{
  317. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  318. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  319. LocalPort: 1,
  320. RemotePort: 1,
  321. Protocol: firewall.ProtoUDP,
  322. Fragment: false,
  323. }
  324. network := netip.MustParsePrefix("1.2.3.4/24")
  325. c := cert.CachedCertificate{
  326. Certificate: &dummyCert{
  327. name: "host-owner",
  328. networks: []netip.Prefix{network},
  329. },
  330. }
  331. c1 := cert.CachedCertificate{
  332. Certificate: &dummyCert{
  333. name: "host1",
  334. networks: []netip.Prefix{network},
  335. issuer: "signer-sha-bad",
  336. },
  337. }
  338. h1 := HostInfo{
  339. ConnectionState: &ConnectionState{
  340. peerCert: &c1,
  341. },
  342. vpnAddrs: []netip.Addr{network.Addr()},
  343. }
  344. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  345. c2 := cert.CachedCertificate{
  346. Certificate: &dummyCert{
  347. name: "host2",
  348. networks: []netip.Prefix{network},
  349. issuer: "signer-sha",
  350. },
  351. }
  352. h2 := HostInfo{
  353. ConnectionState: &ConnectionState{
  354. peerCert: &c2,
  355. },
  356. vpnAddrs: []netip.Addr{network.Addr()},
  357. }
  358. h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
  359. c3 := cert.CachedCertificate{
  360. Certificate: &dummyCert{
  361. name: "host3",
  362. networks: []netip.Prefix{network},
  363. issuer: "signer-sha-bad",
  364. },
  365. }
  366. h3 := HostInfo{
  367. ConnectionState: &ConnectionState{
  368. peerCert: &c3,
  369. },
  370. vpnAddrs: []netip.Addr{network.Addr()},
  371. }
  372. h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
  373. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  374. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
  375. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
  376. cp := cert.NewCAPool()
  377. // c1 should pass because host match
  378. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  379. // c2 should pass because ca sha match
  380. resetConntrack(fw)
  381. require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  382. // c3 should fail because no match
  383. resetConntrack(fw)
  384. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  385. // Test a remote address match
  386. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  387. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
  388. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  389. }
  390. func TestFirewall_DropConntrackReload(t *testing.T) {
  391. l := test.NewLogger()
  392. ob := &bytes.Buffer{}
  393. l.SetOutput(ob)
  394. p := firewall.Packet{
  395. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  396. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  397. LocalPort: 10,
  398. RemotePort: 90,
  399. Protocol: firewall.ProtoUDP,
  400. Fragment: false,
  401. }
  402. network := netip.MustParsePrefix("1.2.3.4/24")
  403. c := cert.CachedCertificate{
  404. Certificate: &dummyCert{
  405. name: "host1",
  406. networks: []netip.Prefix{network},
  407. groups: []string{"default-group"},
  408. issuer: "signer-shasum",
  409. },
  410. InvertedGroups: map[string]struct{}{"default-group": {}},
  411. }
  412. h := HostInfo{
  413. ConnectionState: &ConnectionState{
  414. peerCert: &c,
  415. },
  416. vpnAddrs: []netip.Addr{network.Addr()},
  417. }
  418. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  419. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  420. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  421. cp := cert.NewCAPool()
  422. // Drop outbound
  423. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  424. // Allow inbound
  425. resetConntrack(fw)
  426. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  427. // Allow outbound because conntrack
  428. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  429. oldFw := fw
  430. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  431. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  432. fw.Conntrack = oldFw.Conntrack
  433. fw.rulesVersion = oldFw.rulesVersion + 1
  434. // Allow outbound because conntrack and new rules allow port 10
  435. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  436. oldFw = fw
  437. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  438. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  439. fw.Conntrack = oldFw.Conntrack
  440. fw.rulesVersion = oldFw.rulesVersion + 1
  441. // Drop outbound because conntrack doesn't match new ruleset
  442. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  443. }
  444. func BenchmarkLookup(b *testing.B) {
  445. ml := func(m map[string]struct{}, a [][]string) {
  446. for n := 0; n < b.N; n++ {
  447. for _, sg := range a {
  448. found := false
  449. for _, g := range sg {
  450. if _, ok := m[g]; !ok {
  451. found = false
  452. break
  453. }
  454. found = true
  455. }
  456. if found {
  457. return
  458. }
  459. }
  460. }
  461. }
  462. b.Run("array to map best", func(b *testing.B) {
  463. m := map[string]struct{}{
  464. "1ne": {},
  465. "2wo": {},
  466. "3hr": {},
  467. "4ou": {},
  468. "5iv": {},
  469. "6ix": {},
  470. }
  471. a := [][]string{
  472. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  473. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  474. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  475. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  476. {"one", "two", "thr", "fou", "5iv", "6ix"},
  477. {"one", "two", "thr", "fou", "fiv", "6ix"},
  478. {"one", "two", "thr", "fou", "fiv", "six"},
  479. }
  480. for n := 0; n < b.N; n++ {
  481. ml(m, a)
  482. }
  483. })
  484. b.Run("array to map worst", func(b *testing.B) {
  485. m := map[string]struct{}{
  486. "one": {},
  487. "two": {},
  488. "thr": {},
  489. "fou": {},
  490. "fiv": {},
  491. "six": {},
  492. }
  493. a := [][]string{
  494. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  495. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  496. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  497. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  498. {"one", "two", "thr", "fou", "5iv", "6ix"},
  499. {"one", "two", "thr", "fou", "fiv", "6ix"},
  500. {"one", "two", "thr", "fou", "fiv", "six"},
  501. }
  502. for n := 0; n < b.N; n++ {
  503. ml(m, a)
  504. }
  505. })
  506. }
  507. func Test_parsePort(t *testing.T) {
  508. _, _, err := parsePort("")
  509. require.EqualError(t, err, "was not a number; ``")
  510. _, _, err = parsePort(" ")
  511. require.EqualError(t, err, "was not a number; ` `")
  512. _, _, err = parsePort("-")
  513. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  514. _, _, err = parsePort(" - ")
  515. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  516. _, _, err = parsePort("a-b")
  517. require.EqualError(t, err, "beginning range was not a number; `a`")
  518. _, _, err = parsePort("1-b")
  519. require.EqualError(t, err, "ending range was not a number; `b`")
  520. s, e, err := parsePort(" 1 - 2 ")
  521. assert.Equal(t, int32(1), s)
  522. assert.Equal(t, int32(2), e)
  523. require.NoError(t, err)
  524. s, e, err = parsePort("0-1")
  525. assert.Equal(t, int32(0), s)
  526. assert.Equal(t, int32(0), e)
  527. require.NoError(t, err)
  528. s, e, err = parsePort("9919")
  529. assert.Equal(t, int32(9919), s)
  530. assert.Equal(t, int32(9919), e)
  531. require.NoError(t, err)
  532. s, e, err = parsePort("any")
  533. assert.Equal(t, int32(0), s)
  534. assert.Equal(t, int32(0), e)
  535. require.NoError(t, err)
  536. }
  537. func TestNewFirewallFromConfig(t *testing.T) {
  538. l := test.NewLogger()
  539. // Test a bad rule definition
  540. c := &dummyCert{}
  541. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  542. require.NoError(t, err)
  543. conf := config.NewC(l)
  544. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  545. _, err = NewFirewallFromConfig(l, cs, conf)
  546. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  547. // Test both port and code
  548. conf = config.NewC(l)
  549. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  550. _, err = NewFirewallFromConfig(l, cs, conf)
  551. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  552. // Test missing host, group, cidr, ca_name and ca_sha
  553. conf = config.NewC(l)
  554. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  555. _, err = NewFirewallFromConfig(l, cs, conf)
  556. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  557. // Test code/port error
  558. conf = config.NewC(l)
  559. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
  560. _, err = NewFirewallFromConfig(l, cs, conf)
  561. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  562. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
  563. _, err = NewFirewallFromConfig(l, cs, conf)
  564. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  565. // Test proto error
  566. conf = config.NewC(l)
  567. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  568. _, err = NewFirewallFromConfig(l, cs, conf)
  569. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  570. // Test cidr parse error
  571. conf = config.NewC(l)
  572. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  573. _, err = NewFirewallFromConfig(l, cs, conf)
  574. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  575. // Test local_cidr parse error
  576. conf = config.NewC(l)
  577. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  578. _, err = NewFirewallFromConfig(l, cs, conf)
  579. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  580. // Test both group and groups
  581. conf = config.NewC(l)
  582. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  583. _, err = NewFirewallFromConfig(l, cs, conf)
  584. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  585. }
  586. func TestAddFirewallRulesFromConfig(t *testing.T) {
  587. l := test.NewLogger()
  588. // Test adding tcp rule
  589. conf := config.NewC(l)
  590. mf := &mockFirewall{}
  591. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  592. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  593. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  594. // Test adding udp rule
  595. conf = config.NewC(l)
  596. mf = &mockFirewall{}
  597. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  598. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  599. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  600. // Test adding icmp rule
  601. conf = config.NewC(l)
  602. mf = &mockFirewall{}
  603. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  604. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  605. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  606. // Test adding any rule
  607. conf = config.NewC(l)
  608. mf = &mockFirewall{}
  609. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  610. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  611. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  612. // Test adding rule with cidr
  613. cidr := netip.MustParsePrefix("10.0.0.0/8")
  614. conf = config.NewC(l)
  615. mf = &mockFirewall{}
  616. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  617. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  618. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
  619. // Test adding rule with local_cidr
  620. conf = config.NewC(l)
  621. mf = &mockFirewall{}
  622. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  623. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  624. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
  625. // Test adding rule with ca_sha
  626. conf = config.NewC(l)
  627. mf = &mockFirewall{}
  628. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  629. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  630. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
  631. // Test adding rule with ca_name
  632. conf = config.NewC(l)
  633. mf = &mockFirewall{}
  634. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  635. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  636. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
  637. // Test single group
  638. conf = config.NewC(l)
  639. mf = &mockFirewall{}
  640. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  641. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  642. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  643. // Test single groups
  644. conf = config.NewC(l)
  645. mf = &mockFirewall{}
  646. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  647. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  648. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  649. // Test multiple AND groups
  650. conf = config.NewC(l)
  651. mf = &mockFirewall{}
  652. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  653. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  654. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  655. // Test Add error
  656. conf = config.NewC(l)
  657. mf = &mockFirewall{}
  658. mf.nextCallReturn = errors.New("test error")
  659. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  660. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  661. }
  662. func TestFirewall_convertRule(t *testing.T) {
  663. l := test.NewLogger()
  664. ob := &bytes.Buffer{}
  665. l.SetOutput(ob)
  666. // Ensure group array of 1 is converted and a warning is printed
  667. c := map[string]any{
  668. "group": []any{"group1"},
  669. }
  670. r, err := convertRule(l, c, "test", 1)
  671. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  672. require.NoError(t, err)
  673. assert.Equal(t, "group1", r.Group)
  674. // Ensure group array of > 1 is errord
  675. ob.Reset()
  676. c = map[string]any{
  677. "group": []any{"group1", "group2"},
  678. }
  679. r, err = convertRule(l, c, "test", 1)
  680. assert.Empty(t, ob.String())
  681. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  682. // Make sure a well formed group is alright
  683. ob.Reset()
  684. c = map[string]any{
  685. "group": "group1",
  686. }
  687. r, err = convertRule(l, c, "test", 1)
  688. require.NoError(t, err)
  689. assert.Equal(t, "group1", r.Group)
  690. }
  691. type addRuleCall struct {
  692. incoming bool
  693. proto uint8
  694. startPort int32
  695. endPort int32
  696. groups []string
  697. host string
  698. ip netip.Prefix
  699. localIp netip.Prefix
  700. caName string
  701. caSha string
  702. }
  703. type mockFirewall struct {
  704. lastCall addRuleCall
  705. nextCallReturn error
  706. }
  707. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
  708. mf.lastCall = addRuleCall{
  709. incoming: incoming,
  710. proto: proto,
  711. startPort: startPort,
  712. endPort: endPort,
  713. groups: groups,
  714. host: host,
  715. ip: ip,
  716. localIp: localIp,
  717. caName: caName,
  718. caSha: caSha,
  719. }
  720. err := mf.nextCallReturn
  721. mf.nextCallReturn = nil
  722. return err
  723. }
  724. func resetConntrack(fw *Firewall) {
  725. fw.Conntrack.Lock()
  726. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  727. fw.Conntrack.Unlock()
  728. }