2
0

firewall.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989
  1. package nebula
  2. import (
  3. "crypto/sha256"
  4. "encoding/binary"
  5. "encoding/hex"
  6. "encoding/json"
  7. "errors"
  8. "fmt"
  9. "net"
  10. "reflect"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "sync/atomic"
  15. "time"
  16. "github.com/rcrowley/go-metrics"
  17. "github.com/sirupsen/logrus"
  18. "github.com/slackhq/nebula/cert"
  19. )
  20. const (
  21. fwProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever
  22. fwProtoTCP = 6
  23. fwProtoUDP = 17
  24. fwProtoICMP = 1
  25. fwPortAny = 0 // Special value for matching `port: any`
  26. fwPortFragment = -1 // Special value for matching `port: fragment`
  27. )
  28. const tcpACK = 0x10
  29. const tcpFIN = 0x01
  30. type FirewallInterface interface {
  31. AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error
  32. }
  33. type conn struct {
  34. Expires time.Time // Time when this conntrack entry will expire
  35. Sent time.Time // If tcp rtt tracking is enabled this will be when Seq was last set
  36. Seq uint32 // If tcp rtt tracking is enabled this will be the seq we are looking for an ack
  37. // record why the original connection passed the firewall, so we can re-validate
  38. // after ruleset changes. Note, rulesVersion is a uint16 so that these two
  39. // fields pack for free after the uint32 above
  40. incoming bool
  41. rulesVersion uint16
  42. }
  43. // TODO: need conntrack max tracked connections handling
  44. type Firewall struct {
  45. Conntrack *FirewallConntrack
  46. InRules *FirewallTable
  47. OutRules *FirewallTable
  48. //TODO: we should have many more options for TCP, an option for ICMP, and mimic the kernel a bit better
  49. // https://www.kernel.org/doc/Documentation/networking/nf_conntrack-sysctl.txt
  50. TCPTimeout time.Duration //linux: 5 days max
  51. UDPTimeout time.Duration //linux: 180s max
  52. DefaultTimeout time.Duration //linux: 600s
  53. // Used to ensure we don't emit local packets for ips we don't own
  54. localIps *CIDRTree
  55. rules string
  56. rulesVersion uint16
  57. trackTCPRTT bool
  58. metricTCPRTT metrics.Histogram
  59. l *logrus.Logger
  60. }
  61. type FirewallConntrack struct {
  62. sync.Mutex
  63. Conns map[FirewallPacket]*conn
  64. TimerWheel *TimerWheel
  65. }
  66. type FirewallTable struct {
  67. TCP firewallPort
  68. UDP firewallPort
  69. ICMP firewallPort
  70. AnyProto firewallPort
  71. }
  72. func newFirewallTable() *FirewallTable {
  73. return &FirewallTable{
  74. TCP: firewallPort{},
  75. UDP: firewallPort{},
  76. ICMP: firewallPort{},
  77. AnyProto: firewallPort{},
  78. }
  79. }
  80. type FirewallCA struct {
  81. Any *FirewallRule
  82. CANames map[string]*FirewallRule
  83. CAShas map[string]*FirewallRule
  84. }
  85. type FirewallRule struct {
  86. // Any makes Hosts, Groups, and CIDR irrelevant
  87. Any bool
  88. Hosts map[string]struct{}
  89. Groups [][]string
  90. CIDR *CIDRTree
  91. }
  92. // Even though ports are uint16, int32 maps are faster for lookup
  93. // Plus we can use `-1` for fragment rules
  94. type firewallPort map[int32]*FirewallCA
  95. type FirewallPacket struct {
  96. LocalIP uint32
  97. RemoteIP uint32
  98. LocalPort uint16
  99. RemotePort uint16
  100. Protocol uint8
  101. Fragment bool
  102. }
  103. func (fp *FirewallPacket) Copy() *FirewallPacket {
  104. return &FirewallPacket{
  105. LocalIP: fp.LocalIP,
  106. RemoteIP: fp.RemoteIP,
  107. LocalPort: fp.LocalPort,
  108. RemotePort: fp.RemotePort,
  109. Protocol: fp.Protocol,
  110. Fragment: fp.Fragment,
  111. }
  112. }
  113. func (fp FirewallPacket) MarshalJSON() ([]byte, error) {
  114. var proto string
  115. switch fp.Protocol {
  116. case fwProtoTCP:
  117. proto = "tcp"
  118. case fwProtoICMP:
  119. proto = "icmp"
  120. case fwProtoUDP:
  121. proto = "udp"
  122. default:
  123. proto = fmt.Sprintf("unknown %v", fp.Protocol)
  124. }
  125. return json.Marshal(m{
  126. "LocalIP": int2ip(fp.LocalIP).String(),
  127. "RemoteIP": int2ip(fp.RemoteIP).String(),
  128. "LocalPort": fp.LocalPort,
  129. "RemotePort": fp.RemotePort,
  130. "Protocol": proto,
  131. "Fragment": fp.Fragment,
  132. })
  133. }
  134. // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts.
  135. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall {
  136. //TODO: error on 0 duration
  137. var min, max time.Duration
  138. if tcpTimeout < UDPTimeout {
  139. min = tcpTimeout
  140. max = UDPTimeout
  141. } else {
  142. min = UDPTimeout
  143. max = tcpTimeout
  144. }
  145. if defaultTimeout < min {
  146. min = defaultTimeout
  147. } else if defaultTimeout > max {
  148. max = defaultTimeout
  149. }
  150. localIps := NewCIDRTree()
  151. for _, ip := range c.Details.Ips {
  152. localIps.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
  153. }
  154. for _, n := range c.Details.Subnets {
  155. localIps.AddCIDR(n, struct{}{})
  156. }
  157. return &Firewall{
  158. Conntrack: &FirewallConntrack{
  159. Conns: make(map[FirewallPacket]*conn),
  160. TimerWheel: NewTimerWheel(min, max),
  161. },
  162. InRules: newFirewallTable(),
  163. OutRules: newFirewallTable(),
  164. TCPTimeout: tcpTimeout,
  165. UDPTimeout: UDPTimeout,
  166. DefaultTimeout: defaultTimeout,
  167. localIps: localIps,
  168. metricTCPRTT: metrics.GetOrRegisterHistogram("network.tcp.rtt", nil, metrics.NewExpDecaySample(1028, 0.015)),
  169. l: l,
  170. }
  171. }
  172. func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *Config) (*Firewall, error) {
  173. fw := NewFirewall(
  174. l,
  175. c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12),
  176. c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3),
  177. c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10),
  178. nc,
  179. //TODO: max_connections
  180. )
  181. err := AddFirewallRulesFromConfig(l, false, c, fw)
  182. if err != nil {
  183. return nil, err
  184. }
  185. err = AddFirewallRulesFromConfig(l, true, c, fw)
  186. if err != nil {
  187. return nil, err
  188. }
  189. return fw, nil
  190. }
  191. // AddRule properly creates the in memory rule structure for a firewall table.
  192. func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  193. // Under gomobile, stringing a nil pointer with fmt causes an abort in debug mode for iOS
  194. // https://github.com/golang/go/issues/14131
  195. sIp := ""
  196. if ip != nil {
  197. sIp = ip.String()
  198. }
  199. // We need this rule string because we generate a hash. Removing this will break firewall reload.
  200. ruleString := fmt.Sprintf(
  201. "incoming: %v, proto: %v, startPort: %v, endPort: %v, groups: %v, host: %v, ip: %v, caName: %v, caSha: %s",
  202. incoming, proto, startPort, endPort, groups, host, sIp, caName, caSha,
  203. )
  204. f.rules += ruleString + "\n"
  205. direction := "incoming"
  206. if !incoming {
  207. direction = "outgoing"
  208. }
  209. f.l.WithField("firewallRule", m{"direction": direction, "proto": proto, "startPort": startPort, "endPort": endPort, "groups": groups, "host": host, "ip": sIp, "caName": caName, "caSha": caSha}).
  210. Info("Firewall rule added")
  211. var (
  212. ft *FirewallTable
  213. fp firewallPort
  214. )
  215. if incoming {
  216. ft = f.InRules
  217. } else {
  218. ft = f.OutRules
  219. }
  220. switch proto {
  221. case fwProtoTCP:
  222. fp = ft.TCP
  223. case fwProtoUDP:
  224. fp = ft.UDP
  225. case fwProtoICMP:
  226. fp = ft.ICMP
  227. case fwProtoAny:
  228. fp = ft.AnyProto
  229. default:
  230. return fmt.Errorf("unknown protocol %v", proto)
  231. }
  232. return fp.addRule(startPort, endPort, groups, host, ip, caName, caSha)
  233. }
  234. // GetRuleHash returns a hash representation of all inbound and outbound rules
  235. func (f *Firewall) GetRuleHash() string {
  236. sum := sha256.Sum256([]byte(f.rules))
  237. return hex.EncodeToString(sum[:])
  238. }
  239. func AddFirewallRulesFromConfig(l *logrus.Logger, inbound bool, config *Config, fw FirewallInterface) error {
  240. var table string
  241. if inbound {
  242. table = "firewall.inbound"
  243. } else {
  244. table = "firewall.outbound"
  245. }
  246. r := config.Get(table)
  247. if r == nil {
  248. return nil
  249. }
  250. rs, ok := r.([]interface{})
  251. if !ok {
  252. return fmt.Errorf("%s failed to parse, should be an array of rules", table)
  253. }
  254. for i, t := range rs {
  255. var groups []string
  256. r, err := convertRule(l, t, table, i)
  257. if err != nil {
  258. return fmt.Errorf("%s rule #%v; %s", table, i, err)
  259. }
  260. if r.Code != "" && r.Port != "" {
  261. return fmt.Errorf("%s rule #%v; only one of port or code should be provided", table, i)
  262. }
  263. if r.Host == "" && len(r.Groups) == 0 && r.Group == "" && r.Cidr == "" && r.CAName == "" && r.CASha == "" {
  264. return fmt.Errorf("%s rule #%v; at least one of host, group, cidr, ca_name, or ca_sha must be provided", table, i)
  265. }
  266. if len(r.Groups) > 0 {
  267. groups = r.Groups
  268. }
  269. if r.Group != "" {
  270. // Check if we have both groups and group provided in the rule config
  271. if len(groups) > 0 {
  272. return fmt.Errorf("%s rule #%v; only one of group or groups should be defined, both provided", table, i)
  273. }
  274. groups = []string{r.Group}
  275. }
  276. var sPort, errPort string
  277. if r.Code != "" {
  278. errPort = "code"
  279. sPort = r.Code
  280. } else {
  281. errPort = "port"
  282. sPort = r.Port
  283. }
  284. startPort, endPort, err := parsePort(sPort)
  285. if err != nil {
  286. return fmt.Errorf("%s rule #%v; %s %s", table, i, errPort, err)
  287. }
  288. var proto uint8
  289. switch r.Proto {
  290. case "any":
  291. proto = fwProtoAny
  292. case "tcp":
  293. proto = fwProtoTCP
  294. case "udp":
  295. proto = fwProtoUDP
  296. case "icmp":
  297. proto = fwProtoICMP
  298. default:
  299. return fmt.Errorf("%s rule #%v; proto was not understood; `%s`", table, i, r.Proto)
  300. }
  301. var cidr *net.IPNet
  302. if r.Cidr != "" {
  303. _, cidr, err = net.ParseCIDR(r.Cidr)
  304. if err != nil {
  305. return fmt.Errorf("%s rule #%v; cidr did not parse; %s", table, i, err)
  306. }
  307. }
  308. err = fw.AddRule(inbound, proto, startPort, endPort, groups, r.Host, cidr, r.CAName, r.CASha)
  309. if err != nil {
  310. return fmt.Errorf("%s rule #%v; `%s`", table, i, err)
  311. }
  312. }
  313. return nil
  314. }
  315. var ErrInvalidRemoteIP = errors.New("remote IP is not in remote certificate subnets")
  316. var ErrInvalidLocalIP = errors.New("local IP is not in list of handled local IPs")
  317. var ErrNoMatchingRule = errors.New("no matching rule in firewall table")
  318. // Drop returns an error if the packet should be dropped, explaining why. It
  319. // returns nil if the packet should not be dropped.
  320. func (f *Firewall) Drop(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) error {
  321. // Check if we spoke to this tuple, if we did then allow this packet
  322. if f.inConns(packet, fp, incoming, h, caPool, localCache) {
  323. return nil
  324. }
  325. // Make sure remote address matches nebula certificate
  326. if remoteCidr := h.remoteCidr; remoteCidr != nil {
  327. if remoteCidr.Contains(fp.RemoteIP) == nil {
  328. return ErrInvalidRemoteIP
  329. }
  330. } else {
  331. // Simple case: Certificate has one IP and no subnets
  332. if fp.RemoteIP != h.hostId {
  333. return ErrInvalidRemoteIP
  334. }
  335. }
  336. // Make sure we are supposed to be handling this local ip address
  337. if f.localIps.Contains(fp.LocalIP) == nil {
  338. return ErrInvalidLocalIP
  339. }
  340. table := f.OutRules
  341. if incoming {
  342. table = f.InRules
  343. }
  344. // We now know which firewall table to check against
  345. if !table.match(fp, incoming, h.ConnectionState.peerCert, caPool) {
  346. return ErrNoMatchingRule
  347. }
  348. // We always want to conntrack since it is a faster operation
  349. f.addConn(packet, fp, incoming)
  350. return nil
  351. }
  352. // Destroy cleans up any known cyclical references so the object can be free'd my GC. This should be called if a new
  353. // firewall object is created
  354. func (f *Firewall) Destroy() {
  355. //TODO: clean references if/when needed
  356. }
  357. func (f *Firewall) EmitStats() {
  358. conntrack := f.Conntrack
  359. conntrack.Lock()
  360. conntrackCount := len(conntrack.Conns)
  361. conntrack.Unlock()
  362. metrics.GetOrRegisterGauge("firewall.conntrack.count", nil).Update(int64(conntrackCount))
  363. metrics.GetOrRegisterGauge("firewall.rules.version", nil).Update(int64(f.rulesVersion))
  364. }
  365. func (f *Firewall) inConns(packet []byte, fp FirewallPacket, incoming bool, h *HostInfo, caPool *cert.NebulaCAPool, localCache ConntrackCache) bool {
  366. if localCache != nil {
  367. if _, ok := localCache[fp]; ok {
  368. return true
  369. }
  370. }
  371. conntrack := f.Conntrack
  372. conntrack.Lock()
  373. // Purge every time we test
  374. ep, has := conntrack.TimerWheel.Purge()
  375. if has {
  376. f.evict(ep)
  377. }
  378. c, ok := conntrack.Conns[fp]
  379. if !ok {
  380. conntrack.Unlock()
  381. return false
  382. }
  383. if c.rulesVersion != f.rulesVersion {
  384. // This conntrack entry was for an older rule set, validate
  385. // it still passes with the current rule set
  386. table := f.OutRules
  387. if c.incoming {
  388. table = f.InRules
  389. }
  390. // We now know which firewall table to check against
  391. if !table.match(fp, c.incoming, h.ConnectionState.peerCert, caPool) {
  392. if f.l.Level >= logrus.DebugLevel {
  393. h.logger(f.l).
  394. WithField("fwPacket", fp).
  395. WithField("incoming", c.incoming).
  396. WithField("rulesVersion", f.rulesVersion).
  397. WithField("oldRulesVersion", c.rulesVersion).
  398. Debugln("dropping old conntrack entry, does not match new ruleset")
  399. }
  400. delete(conntrack.Conns, fp)
  401. conntrack.Unlock()
  402. return false
  403. }
  404. if f.l.Level >= logrus.DebugLevel {
  405. h.logger(f.l).
  406. WithField("fwPacket", fp).
  407. WithField("incoming", c.incoming).
  408. WithField("rulesVersion", f.rulesVersion).
  409. WithField("oldRulesVersion", c.rulesVersion).
  410. Debugln("keeping old conntrack entry, does match new ruleset")
  411. }
  412. c.rulesVersion = f.rulesVersion
  413. }
  414. switch fp.Protocol {
  415. case fwProtoTCP:
  416. c.Expires = time.Now().Add(f.TCPTimeout)
  417. if incoming {
  418. f.checkTCPRTT(c, packet)
  419. } else {
  420. setTCPRTTTracking(c, packet)
  421. }
  422. case fwProtoUDP:
  423. c.Expires = time.Now().Add(f.UDPTimeout)
  424. default:
  425. c.Expires = time.Now().Add(f.DefaultTimeout)
  426. }
  427. conntrack.Unlock()
  428. if localCache != nil {
  429. localCache[fp] = struct{}{}
  430. }
  431. return true
  432. }
  433. func (f *Firewall) addConn(packet []byte, fp FirewallPacket, incoming bool) {
  434. var timeout time.Duration
  435. c := &conn{}
  436. switch fp.Protocol {
  437. case fwProtoTCP:
  438. timeout = f.TCPTimeout
  439. if !incoming {
  440. setTCPRTTTracking(c, packet)
  441. }
  442. case fwProtoUDP:
  443. timeout = f.UDPTimeout
  444. default:
  445. timeout = f.DefaultTimeout
  446. }
  447. conntrack := f.Conntrack
  448. conntrack.Lock()
  449. if _, ok := conntrack.Conns[fp]; !ok {
  450. conntrack.TimerWheel.Add(fp, timeout)
  451. }
  452. // Record which rulesVersion allowed this connection, so we can retest after
  453. // firewall reload
  454. c.incoming = incoming
  455. c.rulesVersion = f.rulesVersion
  456. c.Expires = time.Now().Add(timeout)
  457. conntrack.Conns[fp] = c
  458. conntrack.Unlock()
  459. }
  460. // Evict checks if a conntrack entry has expired, if so it is removed, if not it is re-added to the wheel
  461. // Caller must own the connMutex lock!
  462. func (f *Firewall) evict(p FirewallPacket) {
  463. //TODO: report a stat if the tcp rtt tracking was never resolved?
  464. // Are we still tracking this conn?
  465. conntrack := f.Conntrack
  466. t, ok := conntrack.Conns[p]
  467. if !ok {
  468. return
  469. }
  470. newT := t.Expires.Sub(time.Now())
  471. // Timeout is in the future, re-add the timer
  472. if newT > 0 {
  473. conntrack.TimerWheel.Add(p, newT)
  474. return
  475. }
  476. // This conn is done
  477. delete(conntrack.Conns, p)
  478. }
  479. func (ft *FirewallTable) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
  480. if ft.AnyProto.match(p, incoming, c, caPool) {
  481. return true
  482. }
  483. switch p.Protocol {
  484. case fwProtoTCP:
  485. if ft.TCP.match(p, incoming, c, caPool) {
  486. return true
  487. }
  488. case fwProtoUDP:
  489. if ft.UDP.match(p, incoming, c, caPool) {
  490. return true
  491. }
  492. case fwProtoICMP:
  493. if ft.ICMP.match(p, incoming, c, caPool) {
  494. return true
  495. }
  496. }
  497. return false
  498. }
  499. func (fp firewallPort) addRule(startPort int32, endPort int32, groups []string, host string, ip *net.IPNet, caName string, caSha string) error {
  500. if startPort > endPort {
  501. return fmt.Errorf("start port was lower than end port")
  502. }
  503. for i := startPort; i <= endPort; i++ {
  504. if _, ok := fp[i]; !ok {
  505. fp[i] = &FirewallCA{
  506. CANames: make(map[string]*FirewallRule),
  507. CAShas: make(map[string]*FirewallRule),
  508. }
  509. }
  510. if err := fp[i].addRule(groups, host, ip, caName, caSha); err != nil {
  511. return err
  512. }
  513. }
  514. return nil
  515. }
  516. func (fp firewallPort) match(p FirewallPacket, incoming bool, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
  517. // We don't have any allowed ports, bail
  518. if fp == nil {
  519. return false
  520. }
  521. var port int32
  522. if p.Fragment {
  523. port = fwPortFragment
  524. } else if incoming {
  525. port = int32(p.LocalPort)
  526. } else {
  527. port = int32(p.RemotePort)
  528. }
  529. if fp[port].match(p, c, caPool) {
  530. return true
  531. }
  532. return fp[fwPortAny].match(p, c, caPool)
  533. }
  534. func (fc *FirewallCA) addRule(groups []string, host string, ip *net.IPNet, caName, caSha string) error {
  535. fr := func() *FirewallRule {
  536. return &FirewallRule{
  537. Hosts: make(map[string]struct{}),
  538. Groups: make([][]string, 0),
  539. CIDR: NewCIDRTree(),
  540. }
  541. }
  542. if caSha == "" && caName == "" {
  543. if fc.Any == nil {
  544. fc.Any = fr()
  545. }
  546. return fc.Any.addRule(groups, host, ip)
  547. }
  548. if caSha != "" {
  549. if _, ok := fc.CAShas[caSha]; !ok {
  550. fc.CAShas[caSha] = fr()
  551. }
  552. err := fc.CAShas[caSha].addRule(groups, host, ip)
  553. if err != nil {
  554. return err
  555. }
  556. }
  557. if caName != "" {
  558. if _, ok := fc.CANames[caName]; !ok {
  559. fc.CANames[caName] = fr()
  560. }
  561. err := fc.CANames[caName].addRule(groups, host, ip)
  562. if err != nil {
  563. return err
  564. }
  565. }
  566. return nil
  567. }
  568. func (fc *FirewallCA) match(p FirewallPacket, c *cert.NebulaCertificate, caPool *cert.NebulaCAPool) bool {
  569. if fc == nil {
  570. return false
  571. }
  572. if fc.Any.match(p, c) {
  573. return true
  574. }
  575. if t, ok := fc.CAShas[c.Details.Issuer]; ok {
  576. if t.match(p, c) {
  577. return true
  578. }
  579. }
  580. s, err := caPool.GetCAForCert(c)
  581. if err != nil {
  582. return false
  583. }
  584. return fc.CANames[s.Details.Name].match(p, c)
  585. }
  586. func (fr *FirewallRule) addRule(groups []string, host string, ip *net.IPNet) error {
  587. if fr.Any {
  588. return nil
  589. }
  590. if fr.isAny(groups, host, ip) {
  591. fr.Any = true
  592. // If it's any we need to wipe out any pre-existing rules to save on memory
  593. fr.Groups = make([][]string, 0)
  594. fr.Hosts = make(map[string]struct{})
  595. fr.CIDR = NewCIDRTree()
  596. } else {
  597. if len(groups) > 0 {
  598. fr.Groups = append(fr.Groups, groups)
  599. }
  600. if host != "" {
  601. fr.Hosts[host] = struct{}{}
  602. }
  603. if ip != nil {
  604. fr.CIDR.AddCIDR(ip, struct{}{})
  605. }
  606. }
  607. return nil
  608. }
  609. func (fr *FirewallRule) isAny(groups []string, host string, ip *net.IPNet) bool {
  610. if len(groups) == 0 && host == "" && ip == nil {
  611. return true
  612. }
  613. for _, group := range groups {
  614. if group == "any" {
  615. return true
  616. }
  617. }
  618. if host == "any" {
  619. return true
  620. }
  621. if ip != nil && ip.Contains(net.IPv4(0, 0, 0, 0)) {
  622. return true
  623. }
  624. return false
  625. }
  626. func (fr *FirewallRule) match(p FirewallPacket, c *cert.NebulaCertificate) bool {
  627. if fr == nil {
  628. return false
  629. }
  630. // Shortcut path for if groups, hosts, or cidr contained an `any`
  631. if fr.Any {
  632. return true
  633. }
  634. // Need any of group, host, or cidr to match
  635. for _, sg := range fr.Groups {
  636. found := false
  637. for _, g := range sg {
  638. if _, ok := c.Details.InvertedGroups[g]; !ok {
  639. found = false
  640. break
  641. }
  642. found = true
  643. }
  644. if found {
  645. return true
  646. }
  647. }
  648. if fr.Hosts != nil {
  649. if _, ok := fr.Hosts[c.Details.Name]; ok {
  650. return true
  651. }
  652. }
  653. if fr.CIDR != nil && fr.CIDR.Contains(p.RemoteIP) != nil {
  654. return true
  655. }
  656. // No host, group, or cidr matched, bye bye
  657. return false
  658. }
  659. type rule struct {
  660. Port string
  661. Code string
  662. Proto string
  663. Host string
  664. Group string
  665. Groups []string
  666. Cidr string
  667. CAName string
  668. CASha string
  669. }
  670. func convertRule(l *logrus.Logger, p interface{}, table string, i int) (rule, error) {
  671. r := rule{}
  672. m, ok := p.(map[interface{}]interface{})
  673. if !ok {
  674. return r, errors.New("could not parse rule")
  675. }
  676. toString := func(k string, m map[interface{}]interface{}) string {
  677. v, ok := m[k]
  678. if !ok {
  679. return ""
  680. }
  681. return fmt.Sprintf("%v", v)
  682. }
  683. r.Port = toString("port", m)
  684. r.Code = toString("code", m)
  685. r.Proto = toString("proto", m)
  686. r.Host = toString("host", m)
  687. r.Cidr = toString("cidr", m)
  688. r.CAName = toString("ca_name", m)
  689. r.CASha = toString("ca_sha", m)
  690. // Make sure group isn't an array
  691. if v, ok := m["group"].([]interface{}); ok {
  692. if len(v) > 1 {
  693. return r, errors.New("group should contain a single value, an array with more than one entry was provided")
  694. }
  695. l.Warnf("%s rule #%v; group was an array with a single value, converting to simple value", table, i)
  696. m["group"] = v[0]
  697. }
  698. r.Group = toString("group", m)
  699. if rg, ok := m["groups"]; ok {
  700. switch reflect.TypeOf(rg).Kind() {
  701. case reflect.Slice:
  702. v := reflect.ValueOf(rg)
  703. r.Groups = make([]string, v.Len())
  704. for i := 0; i < v.Len(); i++ {
  705. r.Groups[i] = v.Index(i).Interface().(string)
  706. }
  707. case reflect.String:
  708. r.Groups = []string{rg.(string)}
  709. default:
  710. r.Groups = []string{fmt.Sprintf("%v", rg)}
  711. }
  712. }
  713. return r, nil
  714. }
  715. func parsePort(s string) (startPort, endPort int32, err error) {
  716. if s == "any" {
  717. startPort = fwPortAny
  718. endPort = fwPortAny
  719. } else if s == "fragment" {
  720. startPort = fwPortFragment
  721. endPort = fwPortFragment
  722. } else if strings.Contains(s, `-`) {
  723. sPorts := strings.SplitN(s, `-`, 2)
  724. sPorts[0] = strings.Trim(sPorts[0], " ")
  725. sPorts[1] = strings.Trim(sPorts[1], " ")
  726. if len(sPorts) != 2 || sPorts[0] == "" || sPorts[1] == "" {
  727. return 0, 0, fmt.Errorf("appears to be a range but could not be parsed; `%s`", s)
  728. }
  729. rStartPort, err := strconv.Atoi(sPorts[0])
  730. if err != nil {
  731. return 0, 0, fmt.Errorf("beginning range was not a number; `%s`", sPorts[0])
  732. }
  733. rEndPort, err := strconv.Atoi(sPorts[1])
  734. if err != nil {
  735. return 0, 0, fmt.Errorf("ending range was not a number; `%s`", sPorts[1])
  736. }
  737. startPort = int32(rStartPort)
  738. endPort = int32(rEndPort)
  739. if startPort == fwPortAny {
  740. endPort = fwPortAny
  741. }
  742. } else {
  743. rPort, err := strconv.Atoi(s)
  744. if err != nil {
  745. return 0, 0, fmt.Errorf("was not a number; `%s`", s)
  746. }
  747. startPort = int32(rPort)
  748. endPort = startPort
  749. }
  750. return
  751. }
  752. //TODO: write tests for these
  753. func setTCPRTTTracking(c *conn, p []byte) {
  754. if c.Seq != 0 {
  755. return
  756. }
  757. ihl := int(p[0]&0x0f) << 2
  758. // Don't track FIN packets
  759. if p[ihl+13]&tcpFIN != 0 {
  760. return
  761. }
  762. c.Seq = binary.BigEndian.Uint32(p[ihl+4 : ihl+8])
  763. c.Sent = time.Now()
  764. }
  765. func (f *Firewall) checkTCPRTT(c *conn, p []byte) bool {
  766. if c.Seq == 0 {
  767. return false
  768. }
  769. ihl := int(p[0]&0x0f) << 2
  770. if p[ihl+13]&tcpACK == 0 {
  771. return false
  772. }
  773. // Deal with wrap around, signed int cuts the ack window in half
  774. // 0 is a bad ack, no data acknowledged
  775. // positive number is a bad ack, ack is over half the window away
  776. if int32(c.Seq-binary.BigEndian.Uint32(p[ihl+8:ihl+12])) >= 0 {
  777. return false
  778. }
  779. f.metricTCPRTT.Update(time.Since(c.Sent).Nanoseconds())
  780. c.Seq = 0
  781. return true
  782. }
  783. // ConntrackCache is used as a local routine cache to know if a given flow
  784. // has been seen in the conntrack table.
  785. type ConntrackCache map[FirewallPacket]struct{}
  786. type ConntrackCacheTicker struct {
  787. cacheV uint64
  788. cacheTick uint64
  789. cache ConntrackCache
  790. }
  791. func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
  792. if d == 0 {
  793. return nil
  794. }
  795. c := &ConntrackCacheTicker{
  796. cache: ConntrackCache{},
  797. }
  798. go c.tick(d)
  799. return c
  800. }
  801. func (c *ConntrackCacheTicker) tick(d time.Duration) {
  802. for {
  803. time.Sleep(d)
  804. atomic.AddUint64(&c.cacheTick, 1)
  805. }
  806. }
  807. // Get checks if the cache ticker has moved to the next version before returning
  808. // the map. If it has moved, we reset the map.
  809. func (c *ConntrackCacheTicker) Get(l *logrus.Logger) ConntrackCache {
  810. if c == nil {
  811. return nil
  812. }
  813. if tick := atomic.LoadUint64(&c.cacheTick); tick != c.cacheV {
  814. c.cacheV = tick
  815. if ll := len(c.cache); ll > 0 {
  816. if l.Level == logrus.DebugLevel {
  817. l.WithField("len", ll).Debug("resetting conntrack cache")
  818. }
  819. c.cache = make(ConntrackCache, ll)
  820. }
  821. }
  822. return c.cache
  823. }