firewall_test.go 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. "net"
  8. "testing"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/firewall"
  14. "github.com/slackhq/nebula/iputil"
  15. "github.com/slackhq/nebula/test"
  16. "github.com/stretchr/testify/assert"
  17. )
  18. func TestNewFirewall(t *testing.T) {
  19. l := test.NewLogger()
  20. c := &cert.NebulaCertificate{}
  21. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  22. conntrack := fw.Conntrack
  23. assert.NotNil(t, conntrack)
  24. assert.NotNil(t, conntrack.Conns)
  25. assert.NotNil(t, conntrack.TimerWheel)
  26. assert.NotNil(t, fw.InRules)
  27. assert.NotNil(t, fw.OutRules)
  28. assert.Equal(t, time.Second, fw.TCPTimeout)
  29. assert.Equal(t, time.Minute, fw.UDPTimeout)
  30. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  46. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  47. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  48. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  49. }
  50. func TestFirewall_AddRule(t *testing.T) {
  51. l := test.NewLogger()
  52. ob := &bytes.Buffer{}
  53. l.SetOutput(ob)
  54. c := &cert.NebulaCertificate{}
  55. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  56. assert.NotNil(t, fw.InRules)
  57. assert.NotNil(t, fw.OutRules)
  58. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  59. assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
  60. // An empty rule is any
  61. assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
  62. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  63. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  64. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  65. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
  66. assert.Nil(t, fw.InRules.UDP[1].Any.Any)
  67. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
  68. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  69. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  70. assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
  71. assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
  72. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  73. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  74. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  75. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
  76. assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
  77. ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.GetCIDR(ti)
  78. assert.True(t, ok)
  79. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  80. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
  81. assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
  82. ok, _ = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.GetCIDR(ti)
  83. assert.True(t, ok)
  84. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  85. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
  86. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  87. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  88. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
  89. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  90. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  91. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
  92. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  93. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  94. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  95. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
  96. assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
  97. // Test error conditions
  98. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  99. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
  100. assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
  101. }
  102. func TestFirewall_Drop(t *testing.T) {
  103. l := test.NewLogger()
  104. ob := &bytes.Buffer{}
  105. l.SetOutput(ob)
  106. p := firewall.Packet{
  107. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  108. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  109. LocalPort: 10,
  110. RemotePort: 90,
  111. Protocol: firewall.ProtoUDP,
  112. Fragment: false,
  113. }
  114. ipNet := net.IPNet{
  115. IP: net.IPv4(1, 2, 3, 4),
  116. Mask: net.IPMask{255, 255, 255, 0},
  117. }
  118. c := cert.NebulaCertificate{
  119. Details: cert.NebulaCertificateDetails{
  120. Name: "host1",
  121. Ips: []*net.IPNet{&ipNet},
  122. Groups: []string{"default-group"},
  123. InvertedGroups: map[string]struct{}{"default-group": {}},
  124. Issuer: "signer-shasum",
  125. },
  126. }
  127. h := HostInfo{
  128. ConnectionState: &ConnectionState{
  129. peerCert: &c,
  130. },
  131. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  132. }
  133. h.CreateRemoteCIDR(&c)
  134. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  135. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  136. cp := cert.NewCAPool()
  137. // Drop outbound
  138. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  139. // Allow inbound
  140. resetConntrack(fw)
  141. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  142. // Allow outbound because conntrack
  143. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  144. // test remote mismatch
  145. oldRemote := p.RemoteIP
  146. p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
  147. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
  148. p.RemoteIP = oldRemote
  149. // ensure signer doesn't get in the way of group checks
  150. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  151. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
  152. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
  153. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  154. // test caSha doesn't drop on match
  155. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  156. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
  157. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
  158. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  159. // ensure ca name doesn't get in the way of group checks
  160. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  161. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  162. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
  163. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
  164. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  165. // test caName doesn't drop on match
  166. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  167. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  168. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
  169. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
  170. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  171. }
  172. func BenchmarkFirewallTable_match(b *testing.B) {
  173. f := &Firewall{}
  174. ft := FirewallTable{
  175. TCP: firewallPort{},
  176. }
  177. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  178. goodLocalCIDRIP := iputil.Ip2VpnIp(n.IP)
  179. _ = ft.TCP.addRule(f, 10, 10, []string{"good-group"}, "good-host", n, nil, "", "")
  180. _ = ft.TCP.addRule(f, 100, 100, []string{"good-group"}, "good-host", nil, n, "", "")
  181. cp := cert.NewCAPool()
  182. b.Run("fail on proto", func(b *testing.B) {
  183. // This benchmark is showing us the cost of failing to match the protocol
  184. c := &cert.NebulaCertificate{}
  185. for n := 0; n < b.N; n++ {
  186. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp))
  187. }
  188. })
  189. b.Run("pass proto, fail on port", func(b *testing.B) {
  190. // This benchmark is showing us the cost of matching a specific protocol but failing to match the port
  191. c := &cert.NebulaCertificate{}
  192. for n := 0; n < b.N; n++ {
  193. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp))
  194. }
  195. })
  196. b.Run("pass proto, port, fail on local CIDR", func(b *testing.B) {
  197. c := &cert.NebulaCertificate{}
  198. ip, _, _ := net.ParseCIDR("9.254.254.254/32")
  199. lip := iputil.Ip2VpnIp(ip)
  200. for n := 0; n < b.N; n++ {
  201. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: lip}, true, c, cp))
  202. }
  203. })
  204. b.Run("pass proto, port, any local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  205. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  206. c := &cert.NebulaCertificate{
  207. Details: cert.NebulaCertificateDetails{
  208. InvertedGroups: map[string]struct{}{"nope": {}},
  209. Name: "nope",
  210. Ips: []*net.IPNet{ip},
  211. },
  212. }
  213. for n := 0; n < b.N; n++ {
  214. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  215. }
  216. })
  217. b.Run("pass proto, port, specific local CIDR, fail all group, name, and cidr", func(b *testing.B) {
  218. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  219. c := &cert.NebulaCertificate{
  220. Details: cert.NebulaCertificateDetails{
  221. InvertedGroups: map[string]struct{}{"nope": {}},
  222. Name: "nope",
  223. Ips: []*net.IPNet{ip},
  224. },
  225. }
  226. for n := 0; n < b.N; n++ {
  227. assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
  228. }
  229. })
  230. b.Run("pass on group on any local cidr", func(b *testing.B) {
  231. c := &cert.NebulaCertificate{
  232. Details: cert.NebulaCertificateDetails{
  233. InvertedGroups: map[string]struct{}{"good-group": {}},
  234. Name: "nope",
  235. },
  236. }
  237. for n := 0; n < b.N; n++ {
  238. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp))
  239. }
  240. })
  241. b.Run("pass on group on specific local cidr", func(b *testing.B) {
  242. c := &cert.NebulaCertificate{
  243. Details: cert.NebulaCertificateDetails{
  244. InvertedGroups: map[string]struct{}{"good-group": {}},
  245. Name: "nope",
  246. },
  247. }
  248. for n := 0; n < b.N; n++ {
  249. assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: goodLocalCIDRIP}, true, c, cp))
  250. }
  251. })
  252. b.Run("pass on name", func(b *testing.B) {
  253. c := &cert.NebulaCertificate{
  254. Details: cert.NebulaCertificateDetails{
  255. InvertedGroups: map[string]struct{}{"nope": {}},
  256. Name: "good-host",
  257. },
  258. }
  259. for n := 0; n < b.N; n++ {
  260. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  261. }
  262. })
  263. //
  264. //b.Run("pass on ip", func(b *testing.B) {
  265. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  266. // c := &cert.NebulaCertificate{
  267. // Details: cert.NebulaCertificateDetails{
  268. // InvertedGroups: map[string]struct{}{"nope": {}},
  269. // Name: "good-host",
  270. // },
  271. // }
  272. // for n := 0; n < b.N; n++ {
  273. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  274. // }
  275. //})
  276. //
  277. //b.Run("pass on local ip", func(b *testing.B) {
  278. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  279. // c := &cert.NebulaCertificate{
  280. // Details: cert.NebulaCertificateDetails{
  281. // InvertedGroups: map[string]struct{}{"nope": {}},
  282. // Name: "good-host",
  283. // },
  284. // }
  285. // for n := 0; n < b.N; n++ {
  286. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
  287. // }
  288. //})
  289. //
  290. //_ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
  291. //
  292. //b.Run("pass on ip with any port", func(b *testing.B) {
  293. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  294. // c := &cert.NebulaCertificate{
  295. // Details: cert.NebulaCertificateDetails{
  296. // InvertedGroups: map[string]struct{}{"nope": {}},
  297. // Name: "good-host",
  298. // },
  299. // }
  300. // for n := 0; n < b.N; n++ {
  301. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  302. // }
  303. //})
  304. //
  305. //b.Run("pass on local ip with any port", func(b *testing.B) {
  306. // ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  307. // c := &cert.NebulaCertificate{
  308. // Details: cert.NebulaCertificateDetails{
  309. // InvertedGroups: map[string]struct{}{"nope": {}},
  310. // Name: "good-host",
  311. // },
  312. // }
  313. // for n := 0; n < b.N; n++ {
  314. // ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
  315. // }
  316. //})
  317. }
  318. func TestFirewall_Drop2(t *testing.T) {
  319. l := test.NewLogger()
  320. ob := &bytes.Buffer{}
  321. l.SetOutput(ob)
  322. p := firewall.Packet{
  323. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  324. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  325. LocalPort: 10,
  326. RemotePort: 90,
  327. Protocol: firewall.ProtoUDP,
  328. Fragment: false,
  329. }
  330. ipNet := net.IPNet{
  331. IP: net.IPv4(1, 2, 3, 4),
  332. Mask: net.IPMask{255, 255, 255, 0},
  333. }
  334. c := cert.NebulaCertificate{
  335. Details: cert.NebulaCertificateDetails{
  336. Name: "host1",
  337. Ips: []*net.IPNet{&ipNet},
  338. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  339. },
  340. }
  341. h := HostInfo{
  342. ConnectionState: &ConnectionState{
  343. peerCert: &c,
  344. },
  345. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  346. }
  347. h.CreateRemoteCIDR(&c)
  348. c1 := cert.NebulaCertificate{
  349. Details: cert.NebulaCertificateDetails{
  350. Name: "host1",
  351. Ips: []*net.IPNet{&ipNet},
  352. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  353. },
  354. }
  355. h1 := HostInfo{
  356. ConnectionState: &ConnectionState{
  357. peerCert: &c1,
  358. },
  359. }
  360. h1.CreateRemoteCIDR(&c1)
  361. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  362. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
  363. cp := cert.NewCAPool()
  364. // h1/c1 lacks the proper groups
  365. assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
  366. // c has the proper groups
  367. resetConntrack(fw)
  368. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  369. }
  370. func TestFirewall_Drop3(t *testing.T) {
  371. l := test.NewLogger()
  372. ob := &bytes.Buffer{}
  373. l.SetOutput(ob)
  374. p := firewall.Packet{
  375. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  376. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  377. LocalPort: 1,
  378. RemotePort: 1,
  379. Protocol: firewall.ProtoUDP,
  380. Fragment: false,
  381. }
  382. ipNet := net.IPNet{
  383. IP: net.IPv4(1, 2, 3, 4),
  384. Mask: net.IPMask{255, 255, 255, 0},
  385. }
  386. c := cert.NebulaCertificate{
  387. Details: cert.NebulaCertificateDetails{
  388. Name: "host-owner",
  389. Ips: []*net.IPNet{&ipNet},
  390. },
  391. }
  392. c1 := cert.NebulaCertificate{
  393. Details: cert.NebulaCertificateDetails{
  394. Name: "host1",
  395. Ips: []*net.IPNet{&ipNet},
  396. Issuer: "signer-sha-bad",
  397. },
  398. }
  399. h1 := HostInfo{
  400. ConnectionState: &ConnectionState{
  401. peerCert: &c1,
  402. },
  403. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  404. }
  405. h1.CreateRemoteCIDR(&c1)
  406. c2 := cert.NebulaCertificate{
  407. Details: cert.NebulaCertificateDetails{
  408. Name: "host2",
  409. Ips: []*net.IPNet{&ipNet},
  410. Issuer: "signer-sha",
  411. },
  412. }
  413. h2 := HostInfo{
  414. ConnectionState: &ConnectionState{
  415. peerCert: &c2,
  416. },
  417. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  418. }
  419. h2.CreateRemoteCIDR(&c2)
  420. c3 := cert.NebulaCertificate{
  421. Details: cert.NebulaCertificateDetails{
  422. Name: "host3",
  423. Ips: []*net.IPNet{&ipNet},
  424. Issuer: "signer-sha-bad",
  425. },
  426. }
  427. h3 := HostInfo{
  428. ConnectionState: &ConnectionState{
  429. peerCert: &c3,
  430. },
  431. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  432. }
  433. h3.CreateRemoteCIDR(&c3)
  434. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  435. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
  436. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
  437. cp := cert.NewCAPool()
  438. // c1 should pass because host match
  439. assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
  440. // c2 should pass because ca sha match
  441. resetConntrack(fw)
  442. assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
  443. // c3 should fail because no match
  444. resetConntrack(fw)
  445. assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
  446. }
  447. func TestFirewall_DropConntrackReload(t *testing.T) {
  448. l := test.NewLogger()
  449. ob := &bytes.Buffer{}
  450. l.SetOutput(ob)
  451. p := firewall.Packet{
  452. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  453. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  454. LocalPort: 10,
  455. RemotePort: 90,
  456. Protocol: firewall.ProtoUDP,
  457. Fragment: false,
  458. }
  459. ipNet := net.IPNet{
  460. IP: net.IPv4(1, 2, 3, 4),
  461. Mask: net.IPMask{255, 255, 255, 0},
  462. }
  463. c := cert.NebulaCertificate{
  464. Details: cert.NebulaCertificateDetails{
  465. Name: "host1",
  466. Ips: []*net.IPNet{&ipNet},
  467. Groups: []string{"default-group"},
  468. InvertedGroups: map[string]struct{}{"default-group": {}},
  469. Issuer: "signer-shasum",
  470. },
  471. }
  472. h := HostInfo{
  473. ConnectionState: &ConnectionState{
  474. peerCert: &c,
  475. },
  476. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  477. }
  478. h.CreateRemoteCIDR(&c)
  479. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  480. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  481. cp := cert.NewCAPool()
  482. // Drop outbound
  483. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  484. // Allow inbound
  485. resetConntrack(fw)
  486. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  487. // Allow outbound because conntrack
  488. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  489. oldFw := fw
  490. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  491. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
  492. fw.Conntrack = oldFw.Conntrack
  493. fw.rulesVersion = oldFw.rulesVersion + 1
  494. // Allow outbound because conntrack and new rules allow port 10
  495. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  496. oldFw = fw
  497. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  498. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
  499. fw.Conntrack = oldFw.Conntrack
  500. fw.rulesVersion = oldFw.rulesVersion + 1
  501. // Drop outbound because conntrack doesn't match new ruleset
  502. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  503. }
  504. func BenchmarkLookup(b *testing.B) {
  505. ml := func(m map[string]struct{}, a [][]string) {
  506. for n := 0; n < b.N; n++ {
  507. for _, sg := range a {
  508. found := false
  509. for _, g := range sg {
  510. if _, ok := m[g]; !ok {
  511. found = false
  512. break
  513. }
  514. found = true
  515. }
  516. if found {
  517. return
  518. }
  519. }
  520. }
  521. }
  522. b.Run("array to map best", func(b *testing.B) {
  523. m := map[string]struct{}{
  524. "1ne": {},
  525. "2wo": {},
  526. "3hr": {},
  527. "4ou": {},
  528. "5iv": {},
  529. "6ix": {},
  530. }
  531. a := [][]string{
  532. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  533. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  534. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  535. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  536. {"one", "two", "thr", "fou", "5iv", "6ix"},
  537. {"one", "two", "thr", "fou", "fiv", "6ix"},
  538. {"one", "two", "thr", "fou", "fiv", "six"},
  539. }
  540. for n := 0; n < b.N; n++ {
  541. ml(m, a)
  542. }
  543. })
  544. b.Run("array to map worst", func(b *testing.B) {
  545. m := map[string]struct{}{
  546. "one": {},
  547. "two": {},
  548. "thr": {},
  549. "fou": {},
  550. "fiv": {},
  551. "six": {},
  552. }
  553. a := [][]string{
  554. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  555. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  556. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  557. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  558. {"one", "two", "thr", "fou", "5iv", "6ix"},
  559. {"one", "two", "thr", "fou", "fiv", "6ix"},
  560. {"one", "two", "thr", "fou", "fiv", "six"},
  561. }
  562. for n := 0; n < b.N; n++ {
  563. ml(m, a)
  564. }
  565. })
  566. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  567. }
  568. func Test_parsePort(t *testing.T) {
  569. _, _, err := parsePort("")
  570. assert.EqualError(t, err, "was not a number; ``")
  571. _, _, err = parsePort(" ")
  572. assert.EqualError(t, err, "was not a number; ` `")
  573. _, _, err = parsePort("-")
  574. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  575. _, _, err = parsePort(" - ")
  576. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  577. _, _, err = parsePort("a-b")
  578. assert.EqualError(t, err, "beginning range was not a number; `a`")
  579. _, _, err = parsePort("1-b")
  580. assert.EqualError(t, err, "ending range was not a number; `b`")
  581. s, e, err := parsePort(" 1 - 2 ")
  582. assert.Equal(t, int32(1), s)
  583. assert.Equal(t, int32(2), e)
  584. assert.Nil(t, err)
  585. s, e, err = parsePort("0-1")
  586. assert.Equal(t, int32(0), s)
  587. assert.Equal(t, int32(0), e)
  588. assert.Nil(t, err)
  589. s, e, err = parsePort("9919")
  590. assert.Equal(t, int32(9919), s)
  591. assert.Equal(t, int32(9919), e)
  592. assert.Nil(t, err)
  593. s, e, err = parsePort("any")
  594. assert.Equal(t, int32(0), s)
  595. assert.Equal(t, int32(0), e)
  596. assert.Nil(t, err)
  597. }
  598. func TestNewFirewallFromConfig(t *testing.T) {
  599. l := test.NewLogger()
  600. // Test a bad rule definition
  601. c := &cert.NebulaCertificate{}
  602. conf := config.NewC(l)
  603. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  604. _, err := NewFirewallFromConfig(l, c, conf)
  605. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  606. // Test both port and code
  607. conf = config.NewC(l)
  608. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  609. _, err = NewFirewallFromConfig(l, c, conf)
  610. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  611. // Test missing host, group, cidr, ca_name and ca_sha
  612. conf = config.NewC(l)
  613. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  614. _, err = NewFirewallFromConfig(l, c, conf)
  615. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  616. // Test code/port error
  617. conf = config.NewC(l)
  618. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  619. _, err = NewFirewallFromConfig(l, c, conf)
  620. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  621. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  622. _, err = NewFirewallFromConfig(l, c, conf)
  623. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  624. // Test proto error
  625. conf = config.NewC(l)
  626. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  627. _, err = NewFirewallFromConfig(l, c, conf)
  628. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  629. // Test cidr parse error
  630. conf = config.NewC(l)
  631. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  632. _, err = NewFirewallFromConfig(l, c, conf)
  633. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  634. // Test local_cidr parse error
  635. conf = config.NewC(l)
  636. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  637. _, err = NewFirewallFromConfig(l, c, conf)
  638. assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
  639. // Test both group and groups
  640. conf = config.NewC(l)
  641. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  642. _, err = NewFirewallFromConfig(l, c, conf)
  643. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  644. }
  645. func TestAddFirewallRulesFromConfig(t *testing.T) {
  646. l := test.NewLogger()
  647. // Test adding tcp rule
  648. conf := config.NewC(l)
  649. mf := &mockFirewall{}
  650. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  651. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  652. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  653. // Test adding udp rule
  654. conf = config.NewC(l)
  655. mf = &mockFirewall{}
  656. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  657. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  658. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  659. // Test adding icmp rule
  660. conf = config.NewC(l)
  661. mf = &mockFirewall{}
  662. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  663. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  664. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  665. // Test adding any rule
  666. conf = config.NewC(l)
  667. mf = &mockFirewall{}
  668. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  669. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  670. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  671. // Test adding rule with cidr
  672. cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
  673. conf = config.NewC(l)
  674. mf = &mockFirewall{}
  675. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  676. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  677. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
  678. // Test adding rule with local_cidr
  679. conf = config.NewC(l)
  680. mf = &mockFirewall{}
  681. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  682. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  683. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
  684. // Test adding rule with ca_sha
  685. conf = config.NewC(l)
  686. mf = &mockFirewall{}
  687. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  688. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  689. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
  690. // Test adding rule with ca_name
  691. conf = config.NewC(l)
  692. mf = &mockFirewall{}
  693. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  694. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  695. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
  696. // Test single group
  697. conf = config.NewC(l)
  698. mf = &mockFirewall{}
  699. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  700. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  701. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
  702. // Test single groups
  703. conf = config.NewC(l)
  704. mf = &mockFirewall{}
  705. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  706. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  707. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
  708. // Test multiple AND groups
  709. conf = config.NewC(l)
  710. mf = &mockFirewall{}
  711. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  712. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  713. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
  714. // Test Add error
  715. conf = config.NewC(l)
  716. mf = &mockFirewall{}
  717. mf.nextCallReturn = errors.New("test error")
  718. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  719. assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  720. }
  721. func TestTCPRTTTracking(t *testing.T) {
  722. b := make([]byte, 200)
  723. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  724. b[0] = 15
  725. b[60+12] = 15 << 4
  726. f := Firewall{
  727. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  728. }
  729. // Set SEQ to 1
  730. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  731. c := &conn{}
  732. setTCPRTTTracking(c, b)
  733. assert.Equal(t, uint32(1), c.Seq)
  734. // Bad ack - no ack flag
  735. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  736. assert.False(t, f.checkTCPRTT(c, b))
  737. // Bad ack, number is too low
  738. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  739. b[60+13] = uint8(0x10)
  740. assert.False(t, f.checkTCPRTT(c, b))
  741. // Good ack
  742. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  743. assert.True(t, f.checkTCPRTT(c, b))
  744. assert.Equal(t, uint32(0), c.Seq)
  745. // Set SEQ to 1
  746. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  747. c = &conn{}
  748. setTCPRTTTracking(c, b)
  749. assert.Equal(t, uint32(1), c.Seq)
  750. // Good acks
  751. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  752. assert.True(t, f.checkTCPRTT(c, b))
  753. assert.Equal(t, uint32(0), c.Seq)
  754. // Set SEQ to max uint32 - 20
  755. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  756. c = &conn{}
  757. setTCPRTTTracking(c, b)
  758. assert.Equal(t, ^uint32(0)-20, c.Seq)
  759. // Good acks
  760. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  761. assert.True(t, f.checkTCPRTT(c, b))
  762. assert.Equal(t, uint32(0), c.Seq)
  763. // Set SEQ to max uint32 / 2
  764. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  765. c = &conn{}
  766. setTCPRTTTracking(c, b)
  767. assert.Equal(t, ^uint32(0)/2, c.Seq)
  768. // Below
  769. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  770. assert.False(t, f.checkTCPRTT(c, b))
  771. assert.Equal(t, ^uint32(0)/2, c.Seq)
  772. // Halfway below
  773. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  774. assert.False(t, f.checkTCPRTT(c, b))
  775. assert.Equal(t, ^uint32(0)/2, c.Seq)
  776. // Halfway above is ok
  777. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  778. assert.True(t, f.checkTCPRTT(c, b))
  779. assert.Equal(t, uint32(0), c.Seq)
  780. // Set SEQ to max uint32
  781. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  782. c = &conn{}
  783. setTCPRTTTracking(c, b)
  784. assert.Equal(t, ^uint32(0), c.Seq)
  785. // Halfway + 1 above
  786. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  787. assert.False(t, f.checkTCPRTT(c, b))
  788. assert.Equal(t, ^uint32(0), c.Seq)
  789. // Halfway above
  790. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  791. assert.True(t, f.checkTCPRTT(c, b))
  792. assert.Equal(t, uint32(0), c.Seq)
  793. }
  794. func TestFirewall_convertRule(t *testing.T) {
  795. l := test.NewLogger()
  796. ob := &bytes.Buffer{}
  797. l.SetOutput(ob)
  798. // Ensure group array of 1 is converted and a warning is printed
  799. c := map[interface{}]interface{}{
  800. "group": []interface{}{"group1"},
  801. }
  802. r, err := convertRule(l, c, "test", 1)
  803. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  804. assert.Nil(t, err)
  805. assert.Equal(t, "group1", r.Group)
  806. // Ensure group array of > 1 is errord
  807. ob.Reset()
  808. c = map[interface{}]interface{}{
  809. "group": []interface{}{"group1", "group2"},
  810. }
  811. r, err = convertRule(l, c, "test", 1)
  812. assert.Equal(t, "", ob.String())
  813. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  814. // Make sure a well formed group is alright
  815. ob.Reset()
  816. c = map[interface{}]interface{}{
  817. "group": "group1",
  818. }
  819. r, err = convertRule(l, c, "test", 1)
  820. assert.Nil(t, err)
  821. assert.Equal(t, "group1", r.Group)
  822. }
  823. type addRuleCall struct {
  824. incoming bool
  825. proto uint8
  826. startPort int32
  827. endPort int32
  828. groups []string
  829. host string
  830. ip *net.IPNet
  831. localIp *net.IPNet
  832. caName string
  833. caSha string
  834. }
  835. type mockFirewall struct {
  836. lastCall addRuleCall
  837. nextCallReturn error
  838. }
  839. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
  840. mf.lastCall = addRuleCall{
  841. incoming: incoming,
  842. proto: proto,
  843. startPort: startPort,
  844. endPort: endPort,
  845. groups: groups,
  846. host: host,
  847. ip: ip,
  848. localIp: localIp,
  849. caName: caName,
  850. caSha: caSha,
  851. }
  852. err := mf.nextCallReturn
  853. mf.nextCallReturn = nil
  854. return err
  855. }
  856. func resetConntrack(fw *Firewall) {
  857. fw.Conntrack.Lock()
  858. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  859. fw.Conntrack.Unlock()
  860. }