firewall_test.go 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/gaissmai/bart"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/firewall"
  14. "github.com/slackhq/nebula/test"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. )
  18. func TestNewFirewall(t *testing.T) {
  19. l := test.NewLogger()
  20. c := &dummyCert{}
  21. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  22. conntrack := fw.Conntrack
  23. assert.NotNil(t, conntrack)
  24. assert.NotNil(t, conntrack.Conns)
  25. assert.NotNil(t, conntrack.TimerWheel)
  26. assert.NotNil(t, fw.InRules)
  27. assert.NotNil(t, fw.OutRules)
  28. assert.Equal(t, time.Second, fw.TCPTimeout)
  29. assert.Equal(t, time.Minute, fw.UDPTimeout)
  30. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  46. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  47. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  48. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  49. }
  50. func TestFirewall_AddRule(t *testing.T) {
  51. l := test.NewLogger()
  52. ob := &bytes.Buffer{}
  53. l.SetOutput(ob)
  54. c := &dummyCert{}
  55. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  56. assert.NotNil(t, fw.InRules)
  57. assert.NotNil(t, fw.OutRules)
  58. ti, err := netip.ParsePrefix("1.2.3.4/32")
  59. require.NoError(t, err)
  60. ti6, err := netip.ParsePrefix("fd12::34/128")
  61. require.NoError(t, err)
  62. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  63. // An empty rule is any
  64. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  65. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  66. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  67. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  68. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  69. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  70. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  71. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  72. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  73. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
  74. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  75. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  76. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  77. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  78. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
  79. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  80. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  81. assert.True(t, ok)
  82. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  83. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6, netip.Prefix{}, "", ""))
  84. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  85. _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
  86. assert.True(t, ok)
  87. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  88. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
  89. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  90. ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  91. assert.True(t, ok)
  92. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  93. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti6, "", ""))
  94. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  95. ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
  96. assert.True(t, ok)
  97. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  98. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
  99. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  100. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  101. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
  102. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  103. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  104. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
  105. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  106. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  107. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  108. require.NoError(t, err)
  109. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
  110. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  111. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  112. anyIp6, err := netip.ParsePrefix("::/0")
  113. require.NoError(t, err)
  114. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6, netip.Prefix{}, "", ""))
  115. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  116. // Test error conditions
  117. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  118. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  119. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  120. }
  121. func TestFirewall_Drop(t *testing.T) {
  122. l := test.NewLogger()
  123. ob := &bytes.Buffer{}
  124. l.SetOutput(ob)
  125. myVpnNetworksTable := new(bart.Lite)
  126. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  127. p := firewall.Packet{
  128. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  129. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  130. LocalPort: 10,
  131. RemotePort: 90,
  132. Protocol: firewall.ProtoUDP,
  133. Fragment: false,
  134. }
  135. c := dummyCert{
  136. name: "host1",
  137. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  138. groups: []string{"default-group"},
  139. issuer: "signer-shasum",
  140. }
  141. h := HostInfo{
  142. ConnectionState: &ConnectionState{
  143. peerCert: &cert.CachedCertificate{
  144. Certificate: &c,
  145. InvertedGroups: map[string]struct{}{"default-group": {}},
  146. },
  147. },
  148. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  149. }
  150. h.buildNetworks(myVpnNetworksTable, &c)
  151. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  152. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  153. cp := cert.NewCAPool()
  154. // Drop outbound
  155. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  156. // Allow inbound
  157. resetConntrack(fw)
  158. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  159. // Allow outbound because conntrack
  160. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  161. // test remote mismatch
  162. oldRemote := p.RemoteAddr
  163. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  164. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  165. p.RemoteAddr = oldRemote
  166. // ensure signer doesn't get in the way of group checks
  167. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  168. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  169. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  170. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  171. // test caSha doesn't drop on match
  172. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  173. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  174. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  175. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  176. // ensure ca name doesn't get in the way of group checks
  177. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  178. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  179. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  180. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  181. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  182. // test caName doesn't drop on match
  183. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  184. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  185. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  186. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  187. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  188. }
  189. func TestFirewall_DropV6(t *testing.T) {
  190. l := test.NewLogger()
  191. ob := &bytes.Buffer{}
  192. l.SetOutput(ob)
  193. myVpnNetworksTable := new(bart.Lite)
  194. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  195. p := firewall.Packet{
  196. LocalAddr: netip.MustParseAddr("fd12::34"),
  197. RemoteAddr: netip.MustParseAddr("fd12::34"),
  198. LocalPort: 10,
  199. RemotePort: 90,
  200. Protocol: firewall.ProtoUDP,
  201. Fragment: false,
  202. }
  203. c := dummyCert{
  204. name: "host1",
  205. networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
  206. groups: []string{"default-group"},
  207. issuer: "signer-shasum",
  208. }
  209. h := HostInfo{
  210. ConnectionState: &ConnectionState{
  211. peerCert: &cert.CachedCertificate{
  212. Certificate: &c,
  213. InvertedGroups: map[string]struct{}{"default-group": {}},
  214. },
  215. },
  216. vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
  217. }
  218. h.buildNetworks(myVpnNetworksTable, &c)
  219. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  220. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  221. cp := cert.NewCAPool()
  222. // Drop outbound
  223. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
  224. // Allow inbound
  225. resetConntrack(fw)
  226. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  227. // Allow outbound because conntrack
  228. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  229. // test remote mismatch
  230. oldRemote := p.RemoteAddr
  231. p.RemoteAddr = netip.MustParseAddr("fd12::56")
  232. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  233. p.RemoteAddr = oldRemote
  234. // ensure signer doesn't get in the way of group checks
  235. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  236. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  237. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  238. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  239. // test caSha doesn't drop on match
  240. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  241. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
  242. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
  243. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  244. // ensure ca name doesn't get in the way of group checks
  245. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  246. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  247. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  248. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  249. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  250. // test caName doesn't drop on match
  251. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  252. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  253. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
  254. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
  255. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  256. }
  257. func BenchmarkFirewallTable_match(b *testing.B) {
  258. f := &Firewall{}
  259. ft := FirewallTable{
  260. TCP: firewallPort{},
  261. }
  262. pfix := netip.MustParsePrefix("172.1.1.1/32")
  263. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix, netip.Prefix{}, "", "")
  264. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix, "", "")
  265. pfix6 := netip.MustParsePrefix("fd11::11/128")
  266. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6, netip.Prefix{}, "", "")
  267. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", netip.Prefix{}, pfix6, "", "")
  268. cp := cert.NewCAPool()
  269. b.Run("fail on proto", func(b *testing.B) {
  270. // This benchmark is showing us the cost of failing to match the protocol
  271. c := &cert.CachedCertificate{
  272. Certificate: &dummyCert{},
  273. }
  274. for n := 0; n < b.N; n++ {
  275. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  276. }
  277. })
  278. b.Run("pass proto, fail on port", func(b *testing.B) {
  279. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  280. c := &cert.CachedCertificate{
  281. Certificate: &dummyCert{},
  282. }
  283. for n := 0; n < b.N; n++ {
  284. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  285. }
  286. })
  287. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  288. c := &cert.CachedCertificate{
  289. Certificate: &dummyCert{},
  290. }
  291. ip := netip.MustParsePrefix("9.254.254.254/32")
  292. for n := 0; n < b.N; n++ {
  293. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  294. }
  295. })
  296. b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
  297. c := &cert.CachedCertificate{
  298. Certificate: &dummyCert{},
  299. }
  300. ip := netip.MustParsePrefix("fd99::99/128")
  301. for n := 0; n < b.N; n++ {
  302. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  303. }
  304. })
  305. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  306. c := &cert.CachedCertificate{
  307. Certificate: &dummyCert{
  308. name: "nope",
  309. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  310. },
  311. InvertedGroups: map[string]struct{}{"nope": {}},
  312. }
  313. for n := 0; n < b.N; n++ {
  314. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  315. }
  316. })
  317. b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  318. c := &cert.CachedCertificate{
  319. Certificate: &dummyCert{
  320. name: "nope",
  321. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  322. },
  323. InvertedGroups: map[string]struct{}{"nope": {}},
  324. }
  325. for n := 0; n < b.N; n++ {
  326. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  327. }
  328. })
  329. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  330. c := &cert.CachedCertificate{
  331. Certificate: &dummyCert{
  332. name: "nope",
  333. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  334. },
  335. InvertedGroups: map[string]struct{}{"nope": {}},
  336. }
  337. for n := 0; n < b.N; n++ {
  338. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  339. }
  340. })
  341. b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  342. c := &cert.CachedCertificate{
  343. Certificate: &dummyCert{
  344. name: "nope",
  345. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  346. },
  347. InvertedGroups: map[string]struct{}{"nope": {}},
  348. }
  349. for n := 0; n < b.N; n++ {
  350. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  351. }
  352. })
  353. b.Run("pass on group on any local cidr", func(b *testing.B) {
  354. c := &cert.CachedCertificate{
  355. Certificate: &dummyCert{
  356. name: "nope",
  357. },
  358. InvertedGroups: map[string]struct{}{"good-group": {}},
  359. }
  360. for n := 0; n < b.N; n++ {
  361. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  362. }
  363. })
  364. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  365. c := &cert.CachedCertificate{
  366. Certificate: &dummyCert{
  367. name: "nope",
  368. },
  369. InvertedGroups: map[string]struct{}{"good-group": {}},
  370. }
  371. for n := 0; n < b.N; n++ {
  372. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  373. }
  374. })
  375. b.Run("pass on group on specific local cidr6", func(b *testing.B) {
  376. c := &cert.CachedCertificate{
  377. Certificate: &dummyCert{
  378. name: "nope",
  379. },
  380. InvertedGroups: map[string]struct{}{"good-group": {}},
  381. }
  382. for n := 0; n < b.N; n++ {
  383. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  384. }
  385. })
  386. b.Run("pass on name", func(b *testing.B) {
  387. c := &cert.CachedCertificate{
  388. Certificate: &dummyCert{
  389. name: "good-host",
  390. },
  391. InvertedGroups: map[string]struct{}{"nope": {}},
  392. }
  393. for n := 0; n < b.N; n++ {
  394. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  395. }
  396. })
  397. }
  398. func TestFirewall_Drop2(t *testing.T) {
  399. l := test.NewLogger()
  400. ob := &bytes.Buffer{}
  401. l.SetOutput(ob)
  402. myVpnNetworksTable := new(bart.Lite)
  403. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  404. p := firewall.Packet{
  405. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  406. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  407. LocalPort: 10,
  408. RemotePort: 90,
  409. Protocol: firewall.ProtoUDP,
  410. Fragment: false,
  411. }
  412. network := netip.MustParsePrefix("1.2.3.4/24")
  413. c := cert.CachedCertificate{
  414. Certificate: &dummyCert{
  415. name: "host1",
  416. networks: []netip.Prefix{network},
  417. },
  418. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  419. }
  420. h := HostInfo{
  421. ConnectionState: &ConnectionState{
  422. peerCert: &c,
  423. },
  424. vpnAddrs: []netip.Addr{network.Addr()},
  425. }
  426. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  427. c1 := cert.CachedCertificate{
  428. Certificate: &dummyCert{
  429. name: "host1",
  430. networks: []netip.Prefix{network},
  431. },
  432. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  433. }
  434. h1 := HostInfo{
  435. vpnAddrs: []netip.Addr{network.Addr()},
  436. ConnectionState: &ConnectionState{
  437. peerCert: &c1,
  438. },
  439. }
  440. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  441. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  442. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  443. cp := cert.NewCAPool()
  444. // h1/c1 lacks the proper groups
  445. require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  446. // c has the proper groups
  447. resetConntrack(fw)
  448. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  449. }
  450. func TestFirewall_Drop3(t *testing.T) {
  451. l := test.NewLogger()
  452. ob := &bytes.Buffer{}
  453. l.SetOutput(ob)
  454. myVpnNetworksTable := new(bart.Lite)
  455. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  456. p := firewall.Packet{
  457. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  458. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  459. LocalPort: 1,
  460. RemotePort: 1,
  461. Protocol: firewall.ProtoUDP,
  462. Fragment: false,
  463. }
  464. network := netip.MustParsePrefix("1.2.3.4/24")
  465. c := cert.CachedCertificate{
  466. Certificate: &dummyCert{
  467. name: "host-owner",
  468. networks: []netip.Prefix{network},
  469. },
  470. }
  471. c1 := cert.CachedCertificate{
  472. Certificate: &dummyCert{
  473. name: "host1",
  474. networks: []netip.Prefix{network},
  475. issuer: "signer-sha-bad",
  476. },
  477. }
  478. h1 := HostInfo{
  479. ConnectionState: &ConnectionState{
  480. peerCert: &c1,
  481. },
  482. vpnAddrs: []netip.Addr{network.Addr()},
  483. }
  484. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  485. c2 := cert.CachedCertificate{
  486. Certificate: &dummyCert{
  487. name: "host2",
  488. networks: []netip.Prefix{network},
  489. issuer: "signer-sha",
  490. },
  491. }
  492. h2 := HostInfo{
  493. ConnectionState: &ConnectionState{
  494. peerCert: &c2,
  495. },
  496. vpnAddrs: []netip.Addr{network.Addr()},
  497. }
  498. h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
  499. c3 := cert.CachedCertificate{
  500. Certificate: &dummyCert{
  501. name: "host3",
  502. networks: []netip.Prefix{network},
  503. issuer: "signer-sha-bad",
  504. },
  505. }
  506. h3 := HostInfo{
  507. ConnectionState: &ConnectionState{
  508. peerCert: &c3,
  509. },
  510. vpnAddrs: []netip.Addr{network.Addr()},
  511. }
  512. h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
  513. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  514. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
  515. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
  516. cp := cert.NewCAPool()
  517. // c1 should pass because host match
  518. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  519. // c2 should pass because ca sha match
  520. resetConntrack(fw)
  521. require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  522. // c3 should fail because no match
  523. resetConntrack(fw)
  524. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  525. // Test a remote address match
  526. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  527. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
  528. require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  529. }
  530. func TestFirewall_Drop3V6(t *testing.T) {
  531. l := test.NewLogger()
  532. ob := &bytes.Buffer{}
  533. l.SetOutput(ob)
  534. myVpnNetworksTable := new(bart.Lite)
  535. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  536. p := firewall.Packet{
  537. LocalAddr: netip.MustParseAddr("fd12::34"),
  538. RemoteAddr: netip.MustParseAddr("fd12::34"),
  539. LocalPort: 1,
  540. RemotePort: 1,
  541. Protocol: firewall.ProtoUDP,
  542. Fragment: false,
  543. }
  544. network := netip.MustParsePrefix("fd12::34/120")
  545. c := cert.CachedCertificate{
  546. Certificate: &dummyCert{
  547. name: "host-owner",
  548. networks: []netip.Prefix{network},
  549. },
  550. }
  551. h := HostInfo{
  552. ConnectionState: &ConnectionState{
  553. peerCert: &c,
  554. },
  555. vpnAddrs: []netip.Addr{network.Addr()},
  556. }
  557. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  558. // Test a remote address match
  559. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  560. cp := cert.NewCAPool()
  561. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("fd12::34/120"), netip.Prefix{}, "", ""))
  562. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  563. }
  564. func TestFirewall_DropConntrackReload(t *testing.T) {
  565. l := test.NewLogger()
  566. ob := &bytes.Buffer{}
  567. l.SetOutput(ob)
  568. myVpnNetworksTable := new(bart.Lite)
  569. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  570. p := firewall.Packet{
  571. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  572. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  573. LocalPort: 10,
  574. RemotePort: 90,
  575. Protocol: firewall.ProtoUDP,
  576. Fragment: false,
  577. }
  578. network := netip.MustParsePrefix("1.2.3.4/24")
  579. c := cert.CachedCertificate{
  580. Certificate: &dummyCert{
  581. name: "host1",
  582. networks: []netip.Prefix{network},
  583. groups: []string{"default-group"},
  584. issuer: "signer-shasum",
  585. },
  586. InvertedGroups: map[string]struct{}{"default-group": {}},
  587. }
  588. h := HostInfo{
  589. ConnectionState: &ConnectionState{
  590. peerCert: &c,
  591. },
  592. vpnAddrs: []netip.Addr{network.Addr()},
  593. }
  594. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  595. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  596. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  597. cp := cert.NewCAPool()
  598. // Drop outbound
  599. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  600. // Allow inbound
  601. resetConntrack(fw)
  602. require.NoError(t, fw.Drop(p, true, &h, cp, nil))
  603. // Allow outbound because conntrack
  604. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  605. oldFw := fw
  606. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  607. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  608. fw.Conntrack = oldFw.Conntrack
  609. fw.rulesVersion = oldFw.rulesVersion + 1
  610. // Allow outbound because conntrack and new rules allow port 10
  611. require.NoError(t, fw.Drop(p, false, &h, cp, nil))
  612. oldFw = fw
  613. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  614. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  615. fw.Conntrack = oldFw.Conntrack
  616. fw.rulesVersion = oldFw.rulesVersion + 1
  617. // Drop outbound because conntrack doesn't match new ruleset
  618. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  619. }
  620. func TestFirewall_DropIPSpoofing(t *testing.T) {
  621. l := test.NewLogger()
  622. ob := &bytes.Buffer{}
  623. l.SetOutput(ob)
  624. myVpnNetworksTable := new(bart.Lite)
  625. myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
  626. c := cert.CachedCertificate{
  627. Certificate: &dummyCert{
  628. name: "host-owner",
  629. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
  630. },
  631. }
  632. c1 := cert.CachedCertificate{
  633. Certificate: &dummyCert{
  634. name: "host",
  635. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
  636. unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
  637. },
  638. }
  639. h1 := HostInfo{
  640. ConnectionState: &ConnectionState{
  641. peerCert: &c1,
  642. },
  643. vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
  644. }
  645. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  646. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
  647. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  648. cp := cert.NewCAPool()
  649. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
  650. p := firewall.Packet{
  651. LocalAddr: netip.MustParseAddr("192.0.2.1"),
  652. RemoteAddr: netip.MustParseAddr("192.0.2.3"),
  653. LocalPort: 1,
  654. RemotePort: 1,
  655. Protocol: firewall.ProtoUDP,
  656. Fragment: false,
  657. }
  658. assert.Equal(t, fw.Drop(p, true, &h1, cp, nil), ErrInvalidRemoteIP)
  659. }
  660. func BenchmarkLookup(b *testing.B) {
  661. ml := func(m map[string]struct{}, a [][]string) {
  662. for n := 0; n < b.N; n++ {
  663. for _, sg := range a {
  664. found := false
  665. for _, g := range sg {
  666. if _, ok := m[g]; !ok {
  667. found = false
  668. break
  669. }
  670. found = true
  671. }
  672. if found {
  673. return
  674. }
  675. }
  676. }
  677. }
  678. b.Run("array to map best", func(b *testing.B) {
  679. m := map[string]struct{}{
  680. "1ne": {},
  681. "2wo": {},
  682. "3hr": {},
  683. "4ou": {},
  684. "5iv": {},
  685. "6ix": {},
  686. }
  687. a := [][]string{
  688. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  689. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  690. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  691. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  692. {"one", "two", "thr", "fou", "5iv", "6ix"},
  693. {"one", "two", "thr", "fou", "fiv", "6ix"},
  694. {"one", "two", "thr", "fou", "fiv", "six"},
  695. }
  696. for n := 0; n < b.N; n++ {
  697. ml(m, a)
  698. }
  699. })
  700. b.Run("array to map worst", func(b *testing.B) {
  701. m := map[string]struct{}{
  702. "one": {},
  703. "two": {},
  704. "thr": {},
  705. "fou": {},
  706. "fiv": {},
  707. "six": {},
  708. }
  709. a := [][]string{
  710. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  711. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  712. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  713. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  714. {"one", "two", "thr", "fou", "5iv", "6ix"},
  715. {"one", "two", "thr", "fou", "fiv", "6ix"},
  716. {"one", "two", "thr", "fou", "fiv", "six"},
  717. }
  718. for n := 0; n < b.N; n++ {
  719. ml(m, a)
  720. }
  721. })
  722. }
  723. func Test_parsePort(t *testing.T) {
  724. _, _, err := parsePort("")
  725. require.EqualError(t, err, "was not a number; ``")
  726. _, _, err = parsePort(" ")
  727. require.EqualError(t, err, "was not a number; ` `")
  728. _, _, err = parsePort("-")
  729. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  730. _, _, err = parsePort(" - ")
  731. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  732. _, _, err = parsePort("a-b")
  733. require.EqualError(t, err, "beginning range was not a number; `a`")
  734. _, _, err = parsePort("1-b")
  735. require.EqualError(t, err, "ending range was not a number; `b`")
  736. s, e, err := parsePort(" 1 - 2 ")
  737. assert.Equal(t, int32(1), s)
  738. assert.Equal(t, int32(2), e)
  739. require.NoError(t, err)
  740. s, e, err = parsePort("0-1")
  741. assert.Equal(t, int32(0), s)
  742. assert.Equal(t, int32(0), e)
  743. require.NoError(t, err)
  744. s, e, err = parsePort("9919")
  745. assert.Equal(t, int32(9919), s)
  746. assert.Equal(t, int32(9919), e)
  747. require.NoError(t, err)
  748. s, e, err = parsePort("any")
  749. assert.Equal(t, int32(0), s)
  750. assert.Equal(t, int32(0), e)
  751. require.NoError(t, err)
  752. }
  753. func TestNewFirewallFromConfig(t *testing.T) {
  754. l := test.NewLogger()
  755. // Test a bad rule definition
  756. c := &dummyCert{}
  757. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  758. require.NoError(t, err)
  759. conf := config.NewC(l)
  760. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  761. _, err = NewFirewallFromConfig(l, cs, conf)
  762. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  763. // Test both port and code
  764. conf = config.NewC(l)
  765. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  766. _, err = NewFirewallFromConfig(l, cs, conf)
  767. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  768. // Test missing host, group, cidr, ca_name and ca_sha
  769. conf = config.NewC(l)
  770. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  771. _, err = NewFirewallFromConfig(l, cs, conf)
  772. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  773. // Test code/port error
  774. conf = config.NewC(l)
  775. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh"}}}
  776. _, err = NewFirewallFromConfig(l, cs, conf)
  777. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  778. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh"}}}
  779. _, err = NewFirewallFromConfig(l, cs, conf)
  780. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  781. // Test proto error
  782. conf = config.NewC(l)
  783. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  784. _, err = NewFirewallFromConfig(l, cs, conf)
  785. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  786. // Test cidr parse error
  787. conf = config.NewC(l)
  788. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  789. _, err = NewFirewallFromConfig(l, cs, conf)
  790. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  791. // Test local_cidr parse error
  792. conf = config.NewC(l)
  793. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  794. _, err = NewFirewallFromConfig(l, cs, conf)
  795. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  796. // Test both group and groups
  797. conf = config.NewC(l)
  798. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  799. _, err = NewFirewallFromConfig(l, cs, conf)
  800. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  801. }
  802. func TestAddFirewallRulesFromConfig(t *testing.T) {
  803. l := test.NewLogger()
  804. // Test adding tcp rule
  805. conf := config.NewC(l)
  806. mf := &mockFirewall{}
  807. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  808. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  809. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  810. // Test adding udp rule
  811. conf = config.NewC(l)
  812. mf = &mockFirewall{}
  813. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  814. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  815. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  816. // Test adding icmp rule
  817. conf = config.NewC(l)
  818. mf = &mockFirewall{}
  819. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  820. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  821. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  822. // Test adding any rule
  823. conf = config.NewC(l)
  824. mf = &mockFirewall{}
  825. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  826. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  827. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  828. // Test adding rule with cidr
  829. cidr := netip.MustParsePrefix("10.0.0.0/8")
  830. conf = config.NewC(l)
  831. mf = &mockFirewall{}
  832. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  833. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  834. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
  835. // Test adding rule with local_cidr
  836. conf = config.NewC(l)
  837. mf = &mockFirewall{}
  838. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  839. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  840. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
  841. // Test adding rule with cidr ipv6
  842. cidr6 := netip.MustParsePrefix("fd00::/8")
  843. conf = config.NewC(l)
  844. mf = &mockFirewall{}
  845. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
  846. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  847. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6, localIp: netip.Prefix{}}, mf.lastCall)
  848. // Test adding rule with local_cidr ipv6
  849. conf = config.NewC(l)
  850. mf = &mockFirewall{}
  851. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
  852. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  853. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr6}, mf.lastCall)
  854. // Test adding rule with ca_sha
  855. conf = config.NewC(l)
  856. mf = &mockFirewall{}
  857. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  858. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  859. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
  860. // Test adding rule with ca_name
  861. conf = config.NewC(l)
  862. mf = &mockFirewall{}
  863. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  864. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  865. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
  866. // Test single group
  867. conf = config.NewC(l)
  868. mf = &mockFirewall{}
  869. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  870. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  871. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  872. // Test single groups
  873. conf = config.NewC(l)
  874. mf = &mockFirewall{}
  875. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  876. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  877. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  878. // Test multiple AND groups
  879. conf = config.NewC(l)
  880. mf = &mockFirewall{}
  881. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  882. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  883. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
  884. // Test Add error
  885. conf = config.NewC(l)
  886. mf = &mockFirewall{}
  887. mf.nextCallReturn = errors.New("test error")
  888. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  889. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  890. }
  891. func TestFirewall_convertRule(t *testing.T) {
  892. l := test.NewLogger()
  893. ob := &bytes.Buffer{}
  894. l.SetOutput(ob)
  895. // Ensure group array of 1 is converted and a warning is printed
  896. c := map[string]any{
  897. "group": []any{"group1"},
  898. }
  899. r, err := convertRule(l, c, "test", 1)
  900. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  901. require.NoError(t, err)
  902. assert.Equal(t, []string{"group1"}, r.Groups)
  903. // Ensure group array of > 1 is errord
  904. ob.Reset()
  905. c = map[string]any{
  906. "group": []any{"group1", "group2"},
  907. }
  908. r, err = convertRule(l, c, "test", 1)
  909. assert.Empty(t, ob.String())
  910. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  911. // Make sure a well formed group is alright
  912. ob.Reset()
  913. c = map[string]any{
  914. "group": "group1",
  915. }
  916. r, err = convertRule(l, c, "test", 1)
  917. require.NoError(t, err)
  918. assert.Equal(t, []string{"group1"}, r.Groups)
  919. }
  920. func TestFirewall_convertRuleSanity(t *testing.T) {
  921. l := test.NewLogger()
  922. ob := &bytes.Buffer{}
  923. l.SetOutput(ob)
  924. noWarningPlease := []map[string]any{
  925. {"group": "group1"},
  926. {"groups": []any{"group2"}},
  927. {"host": "bob"},
  928. {"cidr": "1.1.1.1/1"},
  929. {"groups": []any{"group2"}, "host": "bob"},
  930. {"cidr": "1.1.1.1/1", "host": "bob"},
  931. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  932. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  933. }
  934. for _, c := range noWarningPlease {
  935. r, err := convertRule(l, c, "test", 1)
  936. require.NoError(t, err)
  937. require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
  938. }
  939. yesWarningPlease := []map[string]any{
  940. {"group": "group1"},
  941. {"groups": []any{"group2"}},
  942. {"cidr": "1.1.1.1/1"},
  943. {"groups": []any{"group2"}, "host": "bob"},
  944. {"cidr": "1.1.1.1/1", "host": "bob"},
  945. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  946. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  947. }
  948. for _, c := range yesWarningPlease {
  949. c["host"] = "any"
  950. r, err := convertRule(l, c, "test", 1)
  951. require.NoError(t, err)
  952. err = r.sanity()
  953. require.Error(t, err, "I wanted a warning: %+v", c)
  954. }
  955. //reset the list
  956. yesWarningPlease = []map[string]any{
  957. {"group": "group1"},
  958. {"groups": []any{"group2"}},
  959. {"cidr": "1.1.1.1/1"},
  960. {"groups": []any{"group2"}, "host": "bob"},
  961. {"cidr": "1.1.1.1/1", "host": "bob"},
  962. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  963. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  964. }
  965. for _, c := range yesWarningPlease {
  966. r, err := convertRule(l, c, "test", 1)
  967. require.NoError(t, err)
  968. r.Groups = append(r.Groups, "any")
  969. err = r.sanity()
  970. require.Error(t, err, "I wanted a warning: %+v", c)
  971. }
  972. }
  973. type testcase struct {
  974. h *HostInfo
  975. p firewall.Packet
  976. c cert.Certificate
  977. err error
  978. }
  979. func (c *testcase) Test(t *testing.T, fw *Firewall) {
  980. t.Helper()
  981. cp := cert.NewCAPool()
  982. resetConntrack(fw)
  983. err := fw.Drop(c.p, true, c.h, cp, nil)
  984. if c.err == nil {
  985. require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
  986. } else {
  987. require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
  988. }
  989. }
  990. func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
  991. c1 := dummyCert{
  992. name: "host1",
  993. networks: theirPrefixes,
  994. groups: []string{"default-group"},
  995. issuer: "signer-shasum",
  996. }
  997. h := HostInfo{
  998. ConnectionState: &ConnectionState{
  999. peerCert: &cert.CachedCertificate{
  1000. Certificate: &c1,
  1001. InvertedGroups: map[string]struct{}{"default-group": {}},
  1002. },
  1003. },
  1004. vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
  1005. }
  1006. for i := range theirPrefixes {
  1007. h.vpnAddrs[i] = theirPrefixes[i].Addr()
  1008. }
  1009. h.buildNetworks(setup.myVpnNetworksTable, &c1)
  1010. p := firewall.Packet{
  1011. LocalAddr: setup.c.Networks()[0].Addr(), //todo?
  1012. RemoteAddr: theirPrefixes[0].Addr(),
  1013. LocalPort: 10,
  1014. RemotePort: 90,
  1015. Protocol: firewall.ProtoUDP,
  1016. Fragment: false,
  1017. }
  1018. return testcase{
  1019. h: &h,
  1020. p: p,
  1021. c: &c1,
  1022. err: err,
  1023. }
  1024. }
  1025. type testsetup struct {
  1026. c dummyCert
  1027. myVpnNetworksTable *bart.Lite
  1028. fw *Firewall
  1029. }
  1030. func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
  1031. c := dummyCert{
  1032. name: "me",
  1033. networks: myPrefixes,
  1034. groups: []string{"default-group"},
  1035. issuer: "signer-shasum",
  1036. }
  1037. return newSetupFromCert(t, l, c)
  1038. }
  1039. func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
  1040. myVpnNetworksTable := new(bart.Lite)
  1041. for _, prefix := range c.Networks() {
  1042. myVpnNetworksTable.Insert(prefix)
  1043. }
  1044. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  1045. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
  1046. return testsetup{
  1047. c: c,
  1048. fw: fw,
  1049. myVpnNetworksTable: myVpnNetworksTable,
  1050. }
  1051. }
  1052. func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
  1053. t.Parallel()
  1054. l := test.NewLogger()
  1055. ob := &bytes.Buffer{}
  1056. l.SetOutput(ob)
  1057. myPrefix := netip.MustParsePrefix("1.1.1.1/8")
  1058. // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
  1059. t.Run("allow inbound all matching", func(t *testing.T) {
  1060. t.Parallel()
  1061. setup := newSetup(t, l, myPrefix)
  1062. tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
  1063. tc.Test(t, setup.fw)
  1064. })
  1065. t.Run("allow inbound local matching", func(t *testing.T) {
  1066. t.Parallel()
  1067. setup := newSetup(t, l, myPrefix)
  1068. tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
  1069. tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
  1070. tc.Test(t, setup.fw)
  1071. })
  1072. t.Run("block inbound remote mismatched", func(t *testing.T) {
  1073. t.Parallel()
  1074. setup := newSetup(t, l, myPrefix)
  1075. tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
  1076. tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
  1077. tc.Test(t, setup.fw)
  1078. })
  1079. t.Run("Block a vpn peer packet", func(t *testing.T) {
  1080. t.Parallel()
  1081. setup := newSetup(t, l, myPrefix)
  1082. tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
  1083. tc.Test(t, setup.fw)
  1084. })
  1085. twoPrefixes := []netip.Prefix{
  1086. netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
  1087. }
  1088. t.Run("allow inbound one matching", func(t *testing.T) {
  1089. t.Parallel()
  1090. setup := newSetup(t, l, myPrefix)
  1091. tc := buildTestCase(setup, nil, twoPrefixes...)
  1092. tc.Test(t, setup.fw)
  1093. })
  1094. t.Run("block inbound multimismatch", func(t *testing.T) {
  1095. t.Parallel()
  1096. setup := newSetup(t, l, myPrefix)
  1097. tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
  1098. tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
  1099. tc.Test(t, setup.fw)
  1100. })
  1101. t.Run("allow inbound 2nd one matching", func(t *testing.T) {
  1102. t.Parallel()
  1103. setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
  1104. tc := buildTestCase(setup2, nil, twoPrefixes...)
  1105. tc.p.RemoteAddr = twoPrefixes[1].Addr()
  1106. tc.Test(t, setup2.fw)
  1107. })
  1108. t.Run("allow inbound unsafe route", func(t *testing.T) {
  1109. t.Parallel()
  1110. unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
  1111. c := dummyCert{
  1112. name: "me",
  1113. networks: []netip.Prefix{myPrefix},
  1114. unsafeNetworks: []netip.Prefix{unsafePrefix},
  1115. groups: []string{"default-group"},
  1116. issuer: "signer-shasum",
  1117. }
  1118. unsafeSetup := newSetupFromCert(t, l, c)
  1119. tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
  1120. tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
  1121. tc.err = ErrNoMatchingRule
  1122. tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
  1123. require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, unsafePrefix, "", ""))
  1124. tc.err = nil
  1125. tc.Test(t, unsafeSetup.fw) //should pass
  1126. })
  1127. }
  1128. type addRuleCall struct {
  1129. incoming bool
  1130. proto uint8
  1131. startPort int32
  1132. endPort int32
  1133. groups []string
  1134. host string
  1135. ip netip.Prefix
  1136. localIp netip.Prefix
  1137. caName string
  1138. caSha string
  1139. }
  1140. type mockFirewall struct {
  1141. lastCall addRuleCall
  1142. nextCallReturn error
  1143. }
  1144. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip netip.Prefix, localIp netip.Prefix, caName string, caSha string) error {
  1145. mf.lastCall = addRuleCall{
  1146. incoming: incoming,
  1147. proto: proto,
  1148. startPort: startPort,
  1149. endPort: endPort,
  1150. groups: groups,
  1151. host: host,
  1152. ip: ip,
  1153. localIp: localIp,
  1154. caName: caName,
  1155. caSha: caSha,
  1156. }
  1157. err := mf.nextCallReturn
  1158. mf.nextCallReturn = nil
  1159. return err
  1160. }
  1161. func resetConntrack(fw *Firewall) {
  1162. fw.Conntrack.Lock()
  1163. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  1164. fw.Conntrack.Unlock()
  1165. }