firewall.go 24 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021
  1. package nebula
  2. import (
  3. "crypto/sha256"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "net"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/rcrowley/go-metrics"
  17. "github.com/sirupsen/logrus"
  18. "github.com/slackhq/nebula/cert"
  19. )
  20. const (
  21. fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
  22. fwProtoTCP = 6
  23. fwProtoUDP = 17
  24. fwProtoICMP = 1
  25. fwPortAny = 0 // Special value for matching `port: any`
  26. fwPortFragment = -1 // Special value for matching `port: fragment`
  27. )
  28. const tcpACK = 0x10
  29. const tcpFIN = 0x01
  30. type FirewallInterface interface {
  31. AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
  32. }
  33. type conn struct {
  34. Expires time.Time // Time when this conntrack entry will expire
  35. Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
  36. Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
  37. // record why the original connection passed the firewall, so we can re-validate
  38. // after ruleset changes. Note, rulesVersion is a uint16 so that these two
  39. // fields pack for free after the uint32 above
  40. incoming bool
  41. rulesVersion uint16
  42. }
  43. // TODO: need conntrack max tracked connections handling
  44. type Firewall struct {
  45. Conntrack *FirewallConntrack
  46. InRules *FirewallTable
  47. OutRules *FirewallTable
  48. //TODO: we should have many more options for TCP, an option for ICMP, and mimic the kernel a bit better
  49. // https://www.kernel.org/doc/Documentation/networking/nf_conntrack-sysctl.txt
  50. TCPTimeout time.Duration //linux: 5 days max
  51. UDPTimeout time.Duration //linux: 180s max
  52. DefaultTimeout time.Duration //linux: 600s
  53. // Used to ensure we don't emit local packets for ips we don't own
  54. localIps *CIDRTree
  55. rules string
  56. rulesVersion uint16
  57. trackTCPRTT bool
  58. metricTCPRTT metrics.Histogram
  59. incomingMetrics firewallMetrics
  60. outgoingMetrics firewallMetrics
  61. l *logrus.Logger
  62. }
  63. type firewallMetrics struct {
  64. droppedLocalIP metrics.Counter
  65. droppedRemoteIP metrics.Counter
  66. droppedNoRule metrics.Counter
  67. }
  68. type FirewallConntrack struct {
  69. sync.Mutex
  70. Conns map[FirewallPacket]*conn
  71. TimerWheel *TimerWheel
  72. }
  73. type FirewallTable struct {
  74. TCP firewallPort
  75. UDP firewallPort
  76. ICMP firewallPort
  77. AnyProto firewallPort
  78. }
  79. func newFirewallTable() *FirewallTable {
  80. return &FirewallTable{
  81. TCP: firewallPort{},
  82. UDP: firewallPort{},
  83. ICMP: firewallPort{},
  84. AnyProto: firewallPort{},
  85. }
  86. }
  87. type FirewallCA struct {
  88. Any *FirewallRule
  89. CANames map[string]*FirewallRule
  90. CAShas map[string]*FirewallRule
  91. }
  92. type FirewallRule struct {
  93. // Any makes Hosts, Groups, and CIDR irrelevant
  94. Any bool
  95. Hosts map[string]struct{}
  96. Groups [][]string
  97. CIDR *CIDRTree
  98. }
  99. // Even though ports are uint16, int32 maps are faster for lookup
  100. // Plus we can use `-1` for fragment rules
  101. type firewallPort map[int32]*FirewallCA
  102. type FirewallPacket struct {
  103. LocalIP uint32
  104. RemoteIP uint32
  105. LocalPort uint16
  106. RemotePort uint16
  107. Protocol uint8
  108. Fragment bool
  109. }
  110. func (fp *FirewallPacket) Copy() *FirewallPacket {
  111. return &FirewallPacket{
  112. LocalIP: fp.LocalIP,
  113. RemoteIP: fp.RemoteIP,
  114. LocalPort: fp.LocalPort,
  115. RemotePort: fp.RemotePort,
  116. Protocol: fp.Protocol,
  117. Fragment: fp.Fragment,
  118. }
  119. }
  120. func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
  121. var proto string
  122. switch fp.Protocol {
  123. case fwProtoTCP:
  124. proto = "tcp"
  125. case fwProtoICMP:
  126. proto = "icmp"
  127. case fwProtoUDP:
  128. proto = "udp"
  129. default:
  130. proto = fmt.Sprintf("unknown %v", fp.Protocol)
  131. }
  132. return json.Marshal(m{
  133. "LocalIP": int2ip(fp.LocalIP).String(),
  134. "RemoteIP": int2ip(fp.RemoteIP).String(),
  135. "LocalPort": fp.LocalPort,
  136. "RemotePort": fp.RemotePort,
  137. "Protocol": proto,
  138. "Fragment": fp.Fragment,
  139. })
  140. }
  141. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
  142. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
  143. //TODO: error on 0 duration
  144. var min, max time.Duration
  145. if tcpTimeout < UDPTimeout {
  146. min = tcpTimeout
  147. max = UDPTimeout
  148. } else {
  149. min = UDPTimeout
  150. max = tcpTimeout
  151. }
  152. if defaultTimeout < min {
  153. min = defaultTimeout
  154. } else if defaultTimeout > max {
  155. max = defaultTimeout
  156. }
  157. localIps := NewCIDRTree()
  158. for _, ip := range c.Details.Ips {
  159. localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
  160. }
  161. for _, n := range c.Details.Subnets {
  162. localIps.AddCIDR(n, struct{}{})
  163. }
  164. return &Firewall{
  165. Conntrack: &FirewallConntrack{
  166. Conns: make(map[FirewallPacket]*conn),
  167. TimerWheel: NewTimerWheel(min, max),
  168. },
  169. InRules: newFirewallTable(),
  170. OutRules: newFirewallTable(),
  171. TCPTimeout: tcpTimeout,
  172. UDPTimeout: UDPTimeout,
  173. DefaultTimeout: defaultTimeout,
  174. localIps: localIps,
  175. l: l,
  176. metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
  177. incomingMetrics: firewallMetrics{
  178. droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil),
  179. droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil),
  180. droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil),
  181. },
  182. outgoingMetrics: firewallMetrics{
  183. droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil),
  184. droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil),
  185. droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil),
  186. },
  187. }
  188. }
  189. func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
  190. fw := NewFirewall(
  191. l,
  192. c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
  193. c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
  194. c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
  195. nc,
  196. //TODO: max_connections
  197. )
  198. err := AddFirewallRulesFromConfig(l, false, c, fw)
  199. if err != nil {
  200. return nil, err
  201. }
  202. err = AddFirewallRulesFromConfig(l, true, c, fw)
  203. if err != nil {
  204. return nil, err
  205. }
  206. return fw, nil
  207. }
  208. // AddRule properly creates the in memory rule structure for a firewall table.
  209. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  210. // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
  211. // https://github.com/golang/go/issues/14131
  212. sIp := ""
  213. if ip != nil {
  214. sIp = ip.String()
  215. }
  216. // We need this rule string because we generate a hash. Removing this will break firewall reload.
  217. ruleString := fmt.Sprintf(
  218. "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
  219. incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
  220. )
  221. f.rules += ruleString + "\n"
  222. direction := "incoming"
  223. if !incoming {
  224. direction = "outgoing"
  225. }
  226. f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
  227. Info("Firewall rule added")
  228. var (
  229. ft *FirewallTable
  230. fp firewallPort
  231. )
  232. if incoming {
  233. ft = f.InRules
  234. } else {
  235. ft = f.OutRules
  236. }
  237. switch proto {
  238. case fwProtoTCP:
  239. fp = ft.TCP
  240. case fwProtoUDP:
  241. fp = ft.UDP
  242. case fwProtoICMP:
  243. fp = ft.ICMP
  244. case fwProtoAny:
  245. fp = ft.AnyProto
  246. default:
  247. return fmt.Errorf("unknown protocol %v", proto)
  248. }
  249. return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
  250. }
  251. // GetRuleHash returns a hash representation of all inbound and outbound rules
  252. func (f *Firewall) GetRuleHash() string {
  253. sum := sha256.Sum256([]byte(f.rules))
  254. return hex.EncodeToString(sum[:])
  255. }
  256. func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
  257. var table string
  258. if inbound {
  259. table = "firewall.inbound"
  260. } else {
  261. table = "firewall.outbound"
  262. }
  263. r := config.Get(table)
  264. if r == nil {
  265. return nil
  266. }
  267. rs, ok := r.([]interface{})
  268. if !ok {
  269. return fmt.Errorf("%s failed to parse, should be an array of rules", table)
  270. }
  271. for i, t := range rs {
  272. var groups []string
  273. r, err := convertRule(l, t, table, i)
  274. if err != nil {
  275. return fmt.Errorf("%s rule #%v; %s", table, i, err)
  276. }
  277. if r.Code != "" && r.Port != "" {
  278. return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
  279. }
  280. if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
  281. return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
  282. }
  283. if len(r.Groups) > 0 {
  284. groups = r.Groups
  285. }
  286. if r.Group != "" {
  287. // Check if we have both groups and group provided in the rule config
  288. if len(groups) > 0 {
  289. return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
  290. }
  291. groups = []string{r.Group}
  292. }
  293. var sPort, errPort string
  294. if r.Code != "" {
  295. errPort = "code"
  296. sPort = r.Code
  297. } else {
  298. errPort = "port"
  299. sPort = r.Port
  300. }
  301. startPort, endPort, err := parsePort(sPort)
  302. if err != nil {
  303. return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
  304. }
  305. var proto uint8
  306. switch r.Proto {
  307. case "any":
  308. proto = fwProtoAny
  309. case "tcp":
  310. proto = fwProtoTCP
  311. case "udp":
  312. proto = fwProtoUDP
  313. case "icmp":
  314. proto = fwProtoICMP
  315. default:
  316. return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
  317. }
  318. var cidr *net.IPNet
  319. if r.Cidr != "" {
  320. _, cidr, err = net.ParseCIDR(r.Cidr)
  321. if err != nil {
  322. return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
  323. }
  324. }
  325. err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
  326. if err != nil {
  327. return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
  328. }
  329. }
  330. return nil
  331. }
  332. var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
  333. var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
  334. var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
  335. // Drop returns an error if the packet should be dropped, explaining why. It
  336. // returns nil if the packet should not be dropped.
  337. func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
  338. // Check if we spoke to this tuple, if we did then allow this packet
  339. if f.inConns(packet, fp, incoming, h, caPool, localCache) {
  340. return nil
  341. }
  342. // Make sure remote address matches nebula certificate
  343. if remoteCidr := h.remoteCidr; remoteCidr != nil {
  344. if remoteCidr.Contains(fp.RemoteIP) == nil {
  345. f.metrics(incoming).droppedRemoteIP.Inc(1)
  346. return ErrInvalidRemoteIP
  347. }
  348. } else {
  349. // Simple case: Certificate has one IP and no subnets
  350. if fp.RemoteIP != h.hostId {
  351. f.metrics(incoming).droppedRemoteIP.Inc(1)
  352. return ErrInvalidRemoteIP
  353. }
  354. }
  355. // Make sure we are supposed to be handling this local ip address
  356. if f.localIps.Contains(fp.LocalIP) == nil {
  357. f.metrics(incoming).droppedLocalIP.Inc(1)
  358. return ErrInvalidLocalIP
  359. }
  360. table := f.OutRules
  361. if incoming {
  362. table = f.InRules
  363. }
  364. // We now know which firewall table to check against
  365. if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
  366. f.metrics(incoming).droppedNoRule.Inc(1)
  367. return ErrNoMatchingRule
  368. }
  369. // We always want to conntrack since it is a faster operation
  370. f.addConn(packet, fp, incoming)
  371. return nil
  372. }
  373. func (f *Firewall) metrics(incoming bool) firewallMetrics {
  374. if incoming {
  375. return f.incomingMetrics
  376. } else {
  377. return f.outgoingMetrics
  378. }
  379. }
  380. // Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
  381. // firewall object is created
  382. func (f *Firewall) Destroy() {
  383. //TODO: clean references if/when needed
  384. }
  385. func (f *Firewall) EmitStats() {
  386. conntrack := f.Conntrack
  387. conntrack.Lock()
  388. conntrackCount := len(conntrack.Conns)
  389. conntrack.Unlock()
  390. metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
  391. metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
  392. }
  393. func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
  394. if localCache != nil {
  395. if _, ok := localCache[fp]; ok {
  396. return true
  397. }
  398. }
  399. conntrack := f.Conntrack
  400. conntrack.Lock()
  401. // Purge every time we test
  402. ep, has := conntrack.TimerWheel.Purge()
  403. if has {
  404. f.evict(ep)
  405. }
  406. c, ok := conntrack.Conns[fp]
  407. if !ok {
  408. conntrack.Unlock()
  409. return false
  410. }
  411. if c.rulesVersion != f.rulesVersion {
  412. // This conntrack entry was for an older rule set, validate
  413. // it still passes with the current rule set
  414. table := f.OutRules
  415. if c.incoming {
  416. table = f.InRules
  417. }
  418. // We now know which firewall table to check against
  419. if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
  420. if f.l.Level >= logrus.DebugLevel {
  421. h.logger(f.l).
  422. WithField("fwPacket", fp).
  423. WithField("incoming", c.incoming).
  424. WithField("rulesVersion", f.rulesVersion).
  425. WithField("oldRulesVersion", c.rulesVersion).
  426. Debugln("dropping old conntrack entry, does not match new ruleset")
  427. }
  428. delete(conntrack.Conns, fp)
  429. conntrack.Unlock()
  430. return false
  431. }
  432. if f.l.Level >= logrus.DebugLevel {
  433. h.logger(f.l).
  434. WithField("fwPacket", fp).
  435. WithField("incoming", c.incoming).
  436. WithField("rulesVersion", f.rulesVersion).
  437. WithField("oldRulesVersion", c.rulesVersion).
  438. Debugln("keeping old conntrack entry, does match new ruleset")
  439. }
  440. c.rulesVersion = f.rulesVersion
  441. }
  442. switch fp.Protocol {
  443. case fwProtoTCP:
  444. c.Expires = time.Now().Add(f.TCPTimeout)
  445. if incoming {
  446. f.checkTCPRTT(c, packet)
  447. } else {
  448. setTCPRTTTracking(c, packet)
  449. }
  450. case fwProtoUDP:
  451. c.Expires = time.Now().Add(f.UDPTimeout)
  452. default:
  453. c.Expires = time.Now().Add(f.DefaultTimeout)
  454. }
  455. conntrack.Unlock()
  456. if localCache != nil {
  457. localCache[fp] = struct{}{}
  458. }
  459. return true
  460. }
  461. func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
  462. var timeout time.Duration
  463. c := &conn{}
  464. switch fp.Protocol {
  465. case fwProtoTCP:
  466. timeout = f.TCPTimeout
  467. if !incoming {
  468. setTCPRTTTracking(c, packet)
  469. }
  470. case fwProtoUDP:
  471. timeout = f.UDPTimeout
  472. default:
  473. timeout = f.DefaultTimeout
  474. }
  475. conntrack := f.Conntrack
  476. conntrack.Lock()
  477. if _, ok := conntrack.Conns[fp]; !ok {
  478. conntrack.TimerWheel.Add(fp, timeout)
  479. }
  480. // Record which rulesVersion allowed this connection, so we can retest after
  481. // firewall reload
  482. c.incoming = incoming
  483. c.rulesVersion = f.rulesVersion
  484. c.Expires = time.Now().Add(timeout)
  485. conntrack.Conns[fp] = c
  486. conntrack.Unlock()
  487. }
  488. // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
  489. // Caller must own the connMutex lock!
  490. func (f *Firewall) evict(p FirewallPacket) {
  491. //TODO: report a stat if the tcp rtt tracking was never resolved?
  492. // Are we still tracking this conn?
  493. conntrack := f.Conntrack
  494. t, ok := conntrack.Conns[p]
  495. if !ok {
  496. return
  497. }
  498. newT := t.Expires.Sub(time.Now())
  499. // Timeout is in the future, re-add the timer
  500. if newT > 0 {
  501. conntrack.TimerWheel.Add(p, newT)
  502. return
  503. }
  504. // This conn is done
  505. delete(conntrack.Conns, p)
  506. }
  507. func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
  508. if ft.AnyProto.match(p, incoming, c, caPool) {
  509. return true
  510. }
  511. switch p.Protocol {
  512. case fwProtoTCP:
  513. if ft.TCP.match(p, incoming, c, caPool) {
  514. return true
  515. }
  516. case fwProtoUDP:
  517. if ft.UDP.match(p, incoming, c, caPool) {
  518. return true
  519. }
  520. case fwProtoICMP:
  521. if ft.ICMP.match(p, incoming, c, caPool) {
  522. return true
  523. }
  524. }
  525. return false
  526. }
  527. func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  528. if startPort > endPort {
  529. return fmt.Errorf("start port was lower than end port")
  530. }
  531. for i := startPort; i <= endPort; i++ {
  532. if _, ok := fp[i]; !ok {
  533. fp[i] = &FirewallCA{
  534. CANames: make(map[string]*FirewallRule),
  535. CAShas: make(map[string]*FirewallRule),
  536. }
  537. }
  538. if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
  539. return err
  540. }
  541. }
  542. return nil
  543. }
  544. func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
  545. // We don't have any allowed ports, bail
  546. if fp == nil {
  547. return false
  548. }
  549. var port int32
  550. if p.Fragment {
  551. port = fwPortFragment
  552. } else if incoming {
  553. port = int32(p.LocalPort)
  554. } else {
  555. port = int32(p.RemotePort)
  556. }
  557. if fp[port].match(p, c, caPool) {
  558. return true
  559. }
  560. return fp[fwPortAny].match(p, c, caPool)
  561. }
  562. func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
  563. fr := func() *FirewallRule {
  564. return &FirewallRule{
  565. Hosts: make(map[string]struct{}),
  566. Groups: make([][]string, 0),
  567. CIDR: NewCIDRTree(),
  568. }
  569. }
  570. if caSha == "" && caName == "" {
  571. if fc.Any == nil {
  572. fc.Any = fr()
  573. }
  574. return fc.Any.addRule(groups, host, ip)
  575. }
  576. if caSha != "" {
  577. if _, ok := fc.CAShas[caSha]; !ok {
  578. fc.CAShas[caSha] = fr()
  579. }
  580. err := fc.CAShas[caSha].addRule(groups, host, ip)
  581. if err != nil {
  582. return err
  583. }
  584. }
  585. if caName != "" {
  586. if _, ok := fc.CANames[caName]; !ok {
  587. fc.CANames[caName] = fr()
  588. }
  589. err := fc.CANames[caName].addRule(groups, host, ip)
  590. if err != nil {
  591. return err
  592. }
  593. }
  594. return nil
  595. }
  596. func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
  597. if fc == nil {
  598. return false
  599. }
  600. if fc.Any.match(p, c) {
  601. return true
  602. }
  603. if t, ok := fc.CAShas[c.Details.Issuer]; ok {
  604. if t.match(p, c) {
  605. return true
  606. }
  607. }
  608. s, err := caPool.GetCAForCert(c)
  609. if err != nil {
  610. return false
  611. }
  612. return fc.CANames[s.Details.Name].match(p, c)
  613. }
  614. func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
  615. if fr.Any {
  616. return nil
  617. }
  618. if fr.isAny(groups, host, ip) {
  619. fr.Any = true
  620. // If it's any we need to wipe out any pre-existing rules to save on memory
  621. fr.Groups = make([][]string, 0)
  622. fr.Hosts = make(map[string]struct{})
  623. fr.CIDR = NewCIDRTree()
  624. } else {
  625. if len(groups) > 0 {
  626. fr.Groups = append(fr.Groups, groups)
  627. }
  628. if host != "" {
  629. fr.Hosts[host] = struct{}{}
  630. }
  631. if ip != nil {
  632. fr.CIDR.AddCIDR(ip, struct{}{})
  633. }
  634. }
  635. return nil
  636. }
  637. func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
  638. if len(groups) == 0 && host == "" && ip == nil {
  639. return true
  640. }
  641. for _, group := range groups {
  642. if group == "any" {
  643. return true
  644. }
  645. }
  646. if host == "any" {
  647. return true
  648. }
  649. if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
  650. return true
  651. }
  652. return false
  653. }
  654. func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
  655. if fr == nil {
  656. return false
  657. }
  658. // Shortcut path for if groups, hosts, or cidr contained an `any`
  659. if fr.Any {
  660. return true
  661. }
  662. // Need any of group, host, or cidr to match
  663. for _, sg := range fr.Groups {
  664. found := false
  665. for _, g := range sg {
  666. if _, ok := c.Details.InvertedGroups[g]; !ok {
  667. found = false
  668. break
  669. }
  670. found = true
  671. }
  672. if found {
  673. return true
  674. }
  675. }
  676. if fr.Hosts != nil {
  677. if _, ok := fr.Hosts[c.Details.Name]; ok {
  678. return true
  679. }
  680. }
  681. if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
  682. return true
  683. }
  684. // No host, group, or cidr matched, bye bye
  685. return false
  686. }
  687. type rule struct {
  688. Port string
  689. Code string
  690. Proto string
  691. Host string
  692. Group string
  693. Groups []string
  694. Cidr string
  695. CAName string
  696. CASha string
  697. }
  698. func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
  699. r := rule{}
  700. m, ok := p.(map[interface{}]interface{})
  701. if !ok {
  702. return r, errors.New("could not parse rule")
  703. }
  704. toString := func(k string, m map[interface{}]interface{}) string {
  705. v, ok := m[k]
  706. if !ok {
  707. return ""
  708. }
  709. return fmt.Sprintf("%v", v)
  710. }
  711. r.Port = toString("port", m)
  712. r.Code = toString("code", m)
  713. r.Proto = toString("proto", m)
  714. r.Host = toString("host", m)
  715. r.Cidr = toString("cidr", m)
  716. r.CAName = toString("ca_name", m)
  717. r.CASha = toString("ca_sha", m)
  718. // Make sure group isn't an array
  719. if v, ok := m["group"].([]interface{}); ok {
  720. if len(v) > 1 {
  721. return r, errors.New("group should contain a single value, an array with more than one entry was provided")
  722. }
  723. l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
  724. m["group"] = v[0]
  725. }
  726. r.Group = toString("group", m)
  727. if rg, ok := m["groups"]; ok {
  728. switch reflect.TypeOf(rg).Kind() {
  729. case reflect.Slice:
  730. v := reflect.ValueOf(rg)
  731. r.Groups = make([]string, v.Len())
  732. for i := 0; i < v.Len(); i++ {
  733. r.Groups[i] = v.Index(i).Interface().(string)
  734. }
  735. case reflect.String:
  736. r.Groups = []string{rg.(string)}
  737. default:
  738. r.Groups = []string{fmt.Sprintf("%v", rg)}
  739. }
  740. }
  741. return r, nil
  742. }
  743. func parsePort(s string) (startPort, endPort int32, err error) {
  744. if s == "any" {
  745. startPort = fwPortAny
  746. endPort = fwPortAny
  747. } else if s == "fragment" {
  748. startPort = fwPortFragment
  749. endPort = fwPortFragment
  750. } else if strings.Contains(s, `-`) {
  751. sPorts := strings.SplitN(s, `-`, 2)
  752. sPorts[0] = strings.Trim(sPorts[0], " ")
  753. sPorts[1] = strings.Trim(sPorts[1], " ")
  754. if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
  755. return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
  756. }
  757. rStartPort, err := strconv.Atoi(sPorts[0])
  758. if err != nil {
  759. return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
  760. }
  761. rEndPort, err := strconv.Atoi(sPorts[1])
  762. if err != nil {
  763. return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
  764. }
  765. startPort = int32(rStartPort)
  766. endPort = int32(rEndPort)
  767. if startPort == fwPortAny {
  768. endPort = fwPortAny
  769. }
  770. } else {
  771. rPort, err := strconv.Atoi(s)
  772. if err != nil {
  773. return 0, 0, fmt.Errorf("was not a number; `%s`", s)
  774. }
  775. startPort = int32(rPort)
  776. endPort = startPort
  777. }
  778. return
  779. }
  780. //TODO: write tests for these
  781. func setTCPRTTTracking(c *conn, p []byte) {
  782. if c.Seq != 0 {
  783. return
  784. }
  785. ihl := int(p[0]&0x0f) << 2
  786. // Don't track FIN packets
  787. if p[ihl+13]&tcpFIN != 0 {
  788. return
  789. }
  790. c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8])
  791. c.Sent = time.Now()
  792. }
  793. func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
  794. if c.Seq == 0 {
  795. return false
  796. }
  797. ihl := int(p[0]&0x0f) << 2
  798. if p[ihl+13]&tcpACK == 0 {
  799. return false
  800. }
  801. // Deal with wrap around, signed int cuts the ack window in half
  802. // 0 is a bad ack, no data acknowledged
  803. // positive number is a bad ack, ack is over half the window away
  804. if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 {
  805. return false
  806. }
  807. f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds())
  808. c.Seq = 0
  809. return true
  810. }
  811. // ConntrackCache is used as a local routine cache to know if a given flow
  812. // has been seen in the conntrack table.
  813. type ConntrackCache map[FirewallPacket]struct{}
  814. type ConntrackCacheTicker struct {
  815. cacheV uint64
  816. cacheTick uint64
  817. cache ConntrackCache
  818. }
  819. func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
  820. if d == 0 {
  821. return nil
  822. }
  823. c := &ConntrackCacheTicker{
  824. cache: ConntrackCache{},
  825. }
  826. go c.tick(d)
  827. return c
  828. }
  829. func (c *ConntrackCacheTicker) tick(d time.Duration) {
  830. for {
  831. time.Sleep(d)
  832. atomic.AddUint64(&c.cacheTick, 1)
  833. }
  834. }
  835. // Get checks if the cache ticker has moved to the next version before returning
  836. // the map. If it has moved, we reset the map.
  837. func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
  838. if c == nil {
  839. return nil
  840. }
  841. if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
  842. c.cacheV = tick
  843. if ll := len(c.cache); ll > 0 {
  844. if l.Level == logrus.DebugLevel {
  845. l.WithField("len", ll).Debug("resetting conntrack cache")
  846. }
  847. c.cache = make(ConntrackCache, ll)
  848. }
  849. }
  850. return c.cache
  851. }