cache.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package firewall
  2. import (
  3. "sync"
  4. "sync/atomic"
  5. "time"
  6. "github.com/sirupsen/logrus"
  7. )
  8. // ConntrackCache is used as a local routine cache to know if a given flow
  9. // has been seen in the conntrack table.
  10. type ConntrackCache struct {
  11. mu sync.Mutex
  12. entries map[Packet]struct{}
  13. }
  14. func newConntrackCache() *ConntrackCache {
  15. return &ConntrackCache{entries: make(map[Packet]struct{})}
  16. }
  17. func (c *ConntrackCache) Has(p Packet) bool {
  18. if c == nil {
  19. return false
  20. }
  21. c.mu.Lock()
  22. _, ok := c.entries[p]
  23. c.mu.Unlock()
  24. return ok
  25. }
  26. func (c *ConntrackCache) Add(p Packet) {
  27. if c == nil {
  28. return
  29. }
  30. c.mu.Lock()
  31. c.entries[p] = struct{}{}
  32. c.mu.Unlock()
  33. }
  34. func (c *ConntrackCache) Len() int {
  35. if c == nil {
  36. return 0
  37. }
  38. c.mu.Lock()
  39. l := len(c.entries)
  40. c.mu.Unlock()
  41. return l
  42. }
  43. func (c *ConntrackCache) Reset(capHint int) {
  44. if c == nil {
  45. return
  46. }
  47. c.mu.Lock()
  48. c.entries = make(map[Packet]struct{}, capHint)
  49. c.mu.Unlock()
  50. }
  51. type ConntrackCacheTicker struct {
  52. cacheV uint64
  53. cacheTick atomic.Uint64
  54. cache *ConntrackCache
  55. }
  56. func NewConntrackCacheTicker(d time.Duration) *ConntrackCacheTicker {
  57. if d == 0 {
  58. return nil
  59. }
  60. c := &ConntrackCacheTicker{cache: newConntrackCache()}
  61. go c.tick(d)
  62. return c
  63. }
  64. func (c *ConntrackCacheTicker) tick(d time.Duration) {
  65. for {
  66. time.Sleep(d)
  67. c.cacheTick.Add(1)
  68. }
  69. }
  70. // Get checks if the cache ticker has moved to the next version before returning
  71. // the map. If it has moved, we reset the map.
  72. func (c *ConntrackCacheTicker) Get(l *logrus.Logger) *ConntrackCache {
  73. if c == nil {
  74. return nil
  75. }
  76. if tick := c.cacheTick.Load(); tick != c.cacheV {
  77. c.cacheV = tick
  78. if ll := c.cache.Len(); ll > 0 {
  79. if l.Level == logrus.DebugLevel {
  80. l.WithField("len", ll).Debug("resetting conntrack cache")
  81. }
  82. c.cache.Reset(ll)
  83. }
  84. }
  85. return c.cache
  86. }