firewall_test.go 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/config"
  11. "github.com/slackhq/nebula/firewall"
  12. "github.com/slackhq/nebula/test"
  13. "github.com/stretchr/testify/assert"
  14. "github.com/stretchr/testify/require"
  15. )
  16. func TestNewFirewall(t *testing.T) {
  17. l := test.NewLogger()
  18. c := &dummyCert{}
  19. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  20. conntrack := fw.Conntrack
  21. assert.NotNil(t, conntrack)
  22. assert.NotNil(t, conntrack.Conns)
  23. assert.NotNil(t, conntrack.TimerWheel)
  24. assert.NotNil(t, fw.InRules)
  25. assert.NotNil(t, fw.OutRules)
  26. assert.Equal(t, time.Second, fw.TCPTimeout)
  27. assert.Equal(t, time.Minute, fw.UDPTimeout)
  28. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  29. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  30. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  32. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  33. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  34. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  35. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  36. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  37. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  38. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  39. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  40. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  41. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  42. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  43. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  44. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  45. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  46. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  47. }
  48. func TestFirewall_AddRule(t *testing.T) {
  49. l := test.NewLogger()
  50. ob := &bytes.Buffer{}
  51. l.SetOutput(ob)
  52. c := &dummyCert{}
  53. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  54. assert.NotNil(t, fw.InRules)
  55. assert.NotNil(t, fw.OutRules)
  56. ti, err := netip.ParsePrefix("1.2.3.4/32")
  57. require.NoError(t, err)
  58. ti6, err := netip.ParsePrefix("fd12::34/128")
  59. require.NoError(t, err)
  60. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  61. // An empty rule is any
  62. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  63. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  64. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  65. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  66. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  67. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  68. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  69. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  70. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  71. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
  72. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  73. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  74. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  75. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  76. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
  77. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  78. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  79. assert.True(t, ok)
  80. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  81. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
  82. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  83. _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
  84. assert.True(t, ok)
  85. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  86. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
  87. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  88. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  89. assert.True(t, ok)
  90. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  91. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
  92. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  93. _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
  94. assert.True(t, ok)
  95. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  96. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
  97. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  98. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  99. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
  100. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  101. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  102. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
  103. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  104. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  105. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  106. require.NoError(t, err)
  107. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
  108. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  109. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  110. anyIp6, err := netip.ParsePrefix("::/0")
  111. require.NoError(t, err)
  112. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
  113. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  114. // Test error conditions
  115. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  116. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  117. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  118. }
  119. func TestFirewall_Drop(t *testing.T) {
  120. l := test.NewLogger()
  121. ob := &bytes.Buffer{}
  122. l.SetOutput(ob)
  123. p := firewall.Packet{
  124. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  125. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  126. LocalPort: 10,
  127. RemotePort: 90,
  128. Protocol: firewall.ProtoUDP,
  129. Fragment: false,
  130. }
  131. c := dummyCert{
  132. name: "host1",
  133. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  134. groups: []string{"default-group"},
  135. issuer: "signer-shasum",
  136. }
  137. h := HostInfo{
  138. ConnectionState: &ConnectionState{
  139. peerCert: &cert.CachedCertificate{
  140. Certificate: &c,
  141. InvertedGroups: map[string]struct{}{"default-group": {}},
  142. },
  143. },
  144. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  145. }
  146. h.buildNetworks(c.networks, c.unsafeNetworks)
  147. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  148. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  149. cp := cert.NewCAPool()
  150. // Drop outbound
  151. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  152. // Allow inbound
  153. resetConntrack(fw)
  154. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  155. // Allow outbound because conntrack
  156. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  157. // test remote mismatch
  158. oldRemote := p.RemoteAddr
  159. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  160. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  161. p.RemoteAddr = oldRemote
  162. // ensure signer doesn't get in the way of group checks
  163. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  164. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  165. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  166. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  167. // test caSha doesn't drop on match
  168. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  169. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  170. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  171. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  172. // ensure ca name doesn't get in the way of group checks
  173. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  174. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  175. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  176. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  177. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  178. // test caName doesn't drop on match
  179. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  180. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  181. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  182. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  183. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  184. }
  185. func TestFirewall_DropV6(t *testing.T) {
  186. l := test.NewLogger()
  187. ob := &bytes.Buffer{}
  188. l.SetOutput(ob)
  189. p := firewall.Packet{
  190. LocalAddr: netip.MustParseAddr("fd12::34"),
  191. RemoteAddr: netip.MustParseAddr("fd12::34"),
  192. LocalPort: 10,
  193. RemotePort: 90,
  194. Protocol: firewall.ProtoUDP,
  195. Fragment: false,
  196. }
  197. c := dummyCert{
  198. name: "host1",
  199. networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
  200. groups: []string{"default-group"},
  201. issuer: "signer-shasum",
  202. }
  203. h := HostInfo{
  204. ConnectionState: &ConnectionState{
  205. peerCert: &cert.CachedCertificate{
  206. Certificate: &c,
  207. InvertedGroups: map[string]struct{}{"default-group": {}},
  208. },
  209. },
  210. vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
  211. }
  212. h.buildNetworks(c.networks, c.unsafeNetworks)
  213. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  214. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  215. cp := cert.NewCAPool()
  216. // Drop outbound
  217. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  218. // Allow inbound
  219. resetConntrack(fw)
  220. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  221. // Allow outbound because conntrack
  222. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  223. // test remote mismatch
  224. oldRemote := p.RemoteAddr
  225. p.RemoteAddr = netip.MustParseAddr("fd12::56")
  226. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  227. p.RemoteAddr = oldRemote
  228. // ensure signer doesn't get in the way of group checks
  229. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  230. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  231. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  232. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  233. // test caSha doesn't drop on match
  234. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  235. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  236. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  237. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  238. // ensure ca name doesn't get in the way of group checks
  239. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  240. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  241. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  242. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  243. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  244. // test caName doesn't drop on match
  245. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  246. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  247. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  248. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  249. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  250. }
  251. func BenchmarkFirewallTable_match(b *testing.B) {
  252. f := &Firewall{}
  253. ft := FirewallTable{
  254. TCP: firewallPort{},
  255. }
  256. pfix := netip.MustParsePrefix("172.1.1.1/32")
  257. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
  258. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
  259. pfix6 := netip.MustParsePrefix("fd11::11/128")
  260. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
  261. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
  262. cp := cert.NewCAPool()
  263. b.Run("fail on proto", func(b *testing.B) {
  264. // This benchmark is showing us the cost of failing to match the protocol
  265. c := &cert.CachedCertificate{
  266. Certificate: &dummyCert{},
  267. }
  268. for n := 0; n < b.N; n++ {
  269. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  270. }
  271. })
  272. b.Run("pass proto, fail on port", func(b *testing.B) {
  273. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  274. c := &cert.CachedCertificate{
  275. Certificate: &dummyCert{},
  276. }
  277. for n := 0; n < b.N; n++ {
  278. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  279. }
  280. })
  281. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  282. c := &cert.CachedCertificate{
  283. Certificate: &dummyCert{},
  284. }
  285. ip := netip.MustParsePrefix("9.254.254.254/32")
  286. for n := 0; n < b.N; n++ {
  287. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  288. }
  289. })
  290. b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
  291. c := &cert.CachedCertificate{
  292. Certificate: &dummyCert{},
  293. }
  294. ip := netip.MustParsePrefix("fd99::99/128")
  295. for n := 0; n < b.N; n++ {
  296. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  297. }
  298. })
  299. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  300. c := &cert.CachedCertificate{
  301. Certificate: &dummyCert{
  302. name: "nope",
  303. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  304. },
  305. InvertedGroups: map[string]struct{}{"nope": {}},
  306. }
  307. for n := 0; n < b.N; n++ {
  308. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  309. }
  310. })
  311. b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  312. c := &cert.CachedCertificate{
  313. Certificate: &dummyCert{
  314. name: "nope",
  315. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  316. },
  317. InvertedGroups: map[string]struct{}{"nope": {}},
  318. }
  319. for n := 0; n < b.N; n++ {
  320. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  321. }
  322. })
  323. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  324. c := &cert.CachedCertificate{
  325. Certificate: &dummyCert{
  326. name: "nope",
  327. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  328. },
  329. InvertedGroups: map[string]struct{}{"nope": {}},
  330. }
  331. for n := 0; n < b.N; n++ {
  332. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  333. }
  334. })
  335. b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  336. c := &cert.CachedCertificate{
  337. Certificate: &dummyCert{
  338. name: "nope",
  339. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  340. },
  341. InvertedGroups: map[string]struct{}{"nope": {}},
  342. }
  343. for n := 0; n < b.N; n++ {
  344. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  345. }
  346. })
  347. b.Run("pass on group on any local cidr", func(b *testing.B) {
  348. c := &cert.CachedCertificate{
  349. Certificate: &dummyCert{
  350. name: "nope",
  351. },
  352. InvertedGroups: map[string]struct{}{"good-group": {}},
  353. }
  354. for n := 0; n < b.N; n++ {
  355. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  356. }
  357. })
  358. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  359. c := &cert.CachedCertificate{
  360. Certificate: &dummyCert{
  361. name: "nope",
  362. },
  363. InvertedGroups: map[string]struct{}{"good-group": {}},
  364. }
  365. for n := 0; n < b.N; n++ {
  366. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  367. }
  368. })
  369. b.Run("pass on group on specific local cidr6", func(b *testing.B) {
  370. c := &cert.CachedCertificate{
  371. Certificate: &dummyCert{
  372. name: "nope",
  373. },
  374. InvertedGroups: map[string]struct{}{"good-group": {}},
  375. }
  376. for n := 0; n < b.N; n++ {
  377. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  378. }
  379. })
  380. b.Run("pass on name", func(b *testing.B) {
  381. c := &cert.CachedCertificate{
  382. Certificate: &dummyCert{
  383. name: "good-host",
  384. },
  385. InvertedGroups: map[string]struct{}{"nope": {}},
  386. }
  387. for n := 0; n < b.N; n++ {
  388. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  389. }
  390. })
  391. }
  392. func TestFirewall_Drop2(t *testing.T) {
  393. l := test.NewLogger()
  394. ob := &bytes.Buffer{}
  395. l.SetOutput(ob)
  396. p := firewall.Packet{
  397. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  398. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  399. LocalPort: 10,
  400. RemotePort: 90,
  401. Protocol: firewall.ProtoUDP,
  402. Fragment: false,
  403. }
  404. network := netip.MustParsePrefix("1.2.3.4/24")
  405. c := cert.CachedCertificate{
  406. Certificate: &dummyCert{
  407. name: "host1",
  408. networks: []netip.Prefix{network},
  409. },
  410. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  411. }
  412. h := HostInfo{
  413. ConnectionState: &ConnectionState{
  414. peerCert: &c,
  415. },
  416. vpnAddrs: []netip.Addr{network.Addr()},
  417. }
  418. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  419. c1 := cert.CachedCertificate{
  420. Certificate: &dummyCert{
  421. name: "host1",
  422. networks: []netip.Prefix{network},
  423. },
  424. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  425. }
  426. h1 := HostInfo{
  427. vpnAddrs: []netip.Addr{network.Addr()},
  428. ConnectionState: &ConnectionState{
  429. peerCert: &c1,
  430. },
  431. }
  432. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  433. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  434. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  435. cp := cert.NewCAPool()
  436. // h1/c1 lacks the proper groups
  437. require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  438. // c has the proper groups
  439. resetConntrack(fw)
  440. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  441. }
  442. func TestFirewall_Drop3(t *testing.T) {
  443. l := test.NewLogger()
  444. ob := &bytes.Buffer{}
  445. l.SetOutput(ob)
  446. p := firewall.Packet{
  447. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  448. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  449. LocalPort: 1,
  450. RemotePort: 1,
  451. Protocol: firewall.ProtoUDP,
  452. Fragment: false,
  453. }
  454. network := netip.MustParsePrefix("1.2.3.4/24")
  455. c := cert.CachedCertificate{
  456. Certificate: &dummyCert{
  457. name: "host-owner",
  458. networks: []netip.Prefix{network},
  459. },
  460. }
  461. c1 := cert.CachedCertificate{
  462. Certificate: &dummyCert{
  463. name: "host1",
  464. networks: []netip.Prefix{network},
  465. issuer: "signer-sha-bad",
  466. },
  467. }
  468. h1 := HostInfo{
  469. ConnectionState: &ConnectionState{
  470. peerCert: &c1,
  471. },
  472. vpnAddrs: []netip.Addr{network.Addr()},
  473. }
  474. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  475. c2 := cert.CachedCertificate{
  476. Certificate: &dummyCert{
  477. name: "host2",
  478. networks: []netip.Prefix{network},
  479. issuer: "signer-sha",
  480. },
  481. }
  482. h2 := HostInfo{
  483. ConnectionState: &ConnectionState{
  484. peerCert: &c2,
  485. },
  486. vpnAddrs: []netip.Addr{network.Addr()},
  487. }
  488. h2.buildNetworks(c2.Certificate.Networks(), c2.Certificate.UnsafeNetworks())
  489. c3 := cert.CachedCertificate{
  490. Certificate: &dummyCert{
  491. name: "host3",
  492. networks: []netip.Prefix{network},
  493. issuer: "signer-sha-bad",
  494. },
  495. }
  496. h3 := HostInfo{
  497. ConnectionState: &ConnectionState{
  498. peerCert: &c3,
  499. },
  500. vpnAddrs: []netip.Addr{network.Addr()},
  501. }
  502. h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
  503. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  504. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
  505. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
  506. cp := cert.NewCAPool()
  507. // c1 should pass because host match
  508. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  509. // c2 should pass because ca sha match
  510. resetConntrack(fw)
  511. require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  512. // c3 should fail because no match
  513. resetConntrack(fw)
  514. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  515. // Test a remote address match
  516. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  517. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
  518. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  519. }
  520. func TestFirewall_Drop3V6(t *testing.T) {
  521. l := test.NewLogger()
  522. ob := &bytes.Buffer{}
  523. l.SetOutput(ob)
  524. p := firewall.Packet{
  525. LocalAddr: netip.MustParseAddr("fd12::34"),
  526. RemoteAddr: netip.MustParseAddr("fd12::34"),
  527. LocalPort: 1,
  528. RemotePort: 1,
  529. Protocol: firewall.ProtoUDP,
  530. Fragment: false,
  531. }
  532. network := netip.MustParsePrefix("fd12::34/120")
  533. c := cert.CachedCertificate{
  534. Certificate: &dummyCert{
  535. name: "host-owner",
  536. networks: []netip.Prefix{network},
  537. },
  538. }
  539. h := HostInfo{
  540. ConnectionState: &ConnectionState{
  541. peerCert: &c,
  542. },
  543. vpnAddrs: []netip.Addr{network.Addr()},
  544. }
  545. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  546. // Test a remote address match
  547. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  548. cp := cert.NewCAPool()
  549. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
  550. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  551. }
  552. func TestFirewall_DropConntrackReload(t *testing.T) {
  553. l := test.NewLogger()
  554. ob := &bytes.Buffer{}
  555. l.SetOutput(ob)
  556. p := firewall.Packet{
  557. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  558. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  559. LocalPort: 10,
  560. RemotePort: 90,
  561. Protocol: firewall.ProtoUDP,
  562. Fragment: false,
  563. }
  564. network := netip.MustParsePrefix("1.2.3.4/24")
  565. c := cert.CachedCertificate{
  566. Certificate: &dummyCert{
  567. name: "host1",
  568. networks: []netip.Prefix{network},
  569. groups: []string{"default-group"},
  570. issuer: "signer-shasum",
  571. },
  572. InvertedGroups: map[string]struct{}{"default-group": {}},
  573. }
  574. h := HostInfo{
  575. ConnectionState: &ConnectionState{
  576. peerCert: &c,
  577. },
  578. vpnAddrs: []netip.Addr{network.Addr()},
  579. }
  580. h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
  581. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  582. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  583. cp := cert.NewCAPool()
  584. // Drop outbound
  585. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  586. // Allow inbound
  587. resetConntrack(fw)
  588. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  589. // Allow outbound because conntrack
  590. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  591. oldFw := fw
  592. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  593. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  594. fw.Conntrack = oldFw.Conntrack
  595. fw.rulesVersion = oldFw.rulesVersion + 1
  596. // Allow outbound because conntrack and new rules allow port 10
  597. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  598. oldFw = fw
  599. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  600. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  601. fw.Conntrack = oldFw.Conntrack
  602. fw.rulesVersion = oldFw.rulesVersion + 1
  603. // Drop outbound because conntrack doesn't match new ruleset
  604. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  605. }
  606. func TestFirewall_DropIPSpoofing(t *testing.T) {
  607. l := test.NewLogger()
  608. ob := &bytes.Buffer{}
  609. l.SetOutput(ob)
  610. c := cert.CachedCertificate{
  611. Certificate: &dummyCert{
  612. name: "host-owner",
  613. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
  614. },
  615. }
  616. c1 := cert.CachedCertificate{
  617. Certificate: &dummyCert{
  618. name: "host",
  619. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
  620. unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
  621. },
  622. }
  623. h1 := HostInfo{
  624. ConnectionState: &ConnectionState{
  625. peerCert: &c1,
  626. },
  627. vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
  628. }
  629. h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
  630. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  631. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  632. cp := cert.NewCAPool()
  633. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
  634. p := firewall.Packet{
  635. LocalAddr: netip.MustParseAddr("192.0.2.1"),
  636. RemoteAddr: netip.MustParseAddr("192.0.2.3"),
  637. LocalPort: 1,
  638. RemotePort: 1,
  639. Protocol: firewall.ProtoUDP,
  640. Fragment: false,
  641. }
  642. assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
  643. }
  644. func BenchmarkLookup(b *testing.B) {
  645. ml := func(m map[string]struct{}, a [][]string) {
  646. for n := 0; n < b.N; n++ {
  647. for _, sg := range a {
  648. found := false
  649. for _, g := range sg {
  650. if _, ok := m[g]; !ok {
  651. found = false
  652. break
  653. }
  654. found = true
  655. }
  656. if found {
  657. return
  658. }
  659. }
  660. }
  661. }
  662. b.Run("array to map best", func(b *testing.B) {
  663. m := map[string]struct{}{
  664. "1ne": {},
  665. "2wo": {},
  666. "3hr": {},
  667. "4ou": {},
  668. "5iv": {},
  669. "6ix": {},
  670. }
  671. a := [][]string{
  672. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  673. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  674. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  675. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  676. {"one", "two", "thr", "fou", "5iv", "6ix"},
  677. {"one", "two", "thr", "fou", "fiv", "6ix"},
  678. {"one", "two", "thr", "fou", "fiv", "six"},
  679. }
  680. for n := 0; n < b.N; n++ {
  681. ml(m, a)
  682. }
  683. })
  684. b.Run("array to map worst", func(b *testing.B) {
  685. m := map[string]struct{}{
  686. "one": {},
  687. "two": {},
  688. "thr": {},
  689. "fou": {},
  690. "fiv": {},
  691. "six": {},
  692. }
  693. a := [][]string{
  694. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  695. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  696. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  697. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  698. {"one", "two", "thr", "fou", "5iv", "6ix"},
  699. {"one", "two", "thr", "fou", "fiv", "6ix"},
  700. {"one", "two", "thr", "fou", "fiv", "six"},
  701. }
  702. for n := 0; n < b.N; n++ {
  703. ml(m, a)
  704. }
  705. })
  706. }
  707. func Test_parsePort(t *testing.T) {
  708. _, _, err := parsePort("")
  709. require.EqualError(t, err, "was not a number; ``")
  710. _, _, err = parsePort(" ")
  711. require.EqualError(t, err, "was not a number; ` `")
  712. _, _, err = parsePort("-")
  713. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  714. _, _, err = parsePort(" - ")
  715. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  716. _, _, err = parsePort("a-b")
  717. require.EqualError(t, err, "beginning range was not a number; `a`")
  718. _, _, err = parsePort("1-b")
  719. require.EqualError(t, err, "ending range was not a number; `b`")
  720. s, e, err := parsePort(" 1 - 2 ")
  721. assert.Equal(t, int32(1), s)
  722. assert.Equal(t, int32(2), e)
  723. require.NoError(t, err)
  724. s, e, err = parsePort("0-1")
  725. assert.Equal(t, int32(0), s)
  726. assert.Equal(t, int32(0), e)
  727. require.NoError(t, err)
  728. s, e, err = parsePort("9919")
  729. assert.Equal(t, int32(9919), s)
  730. assert.Equal(t, int32(9919), e)
  731. require.NoError(t, err)
  732. s, e, err = parsePort("any")
  733. assert.Equal(t, int32(0), s)
  734. assert.Equal(t, int32(0), e)
  735. require.NoError(t, err)
  736. }
  737. func TestNewFirewallFromConfig(t *testing.T) {
  738. l := test.NewLogger()
  739. // Test a bad rule definition
  740. c := &dummyCert{}
  741. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  742. require.NoError(t, err)
  743. conf := config.NewC(l)
  744. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  745. _, err = NewFirewallFromConfig(l, cs, conf)
  746. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  747. // Test both port and code
  748. conf = config.NewC(l)
  749. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  750. _, err = NewFirewallFromConfig(l, cs, conf)
  751. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  752. // Test missing host, group, cidr, ca_name and ca_sha
  753. conf = config.NewC(l)
  754. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  755. _, err = NewFirewallFromConfig(l, cs, conf)
  756. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  757. // Test code/port error
  758. conf = config.NewC(l)
  759. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
  760. _, err = NewFirewallFromConfig(l, cs, conf)
  761. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  762. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
  763. _, err = NewFirewallFromConfig(l, cs, conf)
  764. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  765. // Test proto error
  766. conf = config.NewC(l)
  767. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  768. _, err = NewFirewallFromConfig(l, cs, conf)
  769. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  770. // Test cidr parse error
  771. conf = config.NewC(l)
  772. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  773. _, err = NewFirewallFromConfig(l, cs, conf)
  774. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  775. // Test local_cidr parse error
  776. conf = config.NewC(l)
  777. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  778. _, err = NewFirewallFromConfig(l, cs, conf)
  779. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  780. // Test both group and groups
  781. conf = config.NewC(l)
  782. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  783. _, err = NewFirewallFromConfig(l, cs, conf)
  784. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  785. }
  786. func TestAddFirewallRulesFromConfig(t *testing.T) {
  787. l := test.NewLogger()
  788. // Test adding tcp rule
  789. conf := config.NewC(l)
  790. mf := &mockFirewall{}
  791. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  792. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  793. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  794. // Test adding udp rule
  795. conf = config.NewC(l)
  796. mf = &mockFirewall{}
  797. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  798. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  799. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  800. // Test adding icmp rule
  801. conf = config.NewC(l)
  802. mf = &mockFirewall{}
  803. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  804. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  805. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  806. // Test adding any rule
  807. conf = config.NewC(l)
  808. mf = &mockFirewall{}
  809. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  810. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  811. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  812. // Test adding rule with cidr
  813. cidr := netip.MustParsePrefix("10.0.0.0/8")
  814. conf = config.NewC(l)
  815. mf = &mockFirewall{}
  816. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  817. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  818. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
  819. // Test adding rule with local_cidr
  820. conf = config.NewC(l)
  821. mf = &mockFirewall{}
  822. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  823. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  824. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
  825. // Test adding rule with cidr ipv6
  826. cidr6 := netip.MustParsePrefix("fd00::/8")
  827. conf = config.NewC(l)
  828. mf = &mockFirewall{}
  829. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
  830. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  831. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
  832. // Test adding rule with local_cidr ipv6
  833. conf = config.NewC(l)
  834. mf = &mockFirewall{}
  835. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
  836. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  837. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
  838. // Test adding rule with ca_sha
  839. conf = config.NewC(l)
  840. mf = &mockFirewall{}
  841. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  842. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  843. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
  844. // Test adding rule with ca_name
  845. conf = config.NewC(l)
  846. mf = &mockFirewall{}
  847. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  848. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  849. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
  850. // Test single group
  851. conf = config.NewC(l)
  852. mf = &mockFirewall{}
  853. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  854. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  855. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  856. // Test single groups
  857. conf = config.NewC(l)
  858. mf = &mockFirewall{}
  859. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  860. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  861. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  862. // Test multiple AND groups
  863. conf = config.NewC(l)
  864. mf = &mockFirewall{}
  865. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  866. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  867. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  868. // Test Add error
  869. conf = config.NewC(l)
  870. mf = &mockFirewall{}
  871. mf.nextCallReturn = errors.New("test error")
  872. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  873. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  874. }
  875. func TestFirewall_convertRule(t *testing.T) {
  876. l := test.NewLogger()
  877. ob := &bytes.Buffer{}
  878. l.SetOutput(ob)
  879. // Ensure group array of 1 is converted and a warning is printed
  880. c := map[string]any{
  881. "group": []any{"group1"},
  882. }
  883. r, err := convertRule(l, c, "test", 1)
  884. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  885. require.NoError(t, err)
  886. assert.Equal(t, []string{"group1"}, r.Groups)
  887. // Ensure group array of > 1 is errord
  888. ob.Reset()
  889. c = map[string]any{
  890. "group": []any{"group1", "group2"},
  891. }
  892. r, err = convertRule(l, c, "test", 1)
  893. assert.Empty(t, ob.String())
  894. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  895. // Make sure a well formed group is alright
  896. ob.Reset()
  897. c = map[string]any{
  898. "group": "group1",
  899. }
  900. r, err = convertRule(l, c, "test", 1)
  901. require.NoError(t, err)
  902. assert.Equal(t, []string{"group1"}, r.Groups)
  903. }
  904. func TestFirewall_convertRuleSanity(t *testing.T) {
  905. l := test.NewLogger()
  906. ob := &bytes.Buffer{}
  907. l.SetOutput(ob)
  908. noWarningPlease := []map[string]any{
  909. {"group": "group1"},
  910. {"groups": []any{"group2"}},
  911. {"host": "bob"},
  912. {"cidr": "1.1.1.1/1"},
  913. {"groups": []any{"group2"}, "host": "bob"},
  914. {"cidr": "1.1.1.1/1", "host": "bob"},
  915. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  916. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  917. }
  918. for _, c := range noWarningPlease {
  919. r, err := convertRule(l, c, "test", 1)
  920. require.NoError(t, err)
  921. require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
  922. }
  923. yesWarningPlease := []map[string]any{
  924. {"group": "group1"},
  925. {"groups": []any{"group2"}},
  926. {"cidr": "1.1.1.1/1"},
  927. {"groups": []any{"group2"}, "host": "bob"},
  928. {"cidr": "1.1.1.1/1", "host": "bob"},
  929. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  930. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  931. }
  932. for _, c := range yesWarningPlease {
  933. c["host"] = "any"
  934. r, err := convertRule(l, c, "test", 1)
  935. require.NoError(t, err)
  936. err = r.sanity()
  937. require.Error(t, err, "I wanted a warning: %+v", c)
  938. }
  939. //reset the list
  940. yesWarningPlease = []map[string]any{
  941. {"group": "group1"},
  942. {"groups": []any{"group2"}},
  943. {"cidr": "1.1.1.1/1"},
  944. {"groups": []any{"group2"}, "host": "bob"},
  945. {"cidr": "1.1.1.1/1", "host": "bob"},
  946. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  947. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  948. }
  949. for _, c := range yesWarningPlease {
  950. r, err := convertRule(l, c, "test", 1)
  951. require.NoError(t, err)
  952. r.Groups = append(r.Groups, "any")
  953. err = r.sanity()
  954. require.Error(t, err, "I wanted a warning: %+v", c)
  955. }
  956. }
  957. type addRuleCall struct {
  958. incoming bool
  959. proto uint8
  960. startPort int32
  961. endPort int32
  962. groups []string
  963. host string
  964. ip netip.Prefix
  965. localIp netip.Prefix
  966. caName string
  967. caSha string
  968. }
  969. type mockFirewall struct {
  970. lastCall addRuleCall
  971. nextCallReturn error
  972. }
  973. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
  974. mf.lastCall = addRuleCall{
  975. incoming: incoming,
  976. proto: proto,
  977. startPort: startPort,
  978. endPort: endPort,
  979. groups: groups,
  980. host: host,
  981. ip: ip,
  982. localIp: localIp,
  983. caName: caName,
  984. caSha: caSha,
  985. }
  986. err := mf.nextCallReturn
  987. mf.nextCallReturn = nil
  988. return err
  989. }
  990. func resetConntrack(fw *Firewall) {
  991. fw.Conntrack.Lock()
  992. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  993. fw.Conntrack.Unlock()
  994. }