firewall_test.go 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987
  1. package nebula
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "math"
  7. "net"
  8. "testing"
  9. "time"
  10. "github.com/rcrowley/go-metrics"
  11. "github.com/slackhq/nebula/cert"
  12. "github.com/slackhq/nebula/config"
  13. "github.com/slackhq/nebula/firewall"
  14. "github.com/slackhq/nebula/iputil"
  15. "github.com/slackhq/nebula/test"
  16. "github.com/stretchr/testify/assert"
  17. )
  18. func TestNewFirewall(t *testing.T) {
  19. l := test.NewLogger()
  20. c := &cert.NebulaCertificate{}
  21. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  22. conntrack := fw.Conntrack
  23. assert.NotNil(t, conntrack)
  24. assert.NotNil(t, conntrack.Conns)
  25. assert.NotNil(t, conntrack.TimerWheel)
  26. assert.NotNil(t, fw.InRules)
  27. assert.NotNil(t, fw.OutRules)
  28. assert.Equal(t, time.Second, fw.TCPTimeout)
  29. assert.Equal(t, time.Minute, fw.UDPTimeout)
  30. assert.Equal(t, time.Hour, fw.DefaultTimeout)
  31. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  32. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  33. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  34. fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
  35. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  36. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  37. fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
  38. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  39. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  40. fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
  41. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  42. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  43. fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
  44. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  45. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  46. fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
  47. assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
  48. assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
  49. }
  50. func TestFirewall_AddRule(t *testing.T) {
  51. l := test.NewLogger()
  52. ob := &bytes.Buffer{}
  53. l.SetOutput(ob)
  54. c := &cert.NebulaCertificate{}
  55. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  56. assert.NotNil(t, fw.InRules)
  57. assert.NotNil(t, fw.OutRules)
  58. _, ti, _ := net.ParseCIDR("1.2.3.4/32")
  59. assert.Nil(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", nil, nil, "", ""))
  60. // An empty rule is any
  61. assert.True(t, fw.InRules.TCP[1].Any.Any)
  62. assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
  63. assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
  64. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  65. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", ""))
  66. assert.False(t, fw.InRules.UDP[1].Any.Any)
  67. assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0], "g1")
  68. assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
  69. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  70. assert.Nil(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", nil, nil, "", ""))
  71. assert.False(t, fw.InRules.ICMP[1].Any.Any)
  72. assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
  73. assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
  74. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  75. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, nil, "", ""))
  76. assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
  77. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
  78. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
  79. ok, _ := fw.OutRules.AnyProto[1].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
  80. assert.True(t, ok)
  81. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  82. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", nil, ti, "", ""))
  83. assert.False(t, fw.OutRules.AnyProto[1].Any.Any)
  84. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Groups)
  85. assert.Empty(t, fw.OutRules.AnyProto[1].Any.Hosts)
  86. ok, _ = fw.OutRules.AnyProto[1].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
  87. assert.True(t, ok)
  88. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  89. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "ca-name", ""))
  90. assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
  91. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  92. assert.Nil(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", nil, nil, "", "ca-sha"))
  93. assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
  94. // Set any and clear fields
  95. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  96. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"g1", "g2"}, "h1", ti, ti, "", ""))
  97. assert.Equal(t, []string{"g1", "g2"}, fw.OutRules.AnyProto[0].Any.Groups[0])
  98. assert.Contains(t, fw.OutRules.AnyProto[0].Any.Hosts, "h1")
  99. ok, _ = fw.OutRules.AnyProto[0].Any.CIDR.Match(iputil.Ip2VpnIp(ti.IP))
  100. assert.True(t, ok)
  101. ok, _ = fw.OutRules.AnyProto[0].Any.LocalCIDR.Match(iputil.Ip2VpnIp(ti.IP))
  102. assert.True(t, ok)
  103. // run twice just to make sure
  104. //TODO: these ANY rules should clear the CA firewall portion
  105. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  106. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
  107. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  108. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Groups)
  109. assert.Empty(t, fw.OutRules.AnyProto[0].Any.Hosts)
  110. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  111. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", nil, nil, "", ""))
  112. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  113. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  114. _, anyIp, _ := net.ParseCIDR("0.0.0.0/0")
  115. assert.Nil(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, nil, "", ""))
  116. assert.True(t, fw.OutRules.AnyProto[0].Any.Any)
  117. // Test error conditions
  118. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
  119. assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", nil, nil, "", ""))
  120. assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", nil, nil, "", ""))
  121. }
  122. func TestFirewall_Drop(t *testing.T) {
  123. l := test.NewLogger()
  124. ob := &bytes.Buffer{}
  125. l.SetOutput(ob)
  126. p := firewall.Packet{
  127. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  128. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  129. LocalPort: 10,
  130. RemotePort: 90,
  131. Protocol: firewall.ProtoUDP,
  132. Fragment: false,
  133. }
  134. ipNet := net.IPNet{
  135. IP: net.IPv4(1, 2, 3, 4),
  136. Mask: net.IPMask{255, 255, 255, 0},
  137. }
  138. c := cert.NebulaCertificate{
  139. Details: cert.NebulaCertificateDetails{
  140. Name: "host1",
  141. Ips: []*net.IPNet{&ipNet},
  142. Groups: []string{"default-group"},
  143. InvertedGroups: map[string]struct{}{"default-group": {}},
  144. Issuer: "signer-shasum",
  145. },
  146. }
  147. h := HostInfo{
  148. ConnectionState: &ConnectionState{
  149. peerCert: &c,
  150. },
  151. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  152. }
  153. h.CreateRemoteCIDR(&c)
  154. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  155. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  156. cp := cert.NewCAPool()
  157. // Drop outbound
  158. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  159. // Allow inbound
  160. resetConntrack(fw)
  161. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  162. // Allow outbound because conntrack
  163. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  164. // test remote mismatch
  165. oldRemote := p.RemoteIP
  166. p.RemoteIP = iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 10))
  167. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrInvalidRemoteIP)
  168. p.RemoteIP = oldRemote
  169. // ensure signer doesn't get in the way of group checks
  170. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  171. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum"))
  172. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum-bad"))
  173. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  174. // test caSha doesn't drop on match
  175. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  176. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "", "signer-shasum-bad"))
  177. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "", "signer-shasum"))
  178. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  179. // ensure ca name doesn't get in the way of group checks
  180. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  181. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  182. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good", ""))
  183. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good-bad", ""))
  184. assert.Equal(t, fw.Drop([]byte{}, p, true, &h, cp, nil), ErrNoMatchingRule)
  185. // test caName doesn't drop on match
  186. cp.CAs["signer-shasum"] = &cert.NebulaCertificate{Details: cert.NebulaCertificateDetails{Name: "ca-good"}}
  187. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  188. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", nil, nil, "ca-good-bad", ""))
  189. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", nil, nil, "ca-good", ""))
  190. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  191. }
  192. func BenchmarkFirewallTable_match(b *testing.B) {
  193. ft := FirewallTable{
  194. TCP: firewallPort{},
  195. }
  196. _, n, _ := net.ParseCIDR("172.1.1.1/32")
  197. _ = ft.TCP.addRule(10, 10, []string{"good-group"}, "good-host", n, n, "", "")
  198. _ = ft.TCP.addRule(10, 10, []string{"good-group2"}, "good-host", n, n, "", "")
  199. _ = ft.TCP.addRule(10, 10, []string{"good-group3"}, "good-host", n, n, "", "")
  200. _ = ft.TCP.addRule(10, 10, []string{"good-group4"}, "good-host", n, n, "", "")
  201. _ = ft.TCP.addRule(10, 10, []string{"good-group, good-group1"}, "good-host", n, n, "", "")
  202. cp := cert.NewCAPool()
  203. b.Run("fail on proto", func(b *testing.B) {
  204. c := &cert.NebulaCertificate{}
  205. for n := 0; n < b.N; n++ {
  206. ft.match(firewall.Packet{Protocol: firewall.ProtoUDP}, true, c, cp)
  207. }
  208. })
  209. b.Run("fail on port", func(b *testing.B) {
  210. c := &cert.NebulaCertificate{}
  211. for n := 0; n < b.N; n++ {
  212. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 1}, true, c, cp)
  213. }
  214. })
  215. b.Run("fail all group, name, and cidr", func(b *testing.B) {
  216. _, ip, _ := net.ParseCIDR("9.254.254.254/32")
  217. c := &cert.NebulaCertificate{
  218. Details: cert.NebulaCertificateDetails{
  219. InvertedGroups: map[string]struct{}{"nope": {}},
  220. Name: "nope",
  221. Ips: []*net.IPNet{ip},
  222. },
  223. }
  224. for n := 0; n < b.N; n++ {
  225. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  226. }
  227. })
  228. b.Run("pass on group", func(b *testing.B) {
  229. c := &cert.NebulaCertificate{
  230. Details: cert.NebulaCertificateDetails{
  231. InvertedGroups: map[string]struct{}{"good-group": {}},
  232. Name: "nope",
  233. },
  234. }
  235. for n := 0; n < b.N; n++ {
  236. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  237. }
  238. })
  239. b.Run("pass on name", func(b *testing.B) {
  240. c := &cert.NebulaCertificate{
  241. Details: cert.NebulaCertificateDetails{
  242. InvertedGroups: map[string]struct{}{"nope": {}},
  243. Name: "good-host",
  244. },
  245. }
  246. for n := 0; n < b.N; n++ {
  247. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10}, true, c, cp)
  248. }
  249. })
  250. b.Run("pass on ip", func(b *testing.B) {
  251. ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  252. c := &cert.NebulaCertificate{
  253. Details: cert.NebulaCertificateDetails{
  254. InvertedGroups: map[string]struct{}{"nope": {}},
  255. Name: "good-host",
  256. },
  257. }
  258. for n := 0; n < b.N; n++ {
  259. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, RemoteIP: ip}, true, c, cp)
  260. }
  261. })
  262. b.Run("pass on local ip", func(b *testing.B) {
  263. ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  264. c := &cert.NebulaCertificate{
  265. Details: cert.NebulaCertificateDetails{
  266. InvertedGroups: map[string]struct{}{"nope": {}},
  267. Name: "good-host",
  268. },
  269. }
  270. for n := 0; n < b.N; n++ {
  271. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 10, LocalIP: ip}, true, c, cp)
  272. }
  273. })
  274. _ = ft.TCP.addRule(0, 0, []string{"good-group"}, "good-host", n, n, "", "")
  275. b.Run("pass on ip with any port", func(b *testing.B) {
  276. ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  277. c := &cert.NebulaCertificate{
  278. Details: cert.NebulaCertificateDetails{
  279. InvertedGroups: map[string]struct{}{"nope": {}},
  280. Name: "good-host",
  281. },
  282. }
  283. for n := 0; n < b.N; n++ {
  284. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, RemoteIP: ip}, true, c, cp)
  285. }
  286. })
  287. b.Run("pass on local ip with any port", func(b *testing.B) {
  288. ip := iputil.Ip2VpnIp(net.IPv4(172, 1, 1, 1))
  289. c := &cert.NebulaCertificate{
  290. Details: cert.NebulaCertificateDetails{
  291. InvertedGroups: map[string]struct{}{"nope": {}},
  292. Name: "good-host",
  293. },
  294. }
  295. for n := 0; n < b.N; n++ {
  296. ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip}, true, c, cp)
  297. }
  298. })
  299. }
  300. func TestFirewall_Drop2(t *testing.T) {
  301. l := test.NewLogger()
  302. ob := &bytes.Buffer{}
  303. l.SetOutput(ob)
  304. p := firewall.Packet{
  305. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  306. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  307. LocalPort: 10,
  308. RemotePort: 90,
  309. Protocol: firewall.ProtoUDP,
  310. Fragment: false,
  311. }
  312. ipNet := net.IPNet{
  313. IP: net.IPv4(1, 2, 3, 4),
  314. Mask: net.IPMask{255, 255, 255, 0},
  315. }
  316. c := cert.NebulaCertificate{
  317. Details: cert.NebulaCertificateDetails{
  318. Name: "host1",
  319. Ips: []*net.IPNet{&ipNet},
  320. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group": {}},
  321. },
  322. }
  323. h := HostInfo{
  324. ConnectionState: &ConnectionState{
  325. peerCert: &c,
  326. },
  327. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  328. }
  329. h.CreateRemoteCIDR(&c)
  330. c1 := cert.NebulaCertificate{
  331. Details: cert.NebulaCertificateDetails{
  332. Name: "host1",
  333. Ips: []*net.IPNet{&ipNet},
  334. InvertedGroups: map[string]struct{}{"default-group": {}, "test-group-not": {}},
  335. },
  336. }
  337. h1 := HostInfo{
  338. ConnectionState: &ConnectionState{
  339. peerCert: &c1,
  340. },
  341. }
  342. h1.CreateRemoteCIDR(&c1)
  343. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  344. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", nil, nil, "", ""))
  345. cp := cert.NewCAPool()
  346. // h1/c1 lacks the proper groups
  347. assert.Error(t, fw.Drop([]byte{}, p, true, &h1, cp, nil), ErrNoMatchingRule)
  348. // c has the proper groups
  349. resetConntrack(fw)
  350. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  351. }
  352. func TestFirewall_Drop3(t *testing.T) {
  353. l := test.NewLogger()
  354. ob := &bytes.Buffer{}
  355. l.SetOutput(ob)
  356. p := firewall.Packet{
  357. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  358. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  359. LocalPort: 1,
  360. RemotePort: 1,
  361. Protocol: firewall.ProtoUDP,
  362. Fragment: false,
  363. }
  364. ipNet := net.IPNet{
  365. IP: net.IPv4(1, 2, 3, 4),
  366. Mask: net.IPMask{255, 255, 255, 0},
  367. }
  368. c := cert.NebulaCertificate{
  369. Details: cert.NebulaCertificateDetails{
  370. Name: "host-owner",
  371. Ips: []*net.IPNet{&ipNet},
  372. },
  373. }
  374. c1 := cert.NebulaCertificate{
  375. Details: cert.NebulaCertificateDetails{
  376. Name: "host1",
  377. Ips: []*net.IPNet{&ipNet},
  378. Issuer: "signer-sha-bad",
  379. },
  380. }
  381. h1 := HostInfo{
  382. ConnectionState: &ConnectionState{
  383. peerCert: &c1,
  384. },
  385. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  386. }
  387. h1.CreateRemoteCIDR(&c1)
  388. c2 := cert.NebulaCertificate{
  389. Details: cert.NebulaCertificateDetails{
  390. Name: "host2",
  391. Ips: []*net.IPNet{&ipNet},
  392. Issuer: "signer-sha",
  393. },
  394. }
  395. h2 := HostInfo{
  396. ConnectionState: &ConnectionState{
  397. peerCert: &c2,
  398. },
  399. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  400. }
  401. h2.CreateRemoteCIDR(&c2)
  402. c3 := cert.NebulaCertificate{
  403. Details: cert.NebulaCertificateDetails{
  404. Name: "host3",
  405. Ips: []*net.IPNet{&ipNet},
  406. Issuer: "signer-sha-bad",
  407. },
  408. }
  409. h3 := HostInfo{
  410. ConnectionState: &ConnectionState{
  411. peerCert: &c3,
  412. },
  413. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  414. }
  415. h3.CreateRemoteCIDR(&c3)
  416. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  417. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", nil, nil, "", ""))
  418. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", nil, nil, "", "signer-sha"))
  419. cp := cert.NewCAPool()
  420. // c1 should pass because host match
  421. assert.NoError(t, fw.Drop([]byte{}, p, true, &h1, cp, nil))
  422. // c2 should pass because ca sha match
  423. resetConntrack(fw)
  424. assert.NoError(t, fw.Drop([]byte{}, p, true, &h2, cp, nil))
  425. // c3 should fail because no match
  426. resetConntrack(fw)
  427. assert.Equal(t, fw.Drop([]byte{}, p, true, &h3, cp, nil), ErrNoMatchingRule)
  428. }
  429. func TestFirewall_DropConntrackReload(t *testing.T) {
  430. l := test.NewLogger()
  431. ob := &bytes.Buffer{}
  432. l.SetOutput(ob)
  433. p := firewall.Packet{
  434. LocalIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  435. RemoteIP: iputil.Ip2VpnIp(net.IPv4(1, 2, 3, 4)),
  436. LocalPort: 10,
  437. RemotePort: 90,
  438. Protocol: firewall.ProtoUDP,
  439. Fragment: false,
  440. }
  441. ipNet := net.IPNet{
  442. IP: net.IPv4(1, 2, 3, 4),
  443. Mask: net.IPMask{255, 255, 255, 0},
  444. }
  445. c := cert.NebulaCertificate{
  446. Details: cert.NebulaCertificateDetails{
  447. Name: "host1",
  448. Ips: []*net.IPNet{&ipNet},
  449. Groups: []string{"default-group"},
  450. InvertedGroups: map[string]struct{}{"default-group": {}},
  451. Issuer: "signer-shasum",
  452. },
  453. }
  454. h := HostInfo{
  455. ConnectionState: &ConnectionState{
  456. peerCert: &c,
  457. },
  458. vpnIp: iputil.Ip2VpnIp(ipNet.IP),
  459. }
  460. h.CreateRemoteCIDR(&c)
  461. fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  462. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", nil, nil, "", ""))
  463. cp := cert.NewCAPool()
  464. // Drop outbound
  465. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  466. // Allow inbound
  467. resetConntrack(fw)
  468. assert.NoError(t, fw.Drop([]byte{}, p, true, &h, cp, nil))
  469. // Allow outbound because conntrack
  470. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  471. oldFw := fw
  472. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  473. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", nil, nil, "", ""))
  474. fw.Conntrack = oldFw.Conntrack
  475. fw.rulesVersion = oldFw.rulesVersion + 1
  476. // Allow outbound because conntrack and new rules allow port 10
  477. assert.NoError(t, fw.Drop([]byte{}, p, false, &h, cp, nil))
  478. oldFw = fw
  479. fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
  480. assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", nil, nil, "", ""))
  481. fw.Conntrack = oldFw.Conntrack
  482. fw.rulesVersion = oldFw.rulesVersion + 1
  483. // Drop outbound because conntrack doesn't match new ruleset
  484. assert.Equal(t, fw.Drop([]byte{}, p, false, &h, cp, nil), ErrNoMatchingRule)
  485. }
  486. func BenchmarkLookup(b *testing.B) {
  487. ml := func(m map[string]struct{}, a [][]string) {
  488. for n := 0; n < b.N; n++ {
  489. for _, sg := range a {
  490. found := false
  491. for _, g := range sg {
  492. if _, ok := m[g]; !ok {
  493. found = false
  494. break
  495. }
  496. found = true
  497. }
  498. if found {
  499. return
  500. }
  501. }
  502. }
  503. }
  504. b.Run("array to map best", func(b *testing.B) {
  505. m := map[string]struct{}{
  506. "1ne": {},
  507. "2wo": {},
  508. "3hr": {},
  509. "4ou": {},
  510. "5iv": {},
  511. "6ix": {},
  512. }
  513. a := [][]string{
  514. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  515. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  516. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  517. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  518. {"one", "two", "thr", "fou", "5iv", "6ix"},
  519. {"one", "two", "thr", "fou", "fiv", "6ix"},
  520. {"one", "two", "thr", "fou", "fiv", "six"},
  521. }
  522. for n := 0; n < b.N; n++ {
  523. ml(m, a)
  524. }
  525. })
  526. b.Run("array to map worst", func(b *testing.B) {
  527. m := map[string]struct{}{
  528. "one": {},
  529. "two": {},
  530. "thr": {},
  531. "fou": {},
  532. "fiv": {},
  533. "six": {},
  534. }
  535. a := [][]string{
  536. {"1ne", "2wo", "3hr", "4ou", "5iv", "6ix"},
  537. {"one", "2wo", "3hr", "4ou", "5iv", "6ix"},
  538. {"one", "two", "3hr", "4ou", "5iv", "6ix"},
  539. {"one", "two", "thr", "4ou", "5iv", "6ix"},
  540. {"one", "two", "thr", "fou", "5iv", "6ix"},
  541. {"one", "two", "thr", "fou", "fiv", "6ix"},
  542. {"one", "two", "thr", "fou", "fiv", "six"},
  543. }
  544. for n := 0; n < b.N; n++ {
  545. ml(m, a)
  546. }
  547. })
  548. //TODO: only way array lookup in array will help is if both are sorted, then maybe it's faster
  549. }
  550. func Test_parsePort(t *testing.T) {
  551. _, _, err := parsePort("")
  552. assert.EqualError(t, err, "was not a number; ``")
  553. _, _, err = parsePort(" ")
  554. assert.EqualError(t, err, "was not a number; ` `")
  555. _, _, err = parsePort("-")
  556. assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
  557. _, _, err = parsePort(" - ")
  558. assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
  559. _, _, err = parsePort("a-b")
  560. assert.EqualError(t, err, "beginning range was not a number; `a`")
  561. _, _, err = parsePort("1-b")
  562. assert.EqualError(t, err, "ending range was not a number; `b`")
  563. s, e, err := parsePort(" 1 - 2 ")
  564. assert.Equal(t, int32(1), s)
  565. assert.Equal(t, int32(2), e)
  566. assert.Nil(t, err)
  567. s, e, err = parsePort("0-1")
  568. assert.Equal(t, int32(0), s)
  569. assert.Equal(t, int32(0), e)
  570. assert.Nil(t, err)
  571. s, e, err = parsePort("9919")
  572. assert.Equal(t, int32(9919), s)
  573. assert.Equal(t, int32(9919), e)
  574. assert.Nil(t, err)
  575. s, e, err = parsePort("any")
  576. assert.Equal(t, int32(0), s)
  577. assert.Equal(t, int32(0), e)
  578. assert.Nil(t, err)
  579. }
  580. func TestNewFirewallFromConfig(t *testing.T) {
  581. l := test.NewLogger()
  582. // Test a bad rule definition
  583. c := &cert.NebulaCertificate{}
  584. conf := config.NewC(l)
  585. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
  586. _, err := NewFirewallFromConfig(l, c, conf)
  587. assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
  588. // Test both port and code
  589. conf = config.NewC(l)
  590. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
  591. _, err = NewFirewallFromConfig(l, c, conf)
  592. assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
  593. // Test missing host, group, cidr, ca_name and ca_sha
  594. conf = config.NewC(l)
  595. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
  596. _, err = NewFirewallFromConfig(l, c, conf)
  597. assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
  598. // Test code/port error
  599. conf = config.NewC(l)
  600. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
  601. _, err = NewFirewallFromConfig(l, c, conf)
  602. assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
  603. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
  604. _, err = NewFirewallFromConfig(l, c, conf)
  605. assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
  606. // Test proto error
  607. conf = config.NewC(l)
  608. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
  609. _, err = NewFirewallFromConfig(l, c, conf)
  610. assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
  611. // Test cidr parse error
  612. conf = config.NewC(l)
  613. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
  614. _, err = NewFirewallFromConfig(l, c, conf)
  615. assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; invalid CIDR address: testh")
  616. // Test local_cidr parse error
  617. conf = config.NewC(l)
  618. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
  619. _, err = NewFirewallFromConfig(l, c, conf)
  620. assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; invalid CIDR address: testh")
  621. // Test both group and groups
  622. conf = config.NewC(l)
  623. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
  624. _, err = NewFirewallFromConfig(l, c, conf)
  625. assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
  626. }
  627. func TestAddFirewallRulesFromConfig(t *testing.T) {
  628. l := test.NewLogger()
  629. // Test adding tcp rule
  630. conf := config.NewC(l)
  631. mf := &mockFirewall{}
  632. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
  633. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  634. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  635. // Test adding udp rule
  636. conf = config.NewC(l)
  637. mf = &mockFirewall{}
  638. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
  639. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  640. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  641. // Test adding icmp rule
  642. conf = config.NewC(l)
  643. mf = &mockFirewall{}
  644. conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
  645. assert.Nil(t, AddFirewallRulesFromConfig(l, false, conf, mf))
  646. assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  647. // Test adding any rule
  648. conf = config.NewC(l)
  649. mf = &mockFirewall{}
  650. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  651. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  652. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: nil, localIp: nil}, mf.lastCall)
  653. // Test adding rule with cidr
  654. cidr := &net.IPNet{IP: net.ParseIP("10.0.0.0").To4(), Mask: net.IPv4Mask(255, 0, 0, 0)}
  655. conf = config.NewC(l)
  656. mf = &mockFirewall{}
  657. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
  658. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  659. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: nil}, mf.lastCall)
  660. // Test adding rule with local_cidr
  661. conf = config.NewC(l)
  662. mf = &mockFirewall{}
  663. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
  664. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  665. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: cidr}, mf.lastCall)
  666. // Test adding rule with ca_sha
  667. conf = config.NewC(l)
  668. mf = &mockFirewall{}
  669. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
  670. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  671. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caSha: "12312313123"}, mf.lastCall)
  672. // Test adding rule with ca_name
  673. conf = config.NewC(l)
  674. mf = &mockFirewall{}
  675. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
  676. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  677. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: nil, localIp: nil, caName: "root01"}, mf.lastCall)
  678. // Test single group
  679. conf = config.NewC(l)
  680. mf = &mockFirewall{}
  681. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
  682. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  683. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
  684. // Test single groups
  685. conf = config.NewC(l)
  686. mf = &mockFirewall{}
  687. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
  688. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  689. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: nil, localIp: nil}, mf.lastCall)
  690. // Test multiple AND groups
  691. conf = config.NewC(l)
  692. mf = &mockFirewall{}
  693. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
  694. assert.Nil(t, AddFirewallRulesFromConfig(l, true, conf, mf))
  695. assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: nil, localIp: nil}, mf.lastCall)
  696. // Test Add error
  697. conf = config.NewC(l)
  698. mf = &mockFirewall{}
  699. mf.nextCallReturn = errors.New("test error")
  700. conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
  701. assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
  702. }
  703. func TestTCPRTTTracking(t *testing.T) {
  704. b := make([]byte, 200)
  705. // Max ip IHL (60 bytes) and tcp IHL (60 bytes)
  706. b[0] = 15
  707. b[60+12] = 15 << 4
  708. f := Firewall{
  709. metricTCPRTT: metrics.GetOrRegisterHistogram("nope", nil, metrics.NewExpDecaySample(1028, 0.015)),
  710. }
  711. // Set SEQ to 1
  712. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  713. c := &conn{}
  714. setTCPRTTTracking(c, b)
  715. assert.Equal(t, uint32(1), c.Seq)
  716. // Bad ack - no ack flag
  717. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  718. assert.False(t, f.checkTCPRTT(c, b))
  719. // Bad ack, number is too low
  720. binary.BigEndian.PutUint32(b[60+8:60+12], 0)
  721. b[60+13] = uint8(0x10)
  722. assert.False(t, f.checkTCPRTT(c, b))
  723. // Good ack
  724. binary.BigEndian.PutUint32(b[60+8:60+12], 80)
  725. assert.True(t, f.checkTCPRTT(c, b))
  726. assert.Equal(t, uint32(0), c.Seq)
  727. // Set SEQ to 1
  728. binary.BigEndian.PutUint32(b[60+4:60+8], 1)
  729. c = &conn{}
  730. setTCPRTTTracking(c, b)
  731. assert.Equal(t, uint32(1), c.Seq)
  732. // Good acks
  733. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  734. assert.True(t, f.checkTCPRTT(c, b))
  735. assert.Equal(t, uint32(0), c.Seq)
  736. // Set SEQ to max uint32 - 20
  737. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)-20)
  738. c = &conn{}
  739. setTCPRTTTracking(c, b)
  740. assert.Equal(t, ^uint32(0)-20, c.Seq)
  741. // Good acks
  742. binary.BigEndian.PutUint32(b[60+8:60+12], 81)
  743. assert.True(t, f.checkTCPRTT(c, b))
  744. assert.Equal(t, uint32(0), c.Seq)
  745. // Set SEQ to max uint32 / 2
  746. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0)/2)
  747. c = &conn{}
  748. setTCPRTTTracking(c, b)
  749. assert.Equal(t, ^uint32(0)/2, c.Seq)
  750. // Below
  751. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2-1)
  752. assert.False(t, f.checkTCPRTT(c, b))
  753. assert.Equal(t, ^uint32(0)/2, c.Seq)
  754. // Halfway below
  755. binary.BigEndian.PutUint32(b[60+8:60+12], uint32(0))
  756. assert.False(t, f.checkTCPRTT(c, b))
  757. assert.Equal(t, ^uint32(0)/2, c.Seq)
  758. // Halfway above is ok
  759. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0))
  760. assert.True(t, f.checkTCPRTT(c, b))
  761. assert.Equal(t, uint32(0), c.Seq)
  762. // Set SEQ to max uint32
  763. binary.BigEndian.PutUint32(b[60+4:60+8], ^uint32(0))
  764. c = &conn{}
  765. setTCPRTTTracking(c, b)
  766. assert.Equal(t, ^uint32(0), c.Seq)
  767. // Halfway + 1 above
  768. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2+1)
  769. assert.False(t, f.checkTCPRTT(c, b))
  770. assert.Equal(t, ^uint32(0), c.Seq)
  771. // Halfway above
  772. binary.BigEndian.PutUint32(b[60+8:60+12], ^uint32(0)/2)
  773. assert.True(t, f.checkTCPRTT(c, b))
  774. assert.Equal(t, uint32(0), c.Seq)
  775. }
  776. func TestFirewall_convertRule(t *testing.T) {
  777. l := test.NewLogger()
  778. ob := &bytes.Buffer{}
  779. l.SetOutput(ob)
  780. // Ensure group array of 1 is converted and a warning is printed
  781. c := map[interface{}]interface{}{
  782. "group": []interface{}{"group1"},
  783. }
  784. r, err := convertRule(l, c, "test", 1)
  785. assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
  786. assert.Nil(t, err)
  787. assert.Equal(t, "group1", r.Group)
  788. // Ensure group array of > 1 is errord
  789. ob.Reset()
  790. c = map[interface{}]interface{}{
  791. "group": []interface{}{"group1", "group2"},
  792. }
  793. r, err = convertRule(l, c, "test", 1)
  794. assert.Equal(t, "", ob.String())
  795. assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
  796. // Make sure a well formed group is alright
  797. ob.Reset()
  798. c = map[interface{}]interface{}{
  799. "group": "group1",
  800. }
  801. r, err = convertRule(l, c, "test", 1)
  802. assert.Nil(t, err)
  803. assert.Equal(t, "group1", r.Group)
  804. }
  805. type addRuleCall struct {
  806. incoming bool
  807. proto uint8
  808. startPort int32
  809. endPort int32
  810. groups []string
  811. host string
  812. ip *net.IPNet
  813. localIp *net.IPNet
  814. caName string
  815. caSha string
  816. }
  817. type mockFirewall struct {
  818. lastCall addRuleCall
  819. nextCallReturn error
  820. }
  821. func (mf *mockFirewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, localIp *net.IPNet, caName string, caSha string) error {
  822. mf.lastCall = addRuleCall{
  823. incoming: incoming,
  824. proto: proto,
  825. startPort: startPort,
  826. endPort: endPort,
  827. groups: groups,
  828. host: host,
  829. ip: ip,
  830. localIp: localIp,
  831. caName: caName,
  832. caSha: caSha,
  833. }
  834. err := mf.nextCallReturn
  835. mf.nextCallReturn = nil
  836. return err
  837. }
  838. func resetConntrack(fw *Firewall) {
  839. fw.Conntrack.Lock()
  840. fw.Conntrack.Conns = map[firewall.Packet]*conn{}
  841. fw.Conntrack.Unlock()
  842. }