firewall_test.go 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net/netip"
  7. "testing"
  8. "time"
  9. "github.com/gaissmai/bart"
  10. "github.com/sirupsen/logrus"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/firewall"
  14. "github.com/slackhq/nebula/test"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. )
  18. func TestNewFirewall(t *testing.T) {
  19. l := test.NewLogger()
  20. c := &dummyCert{}
  21. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  22. conntrack := fw.Conntrack
  23. assert.NotNil(t, conntrack)
  24. assert.NotNil(t, conntrack.Conns)
  25. assert.NotNil(t, conntrack.TimerWheel)
  26. assert.NotNil(t, fw.InRules)
  27. assert.NotNil(t, fw.OutRules)
  28. assert.Equal(t, time.Second, fw.TCPTimeout)
  29. assert.Equal(t, time.Minute, fw.UDPTimeout)
  30. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c, netip.Addr{})
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c, netip.Addr{})
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c, netip.Addr{})
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c, netip.Addr{})
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  46. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c, netip.Addr{})
  47. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  48. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  49. }
  50. func TestFirewall_AddRule(t *testing.T) {
  51. l := test.NewLogger()
  52. ob := &bytes.Buffer{}
  53. l.SetOutput(ob)
  54. c := &dummyCert{}
  55. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  56. assert.NotNil(t, fw.InRules)
  57. assert.NotNil(t, fw.OutRules)
  58. ti, err := netip.ParsePrefix("1.2.3.4/32")
  59. require.NoError(t, err)
  60. ti6, err := netip.ParsePrefix("fd12::34/128")
  61. require.NoError(t, err)
  62. require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", "", "", "", ""))
  63. // An empty rule is any
  64. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  65. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  66. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  67. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  68. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", ""))
  69. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  70. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  71. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  72. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  73. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", "", "", "", ""))
  74. //no matter what port is given for icmp, it should end up as "any"
  75. assert.Nil(t, fw.InRules.ICMP[firewall.PortAny].Any.Any)
  76. assert.Empty(t, fw.InRules.ICMP[firewall.PortAny].Any.Groups)
  77. assert.Contains(t, fw.InRules.ICMP[firewall.PortAny].Any.Hosts, "h1")
  78. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  79. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti.String(), "", "", ""))
  80. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  81. _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
  82. assert.True(t, ok)
  83. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  84. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti6.String(), "", "", ""))
  85. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  86. _, ok = fw.OutRules.AnyProto[1].Any.CIDR.Get(ti6)
  87. assert.True(t, ok)
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  89. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti.String(), "", ""))
  90. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  91. ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
  92. assert.True(t, ok)
  93. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  94. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", "", ti6.String(), "", ""))
  95. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  96. ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti6)
  97. assert.True(t, ok)
  98. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  99. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "ca-name", ""))
  100. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  101. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  102. require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", "", "", "", "ca-sha"))
  103. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  104. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  105. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", "", "", "", ""))
  106. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  107. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  108. anyIp, err := netip.ParsePrefix("0.0.0.0/0")
  109. require.NoError(t, err)
  110. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp.String(), "", "", ""))
  111. assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
  112. table, ok := fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
  113. assert.True(t, table.Any)
  114. table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
  115. assert.False(t, ok)
  116. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  117. anyIp6, err := netip.ParsePrefix("::/0")
  118. require.NoError(t, err)
  119. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp6.String(), "", "", ""))
  120. assert.Nil(t, fw.OutRules.AnyProto[0].Any.Any)
  121. table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("9::9"))
  122. assert.True(t, table.Any)
  123. table, ok = fw.OutRules.AnyProto[0].Any.CIDR.Lookup(netip.MustParseAddr("1.1.1.1"))
  124. assert.False(t, ok)
  125. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  126. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "any", "", "", ""))
  127. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  128. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  129. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp.String(), "", ""))
  130. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  131. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
  132. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
  133. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  134. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", anyIp6.String(), "", ""))
  135. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  136. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("9::9")))
  137. assert.False(t, fw.OutRules.AnyProto[0].Any.Any.LocalCIDR.Lookup(netip.MustParseAddr("1.1.1.1")))
  138. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  139. require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", "", "any", "", ""))
  140. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  141. // Test error conditions
  142. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c, netip.Addr{})
  143. require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", "", "", "", ""))
  144. require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", "", "", "", ""))
  145. }
  146. func TestFirewall_Drop(t *testing.T) {
  147. l := test.NewLogger()
  148. ob := &bytes.Buffer{}
  149. l.SetOutput(ob)
  150. myVpnNetworksTable := new(bart.Lite)
  151. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  152. p := firewall.Packet{
  153. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  154. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  155. LocalPort: 10,
  156. RemotePort: 90,
  157. Protocol: firewall.ProtoUDP,
  158. Fragment: false,
  159. }
  160. c := dummyCert{
  161. name: "host1",
  162. networks: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/24")},
  163. groups: []string{"default-group"},
  164. issuer: "signer-shasum",
  165. }
  166. h := HostInfo{
  167. ConnectionState: &ConnectionState{
  168. peerCert: &cert.CachedCertificate{
  169. Certificate: &c,
  170. InvertedGroups: map[string]struct{}{"default-group": {}},
  171. },
  172. },
  173. vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")},
  174. }
  175. h.buildNetworks(myVpnNetworksTable, &c)
  176. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  177. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  178. cp := cert.NewCAPool()
  179. // Drop outbound
  180. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
  181. // Allow inbound
  182. resetConntrack(fw)
  183. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  184. // Allow outbound because conntrack
  185. require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
  186. // test remote mismatch
  187. oldRemote := p.RemoteAddr
  188. p.RemoteAddr = netip.MustParseAddr("1.2.3.10")
  189. assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
  190. p.RemoteAddr = oldRemote
  191. // ensure signer doesn't get in the way of group checks
  192. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  193. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
  194. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
  195. assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  196. // test caSha doesn't drop on match
  197. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  198. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
  199. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
  200. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  201. // ensure ca name doesn't get in the way of group checks
  202. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  203. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  204. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
  205. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
  206. assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  207. // test caName doesn't drop on match
  208. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  209. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  210. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
  211. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
  212. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  213. }
  214. func TestFirewall_DropV6(t *testing.T) {
  215. l := test.NewLogger()
  216. ob := &bytes.Buffer{}
  217. l.SetOutput(ob)
  218. myVpnNetworksTable := new(bart.Lite)
  219. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  220. p := firewall.Packet{
  221. LocalAddr: netip.MustParseAddr("fd12::34"),
  222. RemoteAddr: netip.MustParseAddr("fd12::34"),
  223. LocalPort: 10,
  224. RemotePort: 90,
  225. Protocol: firewall.ProtoUDP,
  226. Fragment: false,
  227. }
  228. c := dummyCert{
  229. name: "host1",
  230. networks: []netip.Prefix{netip.MustParsePrefix("fd12::34/120")},
  231. groups: []string{"default-group"},
  232. issuer: "signer-shasum",
  233. }
  234. h := HostInfo{
  235. ConnectionState: &ConnectionState{
  236. peerCert: &cert.CachedCertificate{
  237. Certificate: &c,
  238. InvertedGroups: map[string]struct{}{"default-group": {}},
  239. },
  240. },
  241. vpnAddrs: []netip.Addr{netip.MustParseAddr("fd12::34")},
  242. }
  243. h.buildNetworks(myVpnNetworksTable, &c)
  244. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  245. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  246. cp := cert.NewCAPool()
  247. // Drop outbound
  248. assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, nil, false, &h, cp, nil))
  249. // Allow inbound
  250. resetConntrack(fw)
  251. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  252. // Allow outbound because conntrack
  253. require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
  254. // test remote mismatch
  255. oldRemote := p.RemoteAddr
  256. p.RemoteAddr = netip.MustParseAddr("fd12::56")
  257. assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrInvalidRemoteIP)
  258. p.RemoteAddr = oldRemote
  259. // ensure signer doesn't get in the way of group checks
  260. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  261. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum"))
  262. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum-bad"))
  263. assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  264. // test caSha doesn't drop on match
  265. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  266. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "", "signer-shasum-bad"))
  267. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "", "signer-shasum"))
  268. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  269. // ensure ca name doesn't get in the way of group checks
  270. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  271. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  272. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good", ""))
  273. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good-bad", ""))
  274. assert.Equal(t, fw.Drop(p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  275. // test caName doesn't drop on match
  276. cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
  277. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  278. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", "", "", "ca-good-bad", ""))
  279. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", "", "", "ca-good", ""))
  280. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  281. }
  282. func BenchmarkFirewallTable_match(b *testing.B) {
  283. f := &Firewall{}
  284. ft := FirewallTable{
  285. TCP: firewallPort{},
  286. }
  287. pfix := netip.MustParsePrefix("172.1.1.1/32")
  288. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix.String(), "", "", "")
  289. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix.String(), "", "")
  290. pfix6 := netip.MustParsePrefix("fd11::11/128")
  291. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", pfix6.String(), "", "", "")
  292. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", "", pfix6.String(), "", "")
  293. cp := cert.NewCAPool()
  294. b.Run("fail on proto", func(b *testing.B) {
  295. // This benchmark is showing us the cost of failing to match the protocol
  296. c := &cert.CachedCertificate{
  297. Certificate: &dummyCert{},
  298. }
  299. for n := 0; n < b.N; n++ {
  300. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  301. }
  302. })
  303. b.Run("pass proto, fail on port", func(b *testing.B) {
  304. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  305. c := &cert.CachedCertificate{
  306. Certificate: &dummyCert{},
  307. }
  308. for n := 0; n < b.N; n++ {
  309. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  310. }
  311. })
  312. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  313. c := &cert.CachedCertificate{
  314. Certificate: &dummyCert{},
  315. }
  316. ip := netip.MustParsePrefix("9.254.254.254/32")
  317. for n := 0; n < b.N; n++ {
  318. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  319. }
  320. })
  321. b.Run("pass proto, port, fail on local CIDRv6", func(b *testing.B) {
  322. c := &cert.CachedCertificate{
  323. Certificate: &dummyCert{},
  324. }
  325. ip := netip.MustParsePrefix("fd99::99/128")
  326. for n := 0; n < b.N; n++ {
  327. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp))
  328. }
  329. })
  330. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  331. c := &cert.CachedCertificate{
  332. Certificate: &dummyCert{
  333. name: "nope",
  334. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  335. },
  336. InvertedGroups: map[string]struct{}{"nope": {}},
  337. }
  338. for n := 0; n < b.N; n++ {
  339. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  340. }
  341. })
  342. b.Run("pass proto, port, any local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  343. c := &cert.CachedCertificate{
  344. Certificate: &dummyCert{
  345. name: "nope",
  346. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  347. },
  348. InvertedGroups: map[string]struct{}{"nope": {}},
  349. }
  350. for n := 0; n < b.N; n++ {
  351. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  352. }
  353. })
  354. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  355. c := &cert.CachedCertificate{
  356. Certificate: &dummyCert{
  357. name: "nope",
  358. networks: []netip.Prefix{netip.MustParsePrefix("9.254.254.245/32")},
  359. },
  360. InvertedGroups: map[string]struct{}{"nope": {}},
  361. }
  362. for n := 0; n < b.N; n++ {
  363. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  364. }
  365. })
  366. b.Run("pass proto, port, specific local CIDRv6, fail all group, name, and cidr", func(b *testing.B) {
  367. c := &cert.CachedCertificate{
  368. Certificate: &dummyCert{
  369. name: "nope",
  370. networks: []netip.Prefix{netip.MustParsePrefix("fd99::99/128")},
  371. },
  372. InvertedGroups: map[string]struct{}{"nope": {}},
  373. }
  374. for n := 0; n < b.N; n++ {
  375. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  376. }
  377. })
  378. b.Run("pass on group on any local cidr", func(b *testing.B) {
  379. c := &cert.CachedCertificate{
  380. Certificate: &dummyCert{
  381. name: "nope",
  382. },
  383. InvertedGroups: map[string]struct{}{"good-group": {}},
  384. }
  385. for n := 0; n < b.N; n++ {
  386. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  387. }
  388. })
  389. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  390. c := &cert.CachedCertificate{
  391. Certificate: &dummyCert{
  392. name: "nope",
  393. },
  394. InvertedGroups: map[string]struct{}{"good-group": {}},
  395. }
  396. for n := 0; n < b.N; n++ {
  397. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp))
  398. }
  399. })
  400. b.Run("pass on group on specific local cidr6", func(b *testing.B) {
  401. c := &cert.CachedCertificate{
  402. Certificate: &dummyCert{
  403. name: "nope",
  404. },
  405. InvertedGroups: map[string]struct{}{"good-group": {}},
  406. }
  407. for n := 0; n < b.N; n++ {
  408. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix6.Addr()}, true, c, cp))
  409. }
  410. })
  411. b.Run("pass on name", func(b *testing.B) {
  412. c := &cert.CachedCertificate{
  413. Certificate: &dummyCert{
  414. name: "good-host",
  415. },
  416. InvertedGroups: map[string]struct{}{"nope": {}},
  417. }
  418. for n := 0; n < b.N; n++ {
  419. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  420. }
  421. })
  422. }
  423. func TestFirewall_Drop2(t *testing.T) {
  424. l := test.NewLogger()
  425. ob := &bytes.Buffer{}
  426. l.SetOutput(ob)
  427. myVpnNetworksTable := new(bart.Lite)
  428. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  429. p := firewall.Packet{
  430. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  431. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  432. LocalPort: 10,
  433. RemotePort: 90,
  434. Protocol: firewall.ProtoUDP,
  435. Fragment: false,
  436. }
  437. network := netip.MustParsePrefix("1.2.3.4/24")
  438. c := cert.CachedCertificate{
  439. Certificate: &dummyCert{
  440. name: "host1",
  441. networks: []netip.Prefix{network},
  442. },
  443. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  444. }
  445. h := HostInfo{
  446. ConnectionState: &ConnectionState{
  447. peerCert: &c,
  448. },
  449. vpnAddrs: []netip.Addr{network.Addr()},
  450. }
  451. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  452. c1 := cert.CachedCertificate{
  453. Certificate: &dummyCert{
  454. name: "host1",
  455. networks: []netip.Prefix{network},
  456. },
  457. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  458. }
  459. h1 := HostInfo{
  460. vpnAddrs: []netip.Addr{network.Addr()},
  461. ConnectionState: &ConnectionState{
  462. peerCert: &c1,
  463. },
  464. }
  465. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  466. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  467. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", "", "", "", ""))
  468. cp := cert.NewCAPool()
  469. // h1/c1 lacks the proper groups
  470. require.ErrorIs(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrNoMatchingRule)
  471. // c has the proper groups
  472. resetConntrack(fw)
  473. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  474. }
  475. func TestFirewall_Drop3(t *testing.T) {
  476. l := test.NewLogger()
  477. ob := &bytes.Buffer{}
  478. l.SetOutput(ob)
  479. myVpnNetworksTable := new(bart.Lite)
  480. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  481. p := firewall.Packet{
  482. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  483. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  484. LocalPort: 1,
  485. RemotePort: 1,
  486. Protocol: firewall.ProtoUDP,
  487. Fragment: false,
  488. }
  489. network := netip.MustParsePrefix("1.2.3.4/24")
  490. c := cert.CachedCertificate{
  491. Certificate: &dummyCert{
  492. name: "host-owner",
  493. networks: []netip.Prefix{network},
  494. },
  495. }
  496. c1 := cert.CachedCertificate{
  497. Certificate: &dummyCert{
  498. name: "host1",
  499. networks: []netip.Prefix{network},
  500. issuer: "signer-sha-bad",
  501. },
  502. }
  503. h1 := HostInfo{
  504. ConnectionState: &ConnectionState{
  505. peerCert: &c1,
  506. },
  507. vpnAddrs: []netip.Addr{network.Addr()},
  508. }
  509. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  510. c2 := cert.CachedCertificate{
  511. Certificate: &dummyCert{
  512. name: "host2",
  513. networks: []netip.Prefix{network},
  514. issuer: "signer-sha",
  515. },
  516. }
  517. h2 := HostInfo{
  518. ConnectionState: &ConnectionState{
  519. peerCert: &c2,
  520. },
  521. vpnAddrs: []netip.Addr{network.Addr()},
  522. }
  523. h2.buildNetworks(myVpnNetworksTable, c2.Certificate)
  524. c3 := cert.CachedCertificate{
  525. Certificate: &dummyCert{
  526. name: "host3",
  527. networks: []netip.Prefix{network},
  528. issuer: "signer-sha-bad",
  529. },
  530. }
  531. h3 := HostInfo{
  532. ConnectionState: &ConnectionState{
  533. peerCert: &c3,
  534. },
  535. vpnAddrs: []netip.Addr{network.Addr()},
  536. }
  537. h3.buildNetworks(myVpnNetworksTable, c3.Certificate)
  538. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  539. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", "", "", "", ""))
  540. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", "signer-sha"))
  541. cp := cert.NewCAPool()
  542. // c1 should pass because host match
  543. require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
  544. // c2 should pass because ca sha match
  545. resetConntrack(fw)
  546. require.NoError(t, fw.Drop(p, nil, true, &h2, cp, nil))
  547. // c3 should fail because no match
  548. resetConntrack(fw)
  549. assert.Equal(t, fw.Drop(p, nil, true, &h3, cp, nil), ErrNoMatchingRule)
  550. // Test a remote address match
  551. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  552. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "1.2.3.4/24", "", "", ""))
  553. require.NoError(t, fw.Drop(p, nil, true, &h1, cp, nil))
  554. }
  555. func TestFirewall_Drop3V6(t *testing.T) {
  556. l := test.NewLogger()
  557. ob := &bytes.Buffer{}
  558. l.SetOutput(ob)
  559. myVpnNetworksTable := new(bart.Lite)
  560. myVpnNetworksTable.Insert(netip.MustParsePrefix("fd00::/7"))
  561. p := firewall.Packet{
  562. LocalAddr: netip.MustParseAddr("fd12::34"),
  563. RemoteAddr: netip.MustParseAddr("fd12::34"),
  564. LocalPort: 1,
  565. RemotePort: 1,
  566. Protocol: firewall.ProtoUDP,
  567. Fragment: false,
  568. }
  569. network := netip.MustParsePrefix("fd12::34/120")
  570. c := cert.CachedCertificate{
  571. Certificate: &dummyCert{
  572. name: "host-owner",
  573. networks: []netip.Prefix{network},
  574. },
  575. }
  576. h := HostInfo{
  577. ConnectionState: &ConnectionState{
  578. peerCert: &c,
  579. },
  580. vpnAddrs: []netip.Addr{network.Addr()},
  581. }
  582. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  583. // Test a remote address match
  584. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  585. cp := cert.NewCAPool()
  586. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "fd12::34/120", "", "", ""))
  587. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  588. }
  589. func TestFirewall_DropConntrackReload(t *testing.T) {
  590. l := test.NewLogger()
  591. ob := &bytes.Buffer{}
  592. l.SetOutput(ob)
  593. myVpnNetworksTable := new(bart.Lite)
  594. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  595. p := firewall.Packet{
  596. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  597. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  598. LocalPort: 10,
  599. RemotePort: 90,
  600. Protocol: firewall.ProtoUDP,
  601. Fragment: false,
  602. }
  603. network := netip.MustParsePrefix("1.2.3.4/24")
  604. c := cert.CachedCertificate{
  605. Certificate: &dummyCert{
  606. name: "host1",
  607. networks: []netip.Prefix{network},
  608. groups: []string{"default-group"},
  609. issuer: "signer-shasum",
  610. },
  611. InvertedGroups: map[string]struct{}{"default-group": {}},
  612. }
  613. h := HostInfo{
  614. ConnectionState: &ConnectionState{
  615. peerCert: &c,
  616. },
  617. vpnAddrs: []netip.Addr{network.Addr()},
  618. }
  619. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  620. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  621. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  622. cp := cert.NewCAPool()
  623. // Drop outbound
  624. assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  625. // Allow inbound
  626. resetConntrack(fw)
  627. require.NoError(t, fw.Drop(p, nil, true, &h, cp, nil))
  628. // Allow outbound because conntrack
  629. require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
  630. oldFw := fw
  631. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  632. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", "", "", "", ""))
  633. fw.Conntrack = oldFw.Conntrack
  634. fw.rulesVersion = oldFw.rulesVersion + 1
  635. // Allow outbound because conntrack and new rules allow port 10
  636. require.NoError(t, fw.Drop(p, nil, false, &h, cp, nil))
  637. oldFw = fw
  638. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  639. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", "", "", "", ""))
  640. fw.Conntrack = oldFw.Conntrack
  641. fw.rulesVersion = oldFw.rulesVersion + 1
  642. // Drop outbound because conntrack doesn't match new ruleset
  643. assert.Equal(t, fw.Drop(p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  644. }
  645. func TestFirewall_ICMPPortBehavior(t *testing.T) {
  646. l := test.NewLogger()
  647. ob := &bytes.Buffer{}
  648. l.SetOutput(ob)
  649. myVpnNetworksTable := new(bart.Lite)
  650. myVpnNetworksTable.Insert(netip.MustParsePrefix("1.1.1.1/8"))
  651. network := netip.MustParsePrefix("1.2.3.4/24")
  652. c := cert.CachedCertificate{
  653. Certificate: &dummyCert{
  654. name: "host1",
  655. networks: []netip.Prefix{network},
  656. groups: []string{"default-group"},
  657. issuer: "signer-shasum",
  658. },
  659. InvertedGroups: map[string]struct{}{"default-group": {}},
  660. }
  661. h := HostInfo{
  662. ConnectionState: &ConnectionState{
  663. peerCert: &c,
  664. },
  665. vpnAddrs: []netip.Addr{network.Addr()},
  666. }
  667. h.buildNetworks(myVpnNetworksTable, c.Certificate)
  668. cp := cert.NewCAPool()
  669. templ := firewall.Packet{
  670. LocalAddr: netip.MustParseAddr("1.2.3.4"),
  671. RemoteAddr: netip.MustParseAddr("1.2.3.4"),
  672. Protocol: firewall.ProtoICMP,
  673. Fragment: false,
  674. }
  675. t.Run("ICMP allowed", func(t *testing.T) {
  676. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  677. require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 0, 0, []string{"any"}, "", "", "", "", ""))
  678. t.Run("zero ports", func(t *testing.T) {
  679. p := templ.Copy()
  680. p.LocalPort = 0
  681. p.RemotePort = 0
  682. // Drop outbound
  683. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  684. // Allow inbound
  685. resetConntrack(fw)
  686. require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
  687. //now also allow outbound
  688. require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
  689. })
  690. t.Run("nonzero ports", func(t *testing.T) {
  691. p := templ.Copy()
  692. p.LocalPort = 0xabcd
  693. p.RemotePort = 0x1234
  694. // Drop outbound
  695. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  696. // Allow inbound
  697. resetConntrack(fw)
  698. require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
  699. //now also allow outbound
  700. require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
  701. })
  702. })
  703. t.Run("Any proto, some ports allowed", func(t *testing.T) {
  704. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  705. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 80, 444, []string{"any"}, "", "", "", "", ""))
  706. t.Run("zero ports, still blocked", func(t *testing.T) {
  707. p := templ.Copy()
  708. p.LocalPort = 0
  709. p.RemotePort = 0
  710. // Drop outbound
  711. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  712. // Allow inbound
  713. resetConntrack(fw)
  714. assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  715. //now also allow outbound
  716. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  717. })
  718. t.Run("nonzero ports, still blocked", func(t *testing.T) {
  719. p := templ.Copy()
  720. p.LocalPort = 0xabcd
  721. p.RemotePort = 0x1234
  722. // Drop outbound
  723. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  724. // Allow inbound
  725. resetConntrack(fw)
  726. assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  727. //now also allow outbound
  728. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  729. })
  730. t.Run("nonzero, matching ports, still blocked", func(t *testing.T) {
  731. p := templ.Copy()
  732. p.LocalPort = 80
  733. p.RemotePort = 80
  734. // Drop outbound
  735. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  736. // Allow inbound
  737. resetConntrack(fw)
  738. assert.Equal(t, fw.Drop(*p, nil, true, &h, cp, nil), ErrNoMatchingRule)
  739. //now also allow outbound
  740. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  741. })
  742. })
  743. t.Run("Any proto, any port", func(t *testing.T) {
  744. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  745. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  746. t.Run("zero ports, allowed", func(t *testing.T) {
  747. resetConntrack(fw)
  748. p := templ.Copy()
  749. p.LocalPort = 0
  750. p.RemotePort = 0
  751. // Drop outbound
  752. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  753. // Allow inbound
  754. resetConntrack(fw)
  755. require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
  756. //now also allow outbound
  757. require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
  758. })
  759. t.Run("nonzero ports, allowed", func(t *testing.T) {
  760. resetConntrack(fw)
  761. p := templ.Copy()
  762. p.LocalPort = 0xabcd
  763. p.RemotePort = 0x1234
  764. // Drop outbound
  765. assert.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  766. // Allow inbound
  767. resetConntrack(fw)
  768. require.NoError(t, fw.Drop(*p, nil, true, &h, cp, nil))
  769. //now also allow outbound
  770. require.NoError(t, fw.Drop(*p, nil, false, &h, cp, nil))
  771. //different ID is blocked
  772. p.RemotePort++
  773. require.Equal(t, fw.Drop(*p, nil, false, &h, cp, nil), ErrNoMatchingRule)
  774. })
  775. })
  776. }
  777. func TestFirewall_DropIPSpoofing(t *testing.T) {
  778. l := test.NewLogger()
  779. ob := &bytes.Buffer{}
  780. l.SetOutput(ob)
  781. myVpnNetworksTable := new(bart.Lite)
  782. myVpnNetworksTable.Insert(netip.MustParsePrefix("192.0.2.1/24"))
  783. c := cert.CachedCertificate{
  784. Certificate: &dummyCert{
  785. name: "host-owner",
  786. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.1/24")},
  787. },
  788. }
  789. c1 := cert.CachedCertificate{
  790. Certificate: &dummyCert{
  791. name: "host",
  792. networks: []netip.Prefix{netip.MustParsePrefix("192.0.2.2/24")},
  793. unsafeNetworks: []netip.Prefix{netip.MustParsePrefix("198.51.100.0/24")},
  794. },
  795. }
  796. h1 := HostInfo{
  797. ConnectionState: &ConnectionState{
  798. peerCert: &c1,
  799. },
  800. vpnAddrs: []netip.Addr{c1.Certificate.Networks()[0].Addr()},
  801. }
  802. h1.buildNetworks(myVpnNetworksTable, c1.Certificate)
  803. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate, netip.Addr{})
  804. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", "", "", "", ""))
  805. cp := cert.NewCAPool()
  806. // Packet spoofed by `c1`. Note that the remote addr is not a valid one.
  807. p := firewall.Packet{
  808. LocalAddr: netip.MustParseAddr("192.0.2.1"),
  809. RemoteAddr: netip.MustParseAddr("192.0.2.3"),
  810. LocalPort: 1,
  811. RemotePort: 1,
  812. Protocol: firewall.ProtoUDP,
  813. Fragment: false,
  814. }
  815. assert.Equal(t, fw.Drop(p, nil, true, &h1, cp, nil), ErrInvalidRemoteIP)
  816. }
  817. func BenchmarkLookup(b *testing.B) {
  818. ml := func(m map[string]struct{}, a [][]string) {
  819. for n := 0; n < b.N; n++ {
  820. for _, sg := range a {
  821. found := false
  822. for _, g := range sg {
  823. if _, ok := m[g]; !ok {
  824. found = false
  825. break
  826. }
  827. found = true
  828. }
  829. if found {
  830. return
  831. }
  832. }
  833. }
  834. }
  835. b.Run("array to map best", func(b *testing.B) {
  836. m := map[string]struct{}{
  837. "1ne": {},
  838. "2wo": {},
  839. "3hr": {},
  840. "4ou": {},
  841. "5iv": {},
  842. "6ix": {},
  843. }
  844. a := [][]string{
  845. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  846. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  847. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  848. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  849. {"one", "two", "thr", "fou", "5iv", "6ix"},
  850. {"one", "two", "thr", "fou", "fiv", "6ix"},
  851. {"one", "two", "thr", "fou", "fiv", "six"},
  852. }
  853. for n := 0; n < b.N; n++ {
  854. ml(m, a)
  855. }
  856. })
  857. b.Run("array to map worst", func(b *testing.B) {
  858. m := map[string]struct{}{
  859. "one": {},
  860. "two": {},
  861. "thr": {},
  862. "fou": {},
  863. "fiv": {},
  864. "six": {},
  865. }
  866. a := [][]string{
  867. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  868. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  869. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  870. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  871. {"one", "two", "thr", "fou", "5iv", "6ix"},
  872. {"one", "two", "thr", "fou", "fiv", "6ix"},
  873. {"one", "two", "thr", "fou", "fiv", "six"},
  874. }
  875. for n := 0; n < b.N; n++ {
  876. ml(m, a)
  877. }
  878. })
  879. }
  880. func Test_parsePort(t *testing.T) {
  881. _, _, err := parsePort("")
  882. require.EqualError(t, err, "was not a number; ``")
  883. _, _, err = parsePort(" ")
  884. require.EqualError(t, err, "was not a number; ` `")
  885. _, _, err = parsePort("-")
  886. require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  887. _, _, err = parsePort(" - ")
  888. require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  889. _, _, err = parsePort("a-b")
  890. require.EqualError(t, err, "beginning range was not a number; `a`")
  891. _, _, err = parsePort("1-b")
  892. require.EqualError(t, err, "ending range was not a number; `b`")
  893. s, e, err := parsePort(" 1 - 2 ")
  894. assert.Equal(t, int32(1), s)
  895. assert.Equal(t, int32(2), e)
  896. require.NoError(t, err)
  897. s, e, err = parsePort("0-1")
  898. assert.Equal(t, int32(0), s)
  899. assert.Equal(t, int32(0), e)
  900. require.NoError(t, err)
  901. s, e, err = parsePort("9919")
  902. assert.Equal(t, int32(9919), s)
  903. assert.Equal(t, int32(9919), e)
  904. require.NoError(t, err)
  905. s, e, err = parsePort("any")
  906. assert.Equal(t, int32(0), s)
  907. assert.Equal(t, int32(0), e)
  908. require.NoError(t, err)
  909. }
  910. func TestNewFirewallFromConfig(t *testing.T) {
  911. l := test.NewLogger()
  912. // Test a bad rule definition
  913. c := &dummyCert{}
  914. cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil)
  915. require.NoError(t, err)
  916. conf := config.NewC(l)
  917. conf.Settings["firewall"] = map[string]any{"outbound": "asdf"}
  918. _, err = NewFirewallFromConfig(l, cs, conf)
  919. require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  920. // Test both port and code
  921. conf = config.NewC(l)
  922. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "code": "2"}}}
  923. _, err = NewFirewallFromConfig(l, cs, conf)
  924. require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  925. // Test missing host, group, cidr, ca_name and ca_sha
  926. conf = config.NewC(l)
  927. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{}}}
  928. _, err = NewFirewallFromConfig(l, cs, conf)
  929. require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  930. // Test code/port error
  931. conf = config.NewC(l)
  932. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "a", "host": "testh", "proto": "any"}}}
  933. _, err = NewFirewallFromConfig(l, cs, conf)
  934. require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  935. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "a", "host": "testh", "proto": "any"}}}
  936. _, err = NewFirewallFromConfig(l, cs, conf)
  937. require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  938. // Test proto error
  939. conf = config.NewC(l)
  940. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "host": "testh"}}}
  941. _, err = NewFirewallFromConfig(l, cs, conf)
  942. require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  943. // Test cidr parse error
  944. conf = config.NewC(l)
  945. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "cidr": "testh", "proto": "any"}}}
  946. _, err = NewFirewallFromConfig(l, cs, conf)
  947. require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  948. // Test local_cidr parse error
  949. conf = config.NewC(l)
  950. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  951. _, err = NewFirewallFromConfig(l, cs, conf)
  952. require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
  953. // Test both group and groups
  954. conf = config.NewC(l)
  955. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  956. _, err = NewFirewallFromConfig(l, cs, conf)
  957. require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  958. }
  959. func TestAddFirewallRulesFromConfig(t *testing.T) {
  960. l := test.NewLogger()
  961. // Test adding tcp rule
  962. conf := config.NewC(l)
  963. mf := &mockFirewall{}
  964. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "tcp", "host": "a"}}}
  965. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  966. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  967. // Test adding udp rule
  968. conf = config.NewC(l)
  969. mf = &mockFirewall{}
  970. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "udp", "host": "a"}}}
  971. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  972. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  973. // Test adding icmp rule
  974. conf = config.NewC(l)
  975. mf = &mockFirewall{}
  976. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"port": "1", "proto": "icmp", "host": "a"}}}
  977. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  978. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  979. // Test adding icmp rule no port
  980. conf = config.NewC(l)
  981. mf = &mockFirewall{}
  982. conf.Settings["firewall"] = map[string]any{"outbound": []any{map[string]any{"proto": "icmp", "host": "a"}}}
  983. require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  984. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: firewall.PortAny, endPort: firewall.PortAny, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  985. // Test adding any rule
  986. conf = config.NewC(l)
  987. mf = &mockFirewall{}
  988. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  989. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  990. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: "", localIp: ""}, mf.lastCall)
  991. // Test adding rule with cidr
  992. cidr := netip.MustParsePrefix("10.0.0.0/8")
  993. conf = config.NewC(l)
  994. mf = &mockFirewall{}
  995. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  996. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  997. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr.String(), localIp: ""}, mf.lastCall)
  998. // Test adding rule with local_cidr
  999. conf = config.NewC(l)
  1000. mf = &mockFirewall{}
  1001. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  1002. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1003. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr.String()}, mf.lastCall)
  1004. // Test adding rule with cidr ipv6
  1005. cidr6 := netip.MustParsePrefix("fd00::/8")
  1006. conf = config.NewC(l)
  1007. mf = &mockFirewall{}
  1008. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": cidr6.String()}}}
  1009. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1010. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr6.String(), localIp: ""}, mf.lastCall)
  1011. // Test adding rule with any cidr
  1012. conf = config.NewC(l)
  1013. mf = &mockFirewall{}
  1014. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "any"}}}
  1015. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1016. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "any", localIp: ""}, mf.lastCall)
  1017. // Test adding rule with junk cidr
  1018. conf = config.NewC(l)
  1019. mf = &mockFirewall{}
  1020. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "cidr": "junk/junk"}}}
  1021. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
  1022. // Test adding rule with local_cidr ipv6
  1023. conf = config.NewC(l)
  1024. mf = &mockFirewall{}
  1025. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": cidr6.String()}}}
  1026. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1027. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: cidr6.String()}, mf.lastCall)
  1028. // Test adding rule with any local_cidr
  1029. conf = config.NewC(l)
  1030. mf = &mockFirewall{}
  1031. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "any"}}}
  1032. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1033. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, localIp: "any"}, mf.lastCall)
  1034. // Test adding rule with junk local_cidr
  1035. conf = config.NewC(l)
  1036. mf = &mockFirewall{}
  1037. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "local_cidr": "junk/junk"}}}
  1038. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"junk/junk\"): ParseAddr(\"junk\"): unable to parse IP")
  1039. // Test adding rule with ca_sha
  1040. conf = config.NewC(l)
  1041. mf = &mockFirewall{}
  1042. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  1043. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1044. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caSha: "12312313123"}, mf.lastCall)
  1045. // Test adding rule with ca_name
  1046. conf = config.NewC(l)
  1047. mf = &mockFirewall{}
  1048. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "ca_name": "root01"}}}
  1049. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1050. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: "", localIp: "", caName: "root01"}, mf.lastCall)
  1051. // Test single group
  1052. conf = config.NewC(l)
  1053. mf = &mockFirewall{}
  1054. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "group": "a"}}}
  1055. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1056. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
  1057. // Test single groups
  1058. conf = config.NewC(l)
  1059. mf = &mockFirewall{}
  1060. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": "a"}}}
  1061. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1062. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: "", localIp: ""}, mf.lastCall)
  1063. // Test multiple AND groups
  1064. conf = config.NewC(l)
  1065. mf = &mockFirewall{}
  1066. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  1067. require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  1068. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: "", localIp: ""}, mf.lastCall)
  1069. // Test Add error
  1070. conf = config.NewC(l)
  1071. mf = &mockFirewall{}
  1072. mf.nextCallReturn = errors.New("test error")
  1073. conf.Settings["firewall"] = map[string]any{"inbound": []any{map[string]any{"port": "1", "proto": "any", "host": "a"}}}
  1074. require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  1075. }
  1076. func TestFirewall_convertRule(t *testing.T) {
  1077. l := test.NewLogger()
  1078. ob := &bytes.Buffer{}
  1079. l.SetOutput(ob)
  1080. // Ensure group array of 1 is converted and a warning is printed
  1081. c := map[string]any{
  1082. "group": []any{"group1"},
  1083. }
  1084. r, err := convertRule(l, c, "test", 1)
  1085. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  1086. require.NoError(t, err)
  1087. assert.Equal(t, []string{"group1"}, r.Groups)
  1088. // Ensure group array of > 1 is errord
  1089. ob.Reset()
  1090. c = map[string]any{
  1091. "group": []any{"group1", "group2"},
  1092. }
  1093. r, err = convertRule(l, c, "test", 1)
  1094. assert.Empty(t, ob.String())
  1095. require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  1096. // Make sure a well formed group is alright
  1097. ob.Reset()
  1098. c = map[string]any{
  1099. "group": "group1",
  1100. }
  1101. r, err = convertRule(l, c, "test", 1)
  1102. require.NoError(t, err)
  1103. assert.Equal(t, []string{"group1"}, r.Groups)
  1104. }
  1105. func TestFirewall_convertRuleSanity(t *testing.T) {
  1106. l := test.NewLogger()
  1107. ob := &bytes.Buffer{}
  1108. l.SetOutput(ob)
  1109. noWarningPlease := []map[string]any{
  1110. {"group": "group1"},
  1111. {"groups": []any{"group2"}},
  1112. {"host": "bob"},
  1113. {"cidr": "1.1.1.1/1"},
  1114. {"groups": []any{"group2"}, "host": "bob"},
  1115. {"cidr": "1.1.1.1/1", "host": "bob"},
  1116. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  1117. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  1118. }
  1119. for _, c := range noWarningPlease {
  1120. r, err := convertRule(l, c, "test", 1)
  1121. require.NoError(t, err)
  1122. require.NoError(t, r.sanity(), "should not generate a sanity warning, %+v", c)
  1123. }
  1124. yesWarningPlease := []map[string]any{
  1125. {"group": "group1"},
  1126. {"groups": []any{"group2"}},
  1127. {"cidr": "1.1.1.1/1"},
  1128. {"groups": []any{"group2"}, "host": "bob"},
  1129. {"cidr": "1.1.1.1/1", "host": "bob"},
  1130. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  1131. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  1132. }
  1133. for _, c := range yesWarningPlease {
  1134. c["host"] = "any"
  1135. r, err := convertRule(l, c, "test", 1)
  1136. require.NoError(t, err)
  1137. err = r.sanity()
  1138. require.Error(t, err, "I wanted a warning: %+v", c)
  1139. }
  1140. //reset the list
  1141. yesWarningPlease = []map[string]any{
  1142. {"group": "group1"},
  1143. {"groups": []any{"group2"}},
  1144. {"cidr": "1.1.1.1/1"},
  1145. {"groups": []any{"group2"}, "host": "bob"},
  1146. {"cidr": "1.1.1.1/1", "host": "bob"},
  1147. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1"},
  1148. {"groups": []any{"group2"}, "cidr": "1.1.1.1/1", "host": "bob"},
  1149. }
  1150. for _, c := range yesWarningPlease {
  1151. r, err := convertRule(l, c, "test", 1)
  1152. require.NoError(t, err)
  1153. r.Groups = append(r.Groups, "any")
  1154. err = r.sanity()
  1155. require.Error(t, err, "I wanted a warning: %+v", c)
  1156. }
  1157. }
  1158. type testcase struct {
  1159. h *HostInfo
  1160. p firewall.Packet
  1161. c cert.Certificate
  1162. err error
  1163. }
  1164. func (c *testcase) Test(t *testing.T, fw *Firewall) {
  1165. t.Helper()
  1166. cp := cert.NewCAPool()
  1167. resetConntrack(fw)
  1168. err := fw.Drop(c.p, nil, true, c.h, cp, nil)
  1169. if c.err == nil {
  1170. require.NoError(t, err, "failed to not drop remote address %s", c.p.RemoteAddr)
  1171. } else {
  1172. require.ErrorIs(t, c.err, err, "failed to drop remote address %s", c.p.RemoteAddr)
  1173. }
  1174. }
  1175. func buildHostinfo(setup testsetup, theirPrefixes ...netip.Prefix) *HostInfo {
  1176. c1 := dummyCert{
  1177. name: "host1",
  1178. networks: theirPrefixes,
  1179. groups: []string{"default-group"},
  1180. issuer: "signer-shasum",
  1181. }
  1182. h := HostInfo{
  1183. ConnectionState: &ConnectionState{
  1184. peerCert: &cert.CachedCertificate{
  1185. Certificate: &c1,
  1186. InvertedGroups: map[string]struct{}{"default-group": {}},
  1187. },
  1188. },
  1189. vpnAddrs: make([]netip.Addr, len(theirPrefixes)),
  1190. }
  1191. for i := range theirPrefixes {
  1192. h.vpnAddrs[i] = theirPrefixes[i].Addr()
  1193. }
  1194. h.buildNetworks(setup.myVpnNetworksTable, &c1)
  1195. return &h
  1196. }
  1197. func buildTestCase(setup testsetup, err error, theirPrefixes ...netip.Prefix) testcase {
  1198. h := buildHostinfo(setup, theirPrefixes...)
  1199. p := firewall.Packet{
  1200. LocalAddr: setup.c.Networks()[0].Addr(), //todo?
  1201. RemoteAddr: theirPrefixes[0].Addr(),
  1202. LocalPort: 10,
  1203. RemotePort: 90,
  1204. Protocol: firewall.ProtoUDP,
  1205. Fragment: false,
  1206. }
  1207. return testcase{
  1208. h: h,
  1209. p: p,
  1210. c: h.ConnectionState.peerCert.Certificate,
  1211. err: err,
  1212. }
  1213. }
  1214. type testsetup struct {
  1215. c dummyCert
  1216. myVpnNetworksTable *bart.Lite
  1217. fw *Firewall
  1218. }
  1219. func newSetup(t *testing.T, l *logrus.Logger, myPrefixes ...netip.Prefix) testsetup {
  1220. c := dummyCert{
  1221. name: "me",
  1222. networks: myPrefixes,
  1223. groups: []string{"default-group"},
  1224. issuer: "signer-shasum",
  1225. }
  1226. return newSetupFromCert(t, l, c)
  1227. }
  1228. func newSnatSetup(t *testing.T, l *logrus.Logger, myPrefix netip.Prefix, snatAddr netip.Addr) testsetup {
  1229. c := dummyCert{
  1230. name: "me",
  1231. networks: []netip.Prefix{myPrefix},
  1232. groups: []string{"default-group"},
  1233. issuer: "signer-shasum",
  1234. }
  1235. out := newSetupFromCert(t, l, c)
  1236. out.fw.snatAddr = snatAddr
  1237. return out
  1238. }
  1239. func newSetupFromCert(t *testing.T, l *logrus.Logger, c dummyCert) testsetup {
  1240. myVpnNetworksTable := new(bart.Lite)
  1241. for _, prefix := range c.Networks() {
  1242. myVpnNetworksTable.Insert(prefix)
  1243. }
  1244. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c, netip.Addr{})
  1245. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  1246. return testsetup{
  1247. c: c,
  1248. fw: fw,
  1249. myVpnNetworksTable: myVpnNetworksTable,
  1250. }
  1251. }
  1252. func TestFirewall_Drop_EnforceIPMatch(t *testing.T) {
  1253. t.Parallel()
  1254. l := test.NewLogger()
  1255. ob := &bytes.Buffer{}
  1256. l.SetOutput(ob)
  1257. myPrefix := netip.MustParsePrefix("1.1.1.1/8")
  1258. // for now, it's okay that these are all "incoming", the logic this test tries to check doesn't care about in/out
  1259. t.Run("allow inbound all matching", func(t *testing.T) {
  1260. t.Parallel()
  1261. setup := newSetup(t, l, myPrefix)
  1262. tc := buildTestCase(setup, nil, netip.MustParsePrefix("1.2.3.4/24"))
  1263. tc.Test(t, setup.fw)
  1264. })
  1265. t.Run("allow inbound local matching", func(t *testing.T) {
  1266. t.Parallel()
  1267. setup := newSetup(t, l, myPrefix)
  1268. tc := buildTestCase(setup, ErrInvalidLocalIP, netip.MustParsePrefix("1.2.3.4/24"))
  1269. tc.p.LocalAddr = netip.MustParseAddr("1.2.3.8")
  1270. tc.Test(t, setup.fw)
  1271. })
  1272. t.Run("block inbound remote mismatched", func(t *testing.T) {
  1273. t.Parallel()
  1274. setup := newSetup(t, l, myPrefix)
  1275. tc := buildTestCase(setup, ErrInvalidRemoteIP, netip.MustParsePrefix("1.2.3.4/24"))
  1276. tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
  1277. tc.Test(t, setup.fw)
  1278. })
  1279. t.Run("Block a vpn peer packet", func(t *testing.T) {
  1280. t.Parallel()
  1281. setup := newSetup(t, l, myPrefix)
  1282. tc := buildTestCase(setup, ErrPeerRejected, netip.MustParsePrefix("2.2.2.2/24"))
  1283. tc.Test(t, setup.fw)
  1284. })
  1285. twoPrefixes := []netip.Prefix{
  1286. netip.MustParsePrefix("1.2.3.4/24"), netip.MustParsePrefix("2.2.2.2/24"),
  1287. }
  1288. t.Run("allow inbound one matching", func(t *testing.T) {
  1289. t.Parallel()
  1290. setup := newSetup(t, l, myPrefix)
  1291. tc := buildTestCase(setup, nil, twoPrefixes...)
  1292. tc.Test(t, setup.fw)
  1293. })
  1294. t.Run("block inbound multimismatch", func(t *testing.T) {
  1295. t.Parallel()
  1296. setup := newSetup(t, l, myPrefix)
  1297. tc := buildTestCase(setup, ErrInvalidRemoteIP, twoPrefixes...)
  1298. tc.p.RemoteAddr = netip.MustParseAddr("9.9.9.9")
  1299. tc.Test(t, setup.fw)
  1300. })
  1301. t.Run("allow inbound 2nd one matching", func(t *testing.T) {
  1302. t.Parallel()
  1303. setup2 := newSetup(t, l, netip.MustParsePrefix("2.2.2.1/24"))
  1304. tc := buildTestCase(setup2, nil, twoPrefixes...)
  1305. tc.p.RemoteAddr = twoPrefixes[1].Addr()
  1306. tc.Test(t, setup2.fw)
  1307. })
  1308. t.Run("allow inbound unsafe route", func(t *testing.T) {
  1309. t.Parallel()
  1310. unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
  1311. c := dummyCert{
  1312. name: "me",
  1313. networks: []netip.Prefix{myPrefix},
  1314. unsafeNetworks: []netip.Prefix{unsafePrefix},
  1315. groups: []string{"default-group"},
  1316. issuer: "signer-shasum",
  1317. }
  1318. unsafeSetup := newSetupFromCert(t, l, c)
  1319. tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
  1320. tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
  1321. tc.err = ErrNoMatchingRule
  1322. tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
  1323. require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
  1324. tc.err = nil
  1325. tc.Test(t, unsafeSetup.fw) //should pass
  1326. })
  1327. }
  1328. type addRuleCall struct {
  1329. incoming bool
  1330. proto uint8
  1331. startPort int32
  1332. endPort int32
  1333. groups []string
  1334. host string
  1335. ip string
  1336. localIp string
  1337. caName string
  1338. caSha string
  1339. }
  1340. type mockFirewall struct {
  1341. lastCall addRuleCall
  1342. nextCallReturn error
  1343. }
  1344. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp, caName string, caSha string) error {
  1345. mf.lastCall = addRuleCall{
  1346. incoming: incoming,
  1347. proto: proto,
  1348. startPort: startPort,
  1349. endPort: endPort,
  1350. groups: groups,
  1351. host: host,
  1352. ip: ip,
  1353. localIp: localIp,
  1354. caName: caName,
  1355. caSha: caSha,
  1356. }
  1357. err := mf.nextCallReturn
  1358. mf.nextCallReturn = nil
  1359. return err
  1360. }
  1361. func resetConntrack(fw *Firewall) {
  1362. fw.Conntrack.Lock()
  1363. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  1364. fw.Conntrack.Unlock()
  1365. }
  1366. func TestFirewall_SNAT(t *testing.T) {
  1367. t.Parallel()
  1368. l := test.NewLogger()
  1369. ob := &bytes.Buffer{}
  1370. l.SetOutput(ob)
  1371. cp := cert.NewCAPool()
  1372. myPrefix := netip.MustParsePrefix("1.1.1.1/8")
  1373. MyCert := dummyCert{
  1374. name: "me",
  1375. networks: []netip.Prefix{myPrefix},
  1376. groups: []string{"default-group"},
  1377. issuer: "signer-shasum",
  1378. }
  1379. theirPrefix := netip.MustParsePrefix("1.2.2.2/8")
  1380. snatAddr := netip.MustParseAddr("169.254.55.96")
  1381. t.Run("allow inbound all matching", func(t *testing.T) {
  1382. t.Parallel()
  1383. myCert := MyCert.Copy()
  1384. setup := newSnatSetup(t, l, myPrefix, snatAddr)
  1385. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, myCert, snatAddr)
  1386. require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", "", "", ""))
  1387. resetConntrack(setup.fw)
  1388. h := buildHostinfo(setup, theirPrefix)
  1389. p := firewall.Packet{
  1390. LocalAddr: setup.c.Networks()[0].Addr(), //todo?
  1391. RemoteAddr: h.vpnAddrs[0],
  1392. LocalPort: 10,
  1393. RemotePort: 90,
  1394. Protocol: firewall.ProtoUDP,
  1395. Fragment: false,
  1396. }
  1397. require.NoError(t, setup.fw.Drop(p, nil, true, h, cp, nil))
  1398. })
  1399. //t.Run("allow inbound unsafe route", func(t *testing.T) {
  1400. // t.Parallel()
  1401. // unsafePrefix := netip.MustParsePrefix("192.168.0.0/24")
  1402. // c := dummyCert{
  1403. // name: "me",
  1404. // networks: []netip.Prefix{myPrefix},
  1405. // unsafeNetworks: []netip.Prefix{unsafePrefix},
  1406. // groups: []string{"default-group"},
  1407. // issuer: "signer-shasum",
  1408. // }
  1409. // unsafeSetup := newSetupFromCert(t, l, c)
  1410. // tc := buildTestCase(unsafeSetup, nil, twoPrefixes...)
  1411. // tc.p.LocalAddr = netip.MustParseAddr("192.168.0.3")
  1412. // tc.err = ErrNoMatchingRule
  1413. // tc.Test(t, unsafeSetup.fw) //should hit firewall and bounce off
  1414. // require.NoError(t, unsafeSetup.fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", "", unsafePrefix.String(), "", ""))
  1415. // tc.err = nil
  1416. // tc.Test(t, unsafeSetup.fw) //should pass
  1417. //})
  1418. }