firewall_test.go 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/config"
  11. "github.com/slackhq/nebula/firewall"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. func TestNewFirewall(t *testing.T) {
  17. l := test.NewLogger()
  18. c := &dummyCert{}
  19. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  20. conntrack := fw.Conntrack
  21. assert.NotNil(t, conntrack)
  22. assert.NotNil(t, conntrack.Conns)
  23. assert.NotNil(t, conntrack.TimerWheel)
  24. assert.NotNil(t, fw.InRules)
  25. assert.NotNil(t, fw.OutRules)
  26. assert.Equal(t, time.Second, fw.TCPTimeout)
  27. assert.Equal(t, time.Minute, fw.UDPTimeout)
  28. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  29. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  30. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  32. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  33. conntrack = fw.Conntrack
  34. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  35. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  36. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  37. conntrack = fw.Conntrack
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  41. conntrack = fw.Conntrack
  42. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  43. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  44. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  45. conntrack = fw.Conntrack
  46. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  47. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  48. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  49. conntrack = fw.Conntrack
  50. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  51. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  52. }
  53. func TestFirewall_AddRule(t *testing.T) {
  54. l := test.NewLogger()
  55. ob := &bytes.Buffer{}
  56. l.SetOutput(ob)
  57. c := &dummyCert{}
  58. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  59. assert.NotNil(t, fw.InRules)
  60. assert.NotNil(t, fw.OutRules)
  61. ti, err := netip.ParsePrefix("1.2.3.4/32")
  62. require.NoError(t, err)
  63. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  64. // An empty rule is any
  65. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  66. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  67. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  68. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  69. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  70. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  71. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  72. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  73. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  74. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
  75. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  76. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  77. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  78. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  79. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
  80. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  81. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  82. assert.True(t, ok)
  83. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  84. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
  85. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  86. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  87. assert.True(t, ok)
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  89. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
  90. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  91. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  92. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
  93. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  94. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  95. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
  96. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  97. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  98. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  99. require.NoError(t, err)
  100. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
  101. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  102. // Test error conditions
  103. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  104. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  105. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  106. }
  107. func TestFirewall_Drop(t *testing.T) {
  108. l := test.NewLogger()
  109. ob := &bytes.Buffer{}
  110. l.SetOutput(ob)
  111. p := firewall.Packet{
  112. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  113. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  114. LocalPort: 10,
  115. RemotePort: 90,
  116. Protocol: firewall.ProtoUDP,
  117. Fragment: false,
  118. }
  119. c := dummyCert{
  120. name: "host1",
  121. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  122. groups: []string{"default-group"},
  123. issuer: "signer-shasum",
  124. }
  125. h := HostInfo{
  126. ConnectionState: &ConnectionState{
  127. peerCert: &cert.CachedCertificate{
  128. Certificate: &c,
  129. InvertedGroups: map[string]struct{}{"default-group": {}},
  130. },
  131. },
  132. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  133. }
  134. h.buildNetworks(c.networks, c.unsafeNetworks)
  135. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  136. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  137. cp := cert.NewCAPool()
  138. // Drop outbound
  139. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  140. // Allow inbound
  141. resetConntrack(fw)
  142. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  143. // Allow outbound because conntrack
  144. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  145. // test remote mismatch
  146. oldRemote := p.RemoteAddr
  147. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  148. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  149. p.RemoteAddr = oldRemote
  150. // ensure signer doesn't get in the way of group checks
  151. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  152. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  153. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  154. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  155. // test caSha doesn't drop on match
  156. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  157. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  158. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  159. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  160. // ensure ca name doesn't get in the way of group checks
  161. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  162. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  163. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  164. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  165. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  166. // test caName doesn't drop on match
  167. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  168. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  169. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  170. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  171. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  172. }
  173. func BenchmarkFirewallTable_match(b *testing.B) {
  174. f := &Firewall{}
  175. ft := FirewallTable{
  176. TCP: firewallPort{},
  177. }
  178. pfix := netip.MustParsePrefix("172.1.1.1/32")
  179. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
  180. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
  181. cp := cert.NewCAPool()
  182. b.Run("fail on proto", func(b *testing.B) {
  183. // This benchmark is showing us the cost of failing to match the protocol
  184. c := &cert.CachedCertificate{
  185. Certificate: &dummyCert{},
  186. }
  187. for n := 0; n < b.N; n++ {
  188. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  189. }
  190. })
  191. b.Run("pass proto, fail on port", func(b *testing.B) {
  192. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  193. c := &cert.CachedCertificate{
  194. Certificate: &dummyCert{},
  195. }
  196. for n := 0; n < b.N; n++ {
  197. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  198. }
  199. })
  200. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  201. c := &cert.CachedCertificate{
  202. Certificate: &dummyCert{},
  203. }
  204. ip := netip.MustParsePrefix("9.254.254.254/32")
  205. for n := 0; n < b.N; n++ {
  206. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  207. }
  208. })
  209. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  210. c := &cert.CachedCertificate{
  211. Certificate: &dummyCert{
  212. name: "nope",
  213. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  214. },
  215. InvertedGroups: map[string]struct{}{"nope": {}},
  216. }
  217. for n := 0; n < b.N; n++ {
  218. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  219. }
  220. })
  221. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  222. c := &cert.CachedCertificate{
  223. Certificate: &dummyCert{
  224. name: "nope",
  225. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  226. },
  227. InvertedGroups: map[string]struct{}{"nope": {}},
  228. }
  229. for n := 0; n < b.N; n++ {
  230. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  231. }
  232. })
  233. b.Run("pass on group on any local cidr", func(b *testing.B) {
  234. c := &cert.CachedCertificate{
  235. Certificate: &dummyCert{
  236. name: "nope",
  237. },
  238. InvertedGroups: map[string]struct{}{"good-group": {}},
  239. }
  240. for n := 0; n < b.N; n++ {
  241. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  242. }
  243. })
  244. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  245. c := &cert.CachedCertificate{
  246. Certificate: &dummyCert{
  247. name: "nope",
  248. },
  249. InvertedGroups: map[string]struct{}{"good-group": {}},
  250. }
  251. for n := 0; n < b.N; n++ {
  252. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  253. }
  254. })
  255. b.Run("pass on name", func(b *testing.B) {
  256. c := &cert.CachedCertificate{
  257. Certificate: &dummyCert{
  258. name: "good-host",
  259. },
  260. InvertedGroups: map[string]struct{}{"nope": {}},
  261. }
  262. for n := 0; n < b.N; n++ {
  263. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  264. }
  265. })
  266. }
  267. func TestFirewall_Drop2(t *testing.T) {
  268. l := test.NewLogger()
  269. ob := &bytes.Buffer{}
  270. l.SetOutput(ob)
  271. p := firewall.Packet{
  272. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  273. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  274. LocalPort: 10,
  275. RemotePort: 90,
  276. Protocol: firewall.ProtoUDP,
  277. Fragment: false,
  278. }
  279. network := netip.MustParsePrefix("1.2.3.4/24")
  280. c := cert.CachedCertificate{
  281. Certificate: &dummyCert{
  282. name: "host1",
  283. networks: []netip.Prefix{network},
  284. },
  285. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  286. }
  287. h := HostInfo{
  288. ConnectionState: &ConnectionState{
  289. peerCert: &c,
  290. },
  291. vpnAddrs: []netip.Addr{network.Addr()},
  292. }
  293. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  294. c1 := cert.CachedCertificate{
  295. Certificate: &dummyCert{
  296. name: "host1",
  297. networks: []netip.Prefix{network},
  298. },
  299. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  300. }
  301. h1 := HostInfo{
  302. vpnAddrs: []netip.Addr{network.Addr()},
  303. ConnectionState: &ConnectionState{
  304. peerCert: &c1,
  305. },
  306. }
  307. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  308. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  309. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  310. cp := cert.NewCAPool()
  311. // h1/c1 lacks the proper groups
  312. require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  313. // c has the proper groups
  314. resetConntrack(fw)
  315. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  316. }
  317. func TestFirewall_Drop3(t *testing.T) {
  318. l := test.NewLogger()
  319. ob := &bytes.Buffer{}
  320. l.SetOutput(ob)
  321. p := firewall.Packet{
  322. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  323. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  324. LocalPort: 1,
  325. RemotePort: 1,
  326. Protocol: firewall.ProtoUDP,
  327. Fragment: false,
  328. }
  329. network := netip.MustParsePrefix("1.2.3.4/24")
  330. c := cert.CachedCertificate{
  331. Certificate: &dummyCert{
  332. name: "host-owner",
  333. networks: []netip.Prefix{network},
  334. },
  335. }
  336. c1 := cert.CachedCertificate{
  337. Certificate: &dummyCert{
  338. name: "host1",
  339. networks: []netip.Prefix{network},
  340. issuer: "signer-sha-bad",
  341. },
  342. }
  343. h1 := HostInfo{
  344. ConnectionState: &ConnectionState{
  345. peerCert: &c1,
  346. },
  347. vpnAddrs: []netip.Addr{network.Addr()},
  348. }
  349. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  350. c2 := cert.CachedCertificate{
  351. Certificate: &dummyCert{
  352. name: "host2",
  353. networks: []netip.Prefix{network},
  354. issuer: "signer-sha",
  355. },
  356. }
  357. h2 := HostInfo{
  358. ConnectionState: &ConnectionState{
  359. peerCert: &c2,
  360. },
  361. vpnAddrs: []netip.Addr{network.Addr()},
  362. }
  363. h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
  364. c3 := cert.CachedCertificate{
  365. Certificate: &dummyCert{
  366. name: "host3",
  367. networks: []netip.Prefix{network},
  368. issuer: "signer-sha-bad",
  369. },
  370. }
  371. h3 := HostInfo{
  372. ConnectionState: &ConnectionState{
  373. peerCert: &c3,
  374. },
  375. vpnAddrs: []netip.Addr{network.Addr()},
  376. }
  377. h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
  378. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  379. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
  380. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
  381. cp := cert.NewCAPool()
  382. // c1 should pass because host match
  383. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  384. // c2 should pass because ca sha match
  385. resetConntrack(fw)
  386. require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  387. // c3 should fail because no match
  388. resetConntrack(fw)
  389. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  390. // Test a remote address match
  391. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  392. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
  393. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  394. }
  395. func TestFirewall_DropConntrackReload(t *testing.T) {
  396. l := test.NewLogger()
  397. ob := &bytes.Buffer{}
  398. l.SetOutput(ob)
  399. p := firewall.Packet{
  400. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  401. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  402. LocalPort: 10,
  403. RemotePort: 90,
  404. Protocol: firewall.ProtoUDP,
  405. Fragment: false,
  406. }
  407. network := netip.MustParsePrefix("1.2.3.4/24")
  408. c := cert.CachedCertificate{
  409. Certificate: &dummyCert{
  410. name: "host1",
  411. networks: []netip.Prefix{network},
  412. groups: []string{"default-group"},
  413. issuer: "signer-shasum",
  414. },
  415. InvertedGroups: map[string]struct{}{"default-group": {}},
  416. }
  417. h := HostInfo{
  418. ConnectionState: &ConnectionState{
  419. peerCert: &c,
  420. },
  421. vpnAddrs: []netip.Addr{network.Addr()},
  422. }
  423. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  424. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  425. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  426. cp := cert.NewCAPool()
  427. // Drop outbound
  428. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  429. // Allow inbound
  430. resetConntrack(fw)
  431. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  432. // Allow outbound because conntrack
  433. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  434. oldFw := fw
  435. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  436. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  437. fw.Conntrack = oldFw.Conntrack
  438. fw.rulesVersion = oldFw.rulesVersion + 1
  439. // Allow outbound because conntrack and new rules allow port 10
  440. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  441. oldFw = fw
  442. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  443. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  444. fw.Conntrack = oldFw.Conntrack
  445. fw.rulesVersion = oldFw.rulesVersion + 1
  446. // Drop outbound because conntrack doesn't match new ruleset
  447. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  448. }
  449. func BenchmarkLookup(b *testing.B) {
  450. ml := func(m map[string]struct{}, a [][]string) {
  451. for n := 0; n < b.N; n++ {
  452. for _, sg := range a {
  453. found := false
  454. for _, g := range sg {
  455. if _, ok := m[g]; !ok {
  456. found = false
  457. break
  458. }
  459. found = true
  460. }
  461. if found {
  462. return
  463. }
  464. }
  465. }
  466. }
  467. b.Run("array to map best", func(b *testing.B) {
  468. m := map[string]struct{}{
  469. "1ne": {},
  470. "2wo": {},
  471. "3hr": {},
  472. "4ou": {},
  473. "5iv": {},
  474. "6ix": {},
  475. }
  476. a := [][]string{
  477. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  478. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  479. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  480. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  481. {"one", "two", "thr", "fou", "5iv", "6ix"},
  482. {"one", "two", "thr", "fou", "fiv", "6ix"},
  483. {"one", "two", "thr", "fou", "fiv", "six"},
  484. }
  485. for n := 0; n < b.N; n++ {
  486. ml(m, a)
  487. }
  488. })
  489. b.Run("array to map worst", func(b *testing.B) {
  490. m := map[string]struct{}{
  491. "one": {},
  492. "two": {},
  493. "thr": {},
  494. "fou": {},
  495. "fiv": {},
  496. "six": {},
  497. }
  498. a := [][]string{
  499. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  500. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  501. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  502. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  503. {"one", "two", "thr", "fou", "5iv", "6ix"},
  504. {"one", "two", "thr", "fou", "fiv", "6ix"},
  505. {"one", "two", "thr", "fou", "fiv", "six"},
  506. }
  507. for n := 0; n < b.N; n++ {
  508. ml(m, a)
  509. }
  510. })
  511. }
  512. func Test_parsePort(t *testing.T) {
  513. _, _, err := parsePort("")
  514. require.EqualError(t, err, "was not a number; ``")
  515. _, _, err = parsePort(" ")
  516. require.EqualError(t, err, "was not a number; ` `")
  517. _, _, err = parsePort("-")
  518. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  519. _, _, err = parsePort(" - ")
  520. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  521. _, _, err = parsePort("a-b")
  522. require.EqualError(t, err, "beginning range was not a number; `a`")
  523. _, _, err = parsePort("1-b")
  524. require.EqualError(t, err, "ending range was not a number; `b`")
  525. s, e, err := parsePort(" 1 - 2 ")
  526. assert.Equal(t, int32(1), s)
  527. assert.Equal(t, int32(2), e)
  528. require.NoError(t, err)
  529. s, e, err = parsePort("0-1")
  530. assert.Equal(t, int32(0), s)
  531. assert.Equal(t, int32(0), e)
  532. require.NoError(t, err)
  533. s, e, err = parsePort("9919")
  534. assert.Equal(t, int32(9919), s)
  535. assert.Equal(t, int32(9919), e)
  536. require.NoError(t, err)
  537. s, e, err = parsePort("any")
  538. assert.Equal(t, int32(0), s)
  539. assert.Equal(t, int32(0), e)
  540. require.NoError(t, err)
  541. }
  542. func TestNewFirewallFromConfig(t *testing.T) {
  543. l := test.NewLogger()
  544. // Test a bad rule definition
  545. c := &dummyCert{}
  546. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  547. require.NoError(t, err)
  548. conf := config.NewC(l)
  549. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  550. _, err = NewFirewallFromConfig(l, cs, conf)
  551. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  552. // Test both port and code
  553. conf = config.NewC(l)
  554. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  555. _, err = NewFirewallFromConfig(l, cs, conf)
  556. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  557. // Test missing host, group, cidr, ca_name and ca_sha
  558. conf = config.NewC(l)
  559. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  560. _, err = NewFirewallFromConfig(l, cs, conf)
  561. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  562. // Test code/port error
  563. conf = config.NewC(l)
  564. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
  565. _, err = NewFirewallFromConfig(l, cs, conf)
  566. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  567. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
  568. _, err = NewFirewallFromConfig(l, cs, conf)
  569. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  570. // Test proto error
  571. conf = config.NewC(l)
  572. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  573. _, err = NewFirewallFromConfig(l, cs, conf)
  574. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  575. // Test cidr parse error
  576. conf = config.NewC(l)
  577. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  578. _, err = NewFirewallFromConfig(l, cs, conf)
  579. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  580. // Test local_cidr parse error
  581. conf = config.NewC(l)
  582. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  583. _, err = NewFirewallFromConfig(l, cs, conf)
  584. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  585. // Test both group and groups
  586. conf = config.NewC(l)
  587. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  588. _, err = NewFirewallFromConfig(l, cs, conf)
  589. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  590. }
  591. func TestAddFirewallRulesFromConfig(t *testing.T) {
  592. l := test.NewLogger()
  593. // Test adding tcp rule
  594. conf := config.NewC(l)
  595. mf := &mockFirewall{}
  596. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  597. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  598. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  599. // Test adding udp rule
  600. conf = config.NewC(l)
  601. mf = &mockFirewall{}
  602. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  603. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  604. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  605. // Test adding icmp rule
  606. conf = config.NewC(l)
  607. mf = &mockFirewall{}
  608. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  609. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  610. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  611. // Test adding any rule
  612. conf = config.NewC(l)
  613. mf = &mockFirewall{}
  614. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  615. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  616. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  617. // Test adding rule with cidr
  618. cidr := netip.MustParsePrefix("10.0.0.0/8")
  619. conf = config.NewC(l)
  620. mf = &mockFirewall{}
  621. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  622. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  623. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
  624. // Test adding rule with local_cidr
  625. conf = config.NewC(l)
  626. mf = &mockFirewall{}
  627. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  628. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  629. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
  630. // Test adding rule with ca_sha
  631. conf = config.NewC(l)
  632. mf = &mockFirewall{}
  633. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  634. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  635. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
  636. // Test adding rule with ca_name
  637. conf = config.NewC(l)
  638. mf = &mockFirewall{}
  639. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  640. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  641. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
  642. // Test single group
  643. conf = config.NewC(l)
  644. mf = &mockFirewall{}
  645. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  646. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  647. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  648. // Test single groups
  649. conf = config.NewC(l)
  650. mf = &mockFirewall{}
  651. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  652. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  653. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  654. // Test multiple AND groups
  655. conf = config.NewC(l)
  656. mf = &mockFirewall{}
  657. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  658. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  659. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  660. // Test Add error
  661. conf = config.NewC(l)
  662. mf = &mockFirewall{}
  663. mf.nextCallReturn = errors.New("test error")
  664. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  665. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  666. }
  667. func TestFirewall_convertRule(t *testing.T) {
  668. l := test.NewLogger()
  669. ob := &bytes.Buffer{}
  670. l.SetOutput(ob)
  671. // Ensure group array of 1 is converted and a warning is printed
  672. c := map[string]any{
  673. "group": []any{"group1"},
  674. }
  675. r, err := convertRule(l, c, "test", 1)
  676. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  677. require.NoError(t, err)
  678. assert.Equal(t, "group1", r.Group)
  679. // Ensure group array of > 1 is errord
  680. ob.Reset()
  681. c = map[string]any{
  682. "group": []any{"group1", "group2"},
  683. }
  684. r, err = convertRule(l, c, "test", 1)
  685. assert.Empty(t, ob.String())
  686. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  687. // Make sure a well formed group is alright
  688. ob.Reset()
  689. c = map[string]any{
  690. "group": "group1",
  691. }
  692. r, err = convertRule(l, c, "test", 1)
  693. require.NoError(t, err)
  694. assert.Equal(t, "group1", r.Group)
  695. }
  696. type addRuleCall struct {
  697. incoming bool
  698. proto uint8
  699. startPort int32
  700. endPort int32
  701. groups []string
  702. host string
  703. ip netip.Prefix
  704. localIp netip.Prefix
  705. caName string
  706. caSha string
  707. }
  708. type mockFirewall struct {
  709. lastCall addRuleCall
  710. nextCallReturn error
  711. }
  712. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
  713. mf.lastCall = addRuleCall{
  714. incoming: incoming,
  715. proto: proto,
  716. startPort: startPort,
  717. endPort: endPort,
  718. groups: groups,
  719. host: host,
  720. ip: ip,
  721. localIp: localIp,
  722. caName: caName,
  723. caSha: caSha,
  724. }
  725. err := mf.nextCallReturn
  726. mf.nextCallReturn = nil
  727. return err
  728. }
  729. func resetConntrack(fw *Firewall) {
  730. fw.Conntrack.Lock()
  731. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  732. fw.Conntrack.Unlock()
  733. }