3
0

firewall_test.go 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908
  1. package nebula
  2. import (
  3. "bytes"
  4. "errors"
  5. "math"
  6. "net"
  7. "testing"
  8. "time"
  9. "github.com/slackhq/nebula/cert"
  10. "github.com/slackhq/nebula/config"
  11. "github.com/slackhq/nebula/firewall"
  12. "github.com/slackhq/nebula/iputil"
  13. "github.com/slackhq/nebula/test"
  14. "github.com/stretchr/testify/assert"
  15. )
  16. func TestNewFirewall(t *testing.T) {
  17. l := test.NewLogger()
  18. c := &cert.NebulaCertificate{}
  19. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  20. conntrack := fw.Conntrack
  21. assert.NotNil(t, conntrack)
  22. assert.NotNil(t, conntrack.Conns)
  23. assert.NotNil(t, conntrack.TimerWheel)
  24. assert.NotNil(t, fw.InRules)
  25. assert.NotNil(t, fw.OutRules)
  26. assert.Equal(t, time.Second, fw.TCPTimeout)
  27. assert.Equal(t, time.Minute, fw.UDPTimeout)
  28. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  29. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  30. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  31. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  32. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  33. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  34. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  35. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  36. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  37. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  38. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  39. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  40. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  41. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  42. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  43. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  44. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  45. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  46. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  47. }
  48. func TestFirewall_AddRule(t *testing.T) {
  49. l := test.NewLogger()
  50. ob := &bytes.Buffer{}
  51. l.SetOutput(ob)
  52. c := &cert.NebulaCertificate{}
  53. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  54. assert.NotNil(t, fw.InRules)
  55. assert.NotNil(t, fw.OutRules)
  56. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  57. assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
  58. // An empty rule is any
  59. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  60. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  61. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  62. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  63. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
  64. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  65. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  66. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  67. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  68. assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
  69. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  70. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  71. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  72. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  73. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
  74. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  75. ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
  76. assert.True(t, ok)
  77. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  78. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
  79. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  80. ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
  81. assert.True(t, ok)
  82. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  83. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
  84. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  85. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  86. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
  87. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  89. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
  90. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  91. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  92. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  93. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
  94. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  95. // Test error conditions
  96. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  97. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
  98. assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
  99. }
  100. func TestFirewall_Drop(t *testing.T) {
  101. l := test.NewLogger()
  102. ob := &bytes.Buffer{}
  103. l.SetOutput(ob)
  104. p := firewall.Packet{
  105. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  106. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  107. LocalPort: 10,
  108. RemotePort: 90,
  109. Protocol: firewall.ProtoUDP,
  110. Fragment: false,
  111. }
  112. ipNet := net.IPNet{
  113. IP: net.IPv4(1, 2, 3, 4),
  114. Mask: net.IPMask{255, 255, 255, 0},
  115. }
  116. c := cert.NebulaCertificate{
  117. Details: cert.NebulaCertificateDetails{
  118. Name: "host1",
  119. Ips: []*net.IPNet{&ipNet},
  120. Groups: []string{"default-group"},
  121. InvertedGroups: map[string]struct{}{"default-group": {}},
  122. Issuer: "signer-shasum",
  123. },
  124. }
  125. h := HostInfo{
  126. ConnectionState: &ConnectionState{
  127. peerCert: &c,
  128. },
  129. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  130. }
  131. h.CreateRemoteCIDR(&c)
  132. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  133. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  134. cp := cert.NewCAPool()
  135. // Drop outbound
  136. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  137. // Allow inbound
  138. resetConntrack(fw)
  139. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  140. // Allow outbound because conntrack
  141. assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
  142. // test remote mismatch
  143. oldRemote := p.RemoteIP
  144. p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
  145. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP)
  146. p.RemoteIP = oldRemote
  147. // ensure signer doesn't get in the way of group checks
  148. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  149. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
  150. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
  151. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  152. // test caSha doesn't drop on match
  153. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  154. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
  155. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
  156. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  157. // ensure ca name doesn't get in the way of group checks
  158. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  159. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  160. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
  161. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
  162. assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
  163. // test caName doesn't drop on match
  164. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  165. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  166. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
  167. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
  168. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  169. }
  170. func BenchmarkFirewallTable_match(b *testing.B) {
  171. f := &Firewall{}
  172. ft := FirewallTable{
  173. TCP: firewallPort{},
  174. }
  175. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  176. goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
  177. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
  178. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
  179. cp := cert.NewCAPool()
  180. b.Run("fail on proto", func(b *testing.B) {
  181. // This benchmark is showing us the cost of failing to match the protocol
  182. c := &cert.NebulaCertificate{}
  183. for n := 0; n < b.N; n++ {
  184. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  185. }
  186. })
  187. b.Run("pass proto, fail on port", func(b *testing.B) {
  188. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  189. c := &cert.NebulaCertificate{}
  190. for n := 0; n < b.N; n++ {
  191. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  192. }
  193. })
  194. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  195. c := &cert.NebulaCertificate{}
  196. ip, _, _ := net.ParseCIDR("9.254.254.254/32")
  197. lip := iputil.Ip2VpnIp(ip)
  198. for n := 0; n < b.N; n++ {
  199. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
  200. }
  201. })
  202. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  203. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  204. c := &cert.NebulaCertificate{
  205. Details: cert.NebulaCertificateDetails{
  206. InvertedGroups: map[string]struct{}{"nope": {}},
  207. Name: "nope",
  208. Ips: []*net.IPNet{ip},
  209. },
  210. }
  211. for n := 0; n < b.N; n++ {
  212. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  213. }
  214. })
  215. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  216. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  217. c := &cert.NebulaCertificate{
  218. Details: cert.NebulaCertificateDetails{
  219. InvertedGroups: map[string]struct{}{"nope": {}},
  220. Name: "nope",
  221. Ips: []*net.IPNet{ip},
  222. },
  223. }
  224. for n := 0; n < b.N; n++ {
  225. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
  226. }
  227. })
  228. b.Run("pass on group on any local cidr", func(b *testing.B) {
  229. c := &cert.NebulaCertificate{
  230. Details: cert.NebulaCertificateDetails{
  231. InvertedGroups: map[string]struct{}{"good-group": {}},
  232. Name: "nope",
  233. },
  234. }
  235. for n := 0; n < b.N; n++ {
  236. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  237. }
  238. })
  239. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  240. c := &cert.NebulaCertificate{
  241. Details: cert.NebulaCertificateDetails{
  242. InvertedGroups: map[string]struct{}{"good-group": {}},
  243. Name: "nope",
  244. },
  245. }
  246. for n := 0; n < b.N; n++ {
  247. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
  248. }
  249. })
  250. b.Run("pass on name", func(b *testing.B) {
  251. c := &cert.NebulaCertificate{
  252. Details: cert.NebulaCertificateDetails{
  253. InvertedGroups: map[string]struct{}{"nope": {}},
  254. Name: "good-host",
  255. },
  256. }
  257. for n := 0; n < b.N; n++ {
  258. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  259. }
  260. })
  261. //
  262. //b.Run("pass on ip", func(b *testing.B) {
  263. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  264. // c := &cert.NebulaCertificate{
  265. // Details: cert.NebulaCertificateDetails{
  266. // InvertedGroups: map[string]struct{}{"nope": {}},
  267. // Name: "good-host",
  268. // },
  269. // }
  270. // for n := 0; n < b.N; n++ {
  271. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  272. // }
  273. //})
  274. //
  275. //b.Run("pass on local ip", func(b *testing.B) {
  276. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  277. // c := &cert.NebulaCertificate{
  278. // Details: cert.NebulaCertificateDetails{
  279. // InvertedGroups: map[string]struct{}{"nope": {}},
  280. // Name: "good-host",
  281. // },
  282. // }
  283. // for n := 0; n < b.N; n++ {
  284. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
  285. // }
  286. //})
  287. //
  288. //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
  289. //
  290. //b.Run("pass on ip with any port", func(b *testing.B) {
  291. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  292. // c := &cert.NebulaCertificate{
  293. // Details: cert.NebulaCertificateDetails{
  294. // InvertedGroups: map[string]struct{}{"nope": {}},
  295. // Name: "good-host",
  296. // },
  297. // }
  298. // for n := 0; n < b.N; n++ {
  299. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  300. // }
  301. //})
  302. //
  303. //b.Run("pass on local ip with any port", func(b *testing.B) {
  304. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  305. // c := &cert.NebulaCertificate{
  306. // Details: cert.NebulaCertificateDetails{
  307. // InvertedGroups: map[string]struct{}{"nope": {}},
  308. // Name: "good-host",
  309. // },
  310. // }
  311. // for n := 0; n < b.N; n++ {
  312. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
  313. // }
  314. //})
  315. }
  316. func TestFirewall_Drop2(t *testing.T) {
  317. l := test.NewLogger()
  318. ob := &bytes.Buffer{}
  319. l.SetOutput(ob)
  320. p := firewall.Packet{
  321. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  322. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  323. LocalPort: 10,
  324. RemotePort: 90,
  325. Protocol: firewall.ProtoUDP,
  326. Fragment: false,
  327. }
  328. ipNet := net.IPNet{
  329. IP: net.IPv4(1, 2, 3, 4),
  330. Mask: net.IPMask{255, 255, 255, 0},
  331. }
  332. c := cert.NebulaCertificate{
  333. Details: cert.NebulaCertificateDetails{
  334. Name: "host1",
  335. Ips: []*net.IPNet{&ipNet},
  336. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  337. },
  338. }
  339. h := HostInfo{
  340. ConnectionState: &ConnectionState{
  341. peerCert: &c,
  342. },
  343. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  344. }
  345. h.CreateRemoteCIDR(&c)
  346. c1 := cert.NebulaCertificate{
  347. Details: cert.NebulaCertificateDetails{
  348. Name: "host1",
  349. Ips: []*net.IPNet{&ipNet},
  350. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  351. },
  352. }
  353. h1 := HostInfo{
  354. ConnectionState: &ConnectionState{
  355. peerCert: &c1,
  356. },
  357. }
  358. h1.CreateRemoteCIDR(&c1)
  359. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  360. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
  361. cp := cert.NewCAPool()
  362. // h1/c1 lacks the proper groups
  363. assert.Error(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
  364. // c has the proper groups
  365. resetConntrack(fw)
  366. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  367. }
  368. func TestFirewall_Drop3(t *testing.T) {
  369. l := test.NewLogger()
  370. ob := &bytes.Buffer{}
  371. l.SetOutput(ob)
  372. p := firewall.Packet{
  373. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  374. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  375. LocalPort: 1,
  376. RemotePort: 1,
  377. Protocol: firewall.ProtoUDP,
  378. Fragment: false,
  379. }
  380. ipNet := net.IPNet{
  381. IP: net.IPv4(1, 2, 3, 4),
  382. Mask: net.IPMask{255, 255, 255, 0},
  383. }
  384. c := cert.NebulaCertificate{
  385. Details: cert.NebulaCertificateDetails{
  386. Name: "host-owner",
  387. Ips: []*net.IPNet{&ipNet},
  388. },
  389. }
  390. c1 := cert.NebulaCertificate{
  391. Details: cert.NebulaCertificateDetails{
  392. Name: "host1",
  393. Ips: []*net.IPNet{&ipNet},
  394. Issuer: "signer-sha-bad",
  395. },
  396. }
  397. h1 := HostInfo{
  398. ConnectionState: &ConnectionState{
  399. peerCert: &c1,
  400. },
  401. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  402. }
  403. h1.CreateRemoteCIDR(&c1)
  404. c2 := cert.NebulaCertificate{
  405. Details: cert.NebulaCertificateDetails{
  406. Name: "host2",
  407. Ips: []*net.IPNet{&ipNet},
  408. Issuer: "signer-sha",
  409. },
  410. }
  411. h2 := HostInfo{
  412. ConnectionState: &ConnectionState{
  413. peerCert: &c2,
  414. },
  415. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  416. }
  417. h2.CreateRemoteCIDR(&c2)
  418. c3 := cert.NebulaCertificate{
  419. Details: cert.NebulaCertificateDetails{
  420. Name: "host3",
  421. Ips: []*net.IPNet{&ipNet},
  422. Issuer: "signer-sha-bad",
  423. },
  424. }
  425. h3 := HostInfo{
  426. ConnectionState: &ConnectionState{
  427. peerCert: &c3,
  428. },
  429. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  430. }
  431. h3.CreateRemoteCIDR(&c3)
  432. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  433. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
  434. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
  435. cp := cert.NewCAPool()
  436. // c1 should pass because host match
  437. assert.NoError(t, fw.Drop(p, true, &h1, cp, nil))
  438. // c2 should pass because ca sha match
  439. resetConntrack(fw)
  440. assert.NoError(t, fw.Drop(p, true, &h2, cp, nil))
  441. // c3 should fail because no match
  442. resetConntrack(fw)
  443. assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
  444. }
  445. func TestFirewall_DropConntrackReload(t *testing.T) {
  446. l := test.NewLogger()
  447. ob := &bytes.Buffer{}
  448. l.SetOutput(ob)
  449. p := firewall.Packet{
  450. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  451. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  452. LocalPort: 10,
  453. RemotePort: 90,
  454. Protocol: firewall.ProtoUDP,
  455. Fragment: false,
  456. }
  457. ipNet := net.IPNet{
  458. IP: net.IPv4(1, 2, 3, 4),
  459. Mask: net.IPMask{255, 255, 255, 0},
  460. }
  461. c := cert.NebulaCertificate{
  462. Details: cert.NebulaCertificateDetails{
  463. Name: "host1",
  464. Ips: []*net.IPNet{&ipNet},
  465. Groups: []string{"default-group"},
  466. InvertedGroups: map[string]struct{}{"default-group": {}},
  467. Issuer: "signer-shasum",
  468. },
  469. }
  470. h := HostInfo{
  471. ConnectionState: &ConnectionState{
  472. peerCert: &c,
  473. },
  474. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  475. }
  476. h.CreateRemoteCIDR(&c)
  477. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  478. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  479. cp := cert.NewCAPool()
  480. // Drop outbound
  481. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  482. // Allow inbound
  483. resetConntrack(fw)
  484. assert.NoError(t, fw.Drop(p, true, &h, cp, nil))
  485. // Allow outbound because conntrack
  486. assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
  487. oldFw := fw
  488. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  489. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
  490. fw.Conntrack = oldFw.Conntrack
  491. fw.rulesVersion = oldFw.rulesVersion + 1
  492. // Allow outbound because conntrack and new rules allow port 10
  493. assert.NoError(t, fw.Drop(p, false, &h, cp, nil))
  494. oldFw = fw
  495. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  496. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
  497. fw.Conntrack = oldFw.Conntrack
  498. fw.rulesVersion = oldFw.rulesVersion + 1
  499. // Drop outbound because conntrack doesn't match new ruleset
  500. assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
  501. }
  502. func BenchmarkLookup(b *testing.B) {
  503. ml := func(m map[string]struct{}, a [][]string) {
  504. for n := 0; n < b.N; n++ {
  505. for _, sg := range a {
  506. found := false
  507. for _, g := range sg {
  508. if _, ok := m[g]; !ok {
  509. found = false
  510. break
  511. }
  512. found = true
  513. }
  514. if found {
  515. return
  516. }
  517. }
  518. }
  519. }
  520. b.Run("array to map best", func(b *testing.B) {
  521. m := map[string]struct{}{
  522. "1ne": {},
  523. "2wo": {},
  524. "3hr": {},
  525. "4ou": {},
  526. "5iv": {},
  527. "6ix": {},
  528. }
  529. a := [][]string{
  530. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  531. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  532. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  533. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  534. {"one", "two", "thr", "fou", "5iv", "6ix"},
  535. {"one", "two", "thr", "fou", "fiv", "6ix"},
  536. {"one", "two", "thr", "fou", "fiv", "six"},
  537. }
  538. for n := 0; n < b.N; n++ {
  539. ml(m, a)
  540. }
  541. })
  542. b.Run("array to map worst", func(b *testing.B) {
  543. m := map[string]struct{}{
  544. "one": {},
  545. "two": {},
  546. "thr": {},
  547. "fou": {},
  548. "fiv": {},
  549. "six": {},
  550. }
  551. a := [][]string{
  552. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  553. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  554. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  555. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  556. {"one", "two", "thr", "fou", "5iv", "6ix"},
  557. {"one", "two", "thr", "fou", "fiv", "6ix"},
  558. {"one", "two", "thr", "fou", "fiv", "six"},
  559. }
  560. for n := 0; n < b.N; n++ {
  561. ml(m, a)
  562. }
  563. })
  564. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  565. }
  566. func Test_parsePort(t *testing.T) {
  567. _, _, err := parsePort("")
  568. assert.EqualError(t, err, "was not a number; ``")
  569. _, _, err = parsePort(" ")
  570. assert.EqualError(t, err, "was not a number; ` `")
  571. _, _, err = parsePort("-")
  572. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  573. _, _, err = parsePort(" - ")
  574. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  575. _, _, err = parsePort("a-b")
  576. assert.EqualError(t, err, "beginning range was not a number; `a`")
  577. _, _, err = parsePort("1-b")
  578. assert.EqualError(t, err, "ending range was not a number; `b`")
  579. s, e, err := parsePort(" 1 - 2 ")
  580. assert.Equal(t, int32(1), s)
  581. assert.Equal(t, int32(2), e)
  582. assert.Nil(t, err)
  583. s, e, err = parsePort("0-1")
  584. assert.Equal(t, int32(0), s)
  585. assert.Equal(t, int32(0), e)
  586. assert.Nil(t, err)
  587. s, e, err = parsePort("9919")
  588. assert.Equal(t, int32(9919), s)
  589. assert.Equal(t, int32(9919), e)
  590. assert.Nil(t, err)
  591. s, e, err = parsePort("any")
  592. assert.Equal(t, int32(0), s)
  593. assert.Equal(t, int32(0), e)
  594. assert.Nil(t, err)
  595. }
  596. func TestNewFirewallFromConfig(t *testing.T) {
  597. l := test.NewLogger()
  598. // Test a bad rule definition
  599. c := &cert.NebulaCertificate{}
  600. conf := config.NewC(l)
  601. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  602. _, err := NewFirewallFromConfig(l, c, conf)
  603. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  604. // Test both port and code
  605. conf = config.NewC(l)
  606. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  607. _, err = NewFirewallFromConfig(l, c, conf)
  608. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  609. // Test missing host, group, cidr, ca_name and ca_sha
  610. conf = config.NewC(l)
  611. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  612. _, err = NewFirewallFromConfig(l, c, conf)
  613. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  614. // Test code/port error
  615. conf = config.NewC(l)
  616. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  617. _, err = NewFirewallFromConfig(l, c, conf)
  618. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  619. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  620. _, err = NewFirewallFromConfig(l, c, conf)
  621. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  622. // Test proto error
  623. conf = config.NewC(l)
  624. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  625. _, err = NewFirewallFromConfig(l, c, conf)
  626. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  627. // Test cidr parse error
  628. conf = config.NewC(l)
  629. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  630. _, err = NewFirewallFromConfig(l, c, conf)
  631. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  632. // Test local_cidr parse error
  633. conf = config.NewC(l)
  634. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  635. _, err = NewFirewallFromConfig(l, c, conf)
  636. assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
  637. // Test both group and groups
  638. conf = config.NewC(l)
  639. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  640. _, err = NewFirewallFromConfig(l, c, conf)
  641. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  642. }
  643. func TestAddFirewallRulesFromConfig(t *testing.T) {
  644. l := test.NewLogger()
  645. // Test adding tcp rule
  646. conf := config.NewC(l)
  647. mf := &mockFirewall{}
  648. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  649. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  650. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  651. // Test adding udp rule
  652. conf = config.NewC(l)
  653. mf = &mockFirewall{}
  654. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  655. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  656. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  657. // Test adding icmp rule
  658. conf = config.NewC(l)
  659. mf = &mockFirewall{}
  660. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  661. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  662. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  663. // Test adding any rule
  664. conf = config.NewC(l)
  665. mf = &mockFirewall{}
  666. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  667. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  668. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  669. // Test adding rule with cidr
  670. cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
  671. conf = config.NewC(l)
  672. mf = &mockFirewall{}
  673. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  674. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  675. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
  676. // Test adding rule with local_cidr
  677. conf = config.NewC(l)
  678. mf = &mockFirewall{}
  679. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  680. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  681. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
  682. // Test adding rule with ca_sha
  683. conf = config.NewC(l)
  684. mf = &mockFirewall{}
  685. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  686. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  687. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
  688. // Test adding rule with ca_name
  689. conf = config.NewC(l)
  690. mf = &mockFirewall{}
  691. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  692. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  693. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
  694. // Test single group
  695. conf = config.NewC(l)
  696. mf = &mockFirewall{}
  697. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  698. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  699. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
  700. // Test single groups
  701. conf = config.NewC(l)
  702. mf = &mockFirewall{}
  703. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  704. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  705. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
  706. // Test multiple AND groups
  707. conf = config.NewC(l)
  708. mf = &mockFirewall{}
  709. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  710. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  711. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
  712. // Test Add error
  713. conf = config.NewC(l)
  714. mf = &mockFirewall{}
  715. mf.nextCallReturn = errors.New("test error")
  716. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  717. assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  718. }
  719. func TestFirewall_convertRule(t *testing.T) {
  720. l := test.NewLogger()
  721. ob := &bytes.Buffer{}
  722. l.SetOutput(ob)
  723. // Ensure group array of 1 is converted and a warning is printed
  724. c := map[interface{}]interface{}{
  725. "group": []interface{}{"group1"},
  726. }
  727. r, err := convertRule(l, c, "test", 1)
  728. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  729. assert.Nil(t, err)
  730. assert.Equal(t, "group1", r.Group)
  731. // Ensure group array of > 1 is errord
  732. ob.Reset()
  733. c = map[interface{}]interface{}{
  734. "group": []interface{}{"group1", "group2"},
  735. }
  736. r, err = convertRule(l, c, "test", 1)
  737. assert.Equal(t, "", ob.String())
  738. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  739. // Make sure a well formed group is alright
  740. ob.Reset()
  741. c = map[interface{}]interface{}{
  742. "group": "group1",
  743. }
  744. r, err = convertRule(l, c, "test", 1)
  745. assert.Nil(t, err)
  746. assert.Equal(t, "group1", r.Group)
  747. }
  748. type addRuleCall struct {
  749. incoming bool
  750. proto uint8
  751. startPort int32
  752. endPort int32
  753. groups []string
  754. host string
  755. ip *net.IPNet
  756. localIp *net.IPNet
  757. caName string
  758. caSha string
  759. }
  760. type mockFirewall struct {
  761. lastCall addRuleCall
  762. nextCallReturn error
  763. }
  764. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
  765. mf.lastCall = addRuleCall{
  766. incoming: incoming,
  767. proto: proto,
  768. startPort: startPort,
  769. endPort: endPort,
  770. groups: groups,
  771. host: host,
  772. ip: ip,
  773. localIp: localIp,
  774. caName: caName,
  775. caSha: caSha,
  776. }
  777. err := mf.nextCallReturn
  778. mf.nextCallReturn = nil
  779. return err
  780. }
  781. func resetConntrack(fw *Firewall) {
  782. fw.Conntrack.Lock()
  783. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  784. fw.Conntrack.Unlock()
  785. }