Browse Source

Generic timerwheel (#804)

Nate Brown 2 years ago
parent
commit
5278b6f926
8 changed files with 116 additions and 431 deletions
  1. 10 14
      connection_manager.go
  2. 4 2
      firewall.go
  3. 5 6
      handshake_manager.go
  4. 2 2
      handshake_manager_test.go
  5. 63 32
      timeout.go
  6. 0 199
      timeout_system.go
  7. 0 156
      timeout_system_test.go
  8. 32 20
      timeout_test.go

+ 10 - 14
connection_manager.go

@@ -19,12 +19,12 @@ type connectionManager struct {
 	inLock       *sync.RWMutex
 	out          map[iputil.VpnIp]struct{}
 	outLock      *sync.RWMutex
-	TrafficTimer *SystemTimerWheel
+	TrafficTimer *LockingTimerWheel[iputil.VpnIp]
 	intf         *Interface
 
 	pendingDeletion      map[iputil.VpnIp]int
 	pendingDeletionLock  *sync.RWMutex
-	pendingDeletionTimer *SystemTimerWheel
+	pendingDeletionTimer *LockingTimerWheel[iputil.VpnIp]
 
 	checkInterval           int
 	pendingDeletionInterval int
@@ -40,11 +40,11 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
 		inLock:                  &sync.RWMutex{},
 		out:                     make(map[iputil.VpnIp]struct{}),
 		outLock:                 &sync.RWMutex{},
-		TrafficTimer:            NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
+		TrafficTimer:            NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60),
 		intf:                    intf,
 		pendingDeletion:         make(map[iputil.VpnIp]int),
 		pendingDeletionLock:     &sync.RWMutex{},
-		pendingDeletionTimer:    NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
+		pendingDeletionTimer:    NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60),
 		checkInterval:           checkInterval,
 		pendingDeletionInterval: pendingDeletionInterval,
 		l:                       l,
@@ -160,15 +160,13 @@ func (n *connectionManager) Run(ctx context.Context) {
 }
 
 func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
-	n.TrafficTimer.advance(now)
+	n.TrafficTimer.Advance(now)
 	for {
-		ep := n.TrafficTimer.Purge()
-		if ep == nil {
+		vpnIp, has := n.TrafficTimer.Purge()
+		if !has {
 			break
 		}
 
-		vpnIp := ep.(iputil.VpnIp)
-
 		// Check for traffic coming back in from this host.
 		traf := n.CheckIn(vpnIp)
 
@@ -214,15 +212,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
 }
 
 func (n *connectionManager) HandleDeletionTick(now time.Time) {
-	n.pendingDeletionTimer.advance(now)
+	n.pendingDeletionTimer.Advance(now)
 	for {
-		ep := n.pendingDeletionTimer.Purge()
-		if ep == nil {
+		vpnIp, has := n.pendingDeletionTimer.Purge()
+		if !has {
 			break
 		}
 
-		vpnIp := ep.(iputil.VpnIp)
-
 		hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
 		if err != nil {
 			n.l.Debugf("Not found in hostmap: %s", vpnIp)

+ 4 - 2
firewall.go

@@ -77,7 +77,7 @@ type FirewallConntrack struct {
 	sync.Mutex
 
 	Conns      map[firewall.Packet]*conn
-	TimerWheel *TimerWheel
+	TimerWheel *TimerWheel[firewall.Packet]
 }
 
 type FirewallTable struct {
@@ -145,7 +145,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
 	return &Firewall{
 		Conntrack: &FirewallConntrack{
 			Conns:      make(map[firewall.Packet]*conn),
-			TimerWheel: NewTimerWheel(min, max),
+			TimerWheel: NewTimerWheel[firewall.Packet](min, max),
 		},
 		InRules:        newFirewallTable(),
 		OutRules:       newFirewallTable(),
@@ -510,6 +510,7 @@ func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
 	conntrack := f.Conntrack
 	conntrack.Lock()
 	if _, ok := conntrack.Conns[fp]; !ok {
+		conntrack.TimerWheel.Advance(time.Now())
 		conntrack.TimerWheel.Add(fp, timeout)
 	}
 
@@ -537,6 +538,7 @@ func (f *Firewall) evict(p firewall.Packet) {
 
 	// Timeout is in the future, re-add the timer
 	if newT > 0 {
+		conntrack.TimerWheel.Advance(time.Now())
 		conntrack.TimerWheel.Add(p, newT)
 		return
 	}

+ 5 - 6
handshake_manager.go

@@ -47,7 +47,7 @@ type HandshakeManager struct {
 	lightHouse             *LightHouse
 	outside                *udp.Conn
 	config                 HandshakeConfig
-	OutboundHandshakeTimer *SystemTimerWheel
+	OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
 	messageMetrics         *MessageMetrics
 	metricInitiated        metrics.Counter
 	metricTimedOut         metrics.Counter
@@ -65,7 +65,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
 		outside:                outside,
 		config:                 config,
 		trigger:                make(chan iputil.VpnIp, config.triggerBuffer),
-		OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
+		OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
 		messageMetrics:         config.messageMetrics,
 		metricInitiated:        metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
 		metricTimedOut:         metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
@@ -90,13 +90,12 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
 }
 
 func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
-	c.OutboundHandshakeTimer.advance(now)
+	c.OutboundHandshakeTimer.Advance(now)
 	for {
-		ep := c.OutboundHandshakeTimer.Purge()
-		if ep == nil {
+		vpnIp, has := c.OutboundHandshakeTimer.Purge()
+		if !has {
 			break
 		}
-		vpnIp := ep.(iputil.VpnIp)
 		c.handleOutbound(vpnIp, f, false)
 	}
 }

+ 2 - 2
handshake_manager_test.go

@@ -106,8 +106,8 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
 	assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
 }
 
-func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
-	for _, i := range tw.wheel {
+func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
+	for _, i := range tw.t.wheel {
 		n := i.Head
 		for n != nil {
 			c++

+ 63 - 32
timeout.go

@@ -1,17 +1,14 @@
 package nebula
 
 import (
+	"sync"
 	"time"
-
-	"github.com/slackhq/nebula/firewall"
 )
 
 // How many timer objects should be cached
 const timerCacheMax = 50000
 
-var emptyFWPacket = firewall.Packet{}
-
-type TimerWheel struct {
+type TimerWheel[T any] struct {
 	// Current tick
 	current int
 
@@ -26,31 +23,38 @@ type TimerWheel struct {
 	wheelDuration time.Duration
 
 	// The actual wheel which is just a set of singly linked lists, head/tail pointers
-	wheel []*TimeoutList
+	wheel []*TimeoutList[T]
 
 	// Singly linked list of items that have timed out of the wheel
-	expired *TimeoutList
+	expired *TimeoutList[T]
 
 	// Item cache to avoid garbage collect
-	itemCache   *TimeoutItem
+	itemCache   *TimeoutItem[T]
 	itemsCached int
 }
 
+type LockingTimerWheel[T any] struct {
+	m sync.Mutex
+	t *TimerWheel[T]
+}
+
 // TimeoutList Represents a tick in the wheel
-type TimeoutList struct {
-	Head *TimeoutItem
-	Tail *TimeoutItem
+type TimeoutList[T any] struct {
+	Head *TimeoutItem[T]
+	Tail *TimeoutItem[T]
 }
 
 // TimeoutItem Represents an item within a tick
-type TimeoutItem struct {
-	Packet firewall.Packet
-	Next   *TimeoutItem
+type TimeoutItem[T any] struct {
+	Item T
+	Next *TimeoutItem[T]
 }
 
 // NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
 // Purge must be called once per entry to actually remove anything
-func NewTimerWheel(min, max time.Duration) *TimerWheel {
+// The TimerWheel does not handle concurrency on its own.
+// Locks around access to it must be used if multiple routines are manipulating it.
+func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] {
 	//TODO provide an error
 	//if min >= max {
 	//	return nil
@@ -61,26 +65,31 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
 	// timeout
 	wLen := int((max / min) + 2)
 
-	tw := TimerWheel{
+	tw := TimerWheel[T]{
 		wheelLen:      wLen,
-		wheel:         make([]*TimeoutList, wLen),
+		wheel:         make([]*TimeoutList[T], wLen),
 		tickDuration:  min,
 		wheelDuration: max,
-		expired:       &TimeoutList{},
+		expired:       &TimeoutList[T]{},
 	}
 
 	for i := range tw.wheel {
-		tw.wheel[i] = &TimeoutList{}
+		tw.wheel[i] = &TimeoutList[T]{}
 	}
 
 	return &tw
 }
 
-// Add will add a firewall.Packet to the wheel in it's proper timeout
-func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
-	// Check and see if we should progress the tick
-	tw.advance(time.Now())
+// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty
+func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] {
+	return &LockingTimerWheel[T]{
+		t: NewTimerWheel[T](min, max),
+	}
+}
 
+// Add will add an item to the wheel in its proper timeout.
+// Caller should Advance the wheel prior to ensure the proper slot is used.
+func (tw *TimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] {
 	i := tw.findWheel(timeout)
 
 	// Try to fetch off the cache
@@ -90,11 +99,11 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem
 		tw.itemsCached--
 		ti.Next = nil
 	} else {
-		ti = &TimeoutItem{}
+		ti = &TimeoutItem[T]{}
 	}
 
 	// Relink and return
-	ti.Packet = v
+	ti.Item = v
 	if tw.wheel[i].Tail == nil {
 		tw.wheel[i].Head = ti
 		tw.wheel[i].Tail = ti
@@ -106,9 +115,12 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem
 	return ti
 }
 
-func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
+// Purge removes and returns the first available expired item from the wheel and the 2nd argument is true.
+// If no item is available then an empty T is returned and the 2nd argument is false.
+func (tw *TimerWheel[T]) Purge() (T, bool) {
 	if tw.expired.Head == nil {
-		return emptyFWPacket, false
+		var na T
+		return na, false
 	}
 
 	ti := tw.expired.Head
@@ -128,11 +140,11 @@ func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
 		tw.itemsCached++
 	}
 
-	return ti.Packet, true
+	return ti.Item, true
 }
 
-// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this
-func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
+// findWheel find the next position in the wheel for the provided timeout given the current tick
+func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) {
 	if timeout < tw.tickDuration {
 		// Can't track anything below the set resolution
 		timeout = tw.tickDuration
@@ -154,8 +166,9 @@ func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
 	return tick
 }
 
-// advance will lock and move the wheel forward by proper number of ticks.
-func (tw *TimerWheel) advance(now time.Time) {
+// Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items
+// passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely.
+func (tw *TimerWheel[T]) Advance(now time.Time) {
 	if tw.lastTick == nil {
 		tw.lastTick = &now
 	}
@@ -192,3 +205,21 @@ func (tw *TimerWheel) advance(now time.Time) {
 	newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv))
 	tw.lastTick = &newTick
 }
+
+func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] {
+	lw.m.Lock()
+	defer lw.m.Unlock()
+	return lw.t.Add(v, timeout)
+}
+
+func (lw *LockingTimerWheel[T]) Purge() (T, bool) {
+	lw.m.Lock()
+	defer lw.m.Unlock()
+	return lw.t.Purge()
+}
+
+func (lw *LockingTimerWheel[T]) Advance(now time.Time) {
+	lw.m.Lock()
+	defer lw.m.Unlock()
+	lw.t.Advance(now)
+}

+ 0 - 199
timeout_system.go

@@ -1,199 +0,0 @@
-package nebula
-
-import (
-	"sync"
-	"time"
-
-	"github.com/slackhq/nebula/iputil"
-)
-
-// How many timer objects should be cached
-const systemTimerCacheMax = 50000
-
-type SystemTimerWheel struct {
-	// Current tick
-	current int
-
-	// Cheat on finding the length of the wheel
-	wheelLen int
-
-	// Last time we ticked, since we are lazy ticking
-	lastTick *time.Time
-
-	// Durations of a tick and the entire wheel
-	tickDuration  time.Duration
-	wheelDuration time.Duration
-
-	// The actual wheel which is just a set of singly linked lists, head/tail pointers
-	wheel []*SystemTimeoutList
-
-	// Singly linked list of items that have timed out of the wheel
-	expired *SystemTimeoutList
-
-	// Item cache to avoid garbage collect
-	itemCache   *SystemTimeoutItem
-	itemsCached int
-
-	lock sync.Mutex
-}
-
-// SystemTimeoutList Represents a tick in the wheel
-type SystemTimeoutList struct {
-	Head *SystemTimeoutItem
-	Tail *SystemTimeoutItem
-}
-
-// SystemTimeoutItem Represents an item within a tick
-type SystemTimeoutItem struct {
-	Item iputil.VpnIp
-	Next *SystemTimeoutItem
-}
-
-// NewSystemTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
-// Purge must be called once per entry to actually remove anything
-func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
-	//TODO provide an error
-	//if min >= max {
-	//	return nil
-	//}
-
-	// Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full
-	// max duration, even if our current tick is at the maximum position and the next item to be added is at maximum
-	// timeout
-	wLen := int((max / min) + 2)
-
-	tw := SystemTimerWheel{
-		wheelLen:      wLen,
-		wheel:         make([]*SystemTimeoutList, wLen),
-		tickDuration:  min,
-		wheelDuration: max,
-		expired:       &SystemTimeoutList{},
-	}
-
-	for i := range tw.wheel {
-		tw.wheel[i] = &SystemTimeoutList{}
-	}
-
-	return &tw
-}
-
-func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
-	tw.lock.Lock()
-	defer tw.lock.Unlock()
-
-	// Check and see if we should progress the tick
-	//tw.advance(time.Now())
-
-	i := tw.findWheel(timeout)
-
-	// Try to fetch off the cache
-	ti := tw.itemCache
-	if ti != nil {
-		tw.itemCache = ti.Next
-		ti.Next = nil
-		tw.itemsCached--
-	} else {
-		ti = &SystemTimeoutItem{}
-	}
-
-	// Relink and return
-	ti.Item = v
-	ti.Next = tw.wheel[i].Head
-	tw.wheel[i].Head = ti
-
-	if tw.wheel[i].Tail == nil {
-		tw.wheel[i].Tail = ti
-	}
-
-	return ti
-}
-
-func (tw *SystemTimerWheel) Purge() interface{} {
-	tw.lock.Lock()
-	defer tw.lock.Unlock()
-
-	if tw.expired.Head == nil {
-		return nil
-	}
-
-	ti := tw.expired.Head
-	tw.expired.Head = ti.Next
-
-	if tw.expired.Head == nil {
-		tw.expired.Tail = nil
-	}
-
-	p := ti.Item
-
-	// Clear out the items references
-	ti.Item = 0
-	ti.Next = nil
-
-	// Maybe cache it for later
-	if tw.itemsCached < systemTimerCacheMax {
-		ti.Next = tw.itemCache
-		tw.itemCache = ti
-		tw.itemsCached++
-	}
-
-	return p
-}
-
-func (tw *SystemTimerWheel) findWheel(timeout time.Duration) (i int) {
-	if timeout < tw.tickDuration {
-		// Can't track anything below the set resolution
-		timeout = tw.tickDuration
-	} else if timeout > tw.wheelDuration {
-		// We aren't handling timeouts greater than the wheels duration
-		timeout = tw.wheelDuration
-	}
-
-	// Find the next highest, rounding up
-	tick := int(((timeout - 1) / tw.tickDuration) + 1)
-
-	// Add another tick since the current tick may almost be over then map it to the wheel from our
-	// current position
-	tick += tw.current + 1
-	if tick >= tw.wheelLen {
-		tick -= tw.wheelLen
-	}
-
-	return tick
-}
-
-func (tw *SystemTimerWheel) advance(now time.Time) {
-	tw.lock.Lock()
-	defer tw.lock.Unlock()
-
-	if tw.lastTick == nil {
-		tw.lastTick = &now
-	}
-
-	// We want to round down
-	ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration)
-	//l.Infoln("Ticks: ", ticks)
-	for i := 0; i < ticks; i++ {
-		tw.current++
-		//l.Infoln("Tick: ", tw.current)
-		if tw.current >= tw.wheelLen {
-			tw.current = 0
-		}
-
-		// We need to append the expired items as to not starve evicting the oldest ones
-		if tw.expired.Tail == nil {
-			tw.expired.Head = tw.wheel[tw.current].Head
-			tw.expired.Tail = tw.wheel[tw.current].Tail
-		} else {
-			tw.expired.Tail.Next = tw.wheel[tw.current].Head
-			if tw.wheel[tw.current].Tail != nil {
-				tw.expired.Tail = tw.wheel[tw.current].Tail
-			}
-		}
-
-		//l.Infoln("Head: ", tw.expired.Head, "Tail: ", tw.expired.Tail)
-		tw.wheel[tw.current].Head = nil
-		tw.wheel[tw.current].Tail = nil
-
-		tw.lastTick = &now
-	}
-}

+ 0 - 156
timeout_system_test.go

@@ -1,156 +0,0 @@
-package nebula
-
-import (
-	"net"
-	"testing"
-	"time"
-
-	"github.com/slackhq/nebula/iputil"
-	"github.com/stretchr/testify/assert"
-)
-
-func TestNewSystemTimerWheel(t *testing.T) {
-	// Make sure we get an object we expect
-	tw := NewSystemTimerWheel(time.Second, time.Second*10)
-	assert.Equal(t, 12, tw.wheelLen)
-	assert.Equal(t, 0, tw.current)
-	assert.Nil(t, tw.lastTick)
-	assert.Equal(t, time.Second*1, tw.tickDuration)
-	assert.Equal(t, time.Second*10, tw.wheelDuration)
-	assert.Len(t, tw.wheel, 12)
-
-	// Assert the math is correct
-	tw = NewSystemTimerWheel(time.Second*3, time.Second*10)
-	assert.Equal(t, 5, tw.wheelLen)
-
-	tw = NewSystemTimerWheel(time.Second*120, time.Minute*10)
-	assert.Equal(t, 7, tw.wheelLen)
-}
-
-func TestSystemTimerWheel_findWheel(t *testing.T) {
-	tw := NewSystemTimerWheel(time.Second, time.Second*10)
-	assert.Len(t, tw.wheel, 12)
-
-	// Current + tick + 1 since we don't know how far into current we are
-	assert.Equal(t, 2, tw.findWheel(time.Second*1))
-
-	// Scale up to min duration
-	assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
-
-	// Make sure we hit that last index
-	assert.Equal(t, 11, tw.findWheel(time.Second*10))
-
-	// Scale down to max duration
-	assert.Equal(t, 11, tw.findWheel(time.Second*11))
-
-	tw.current = 1
-	// Make sure we account for the current position properly
-	assert.Equal(t, 3, tw.findWheel(time.Second*1))
-	assert.Equal(t, 0, tw.findWheel(time.Second*10))
-
-	// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
-	for min := time.Duration(1); min < 100; min++ {
-		for max := min; max < 100; max++ {
-			tw = NewSystemTimerWheel(min, max)
-
-			for current := 0; current < tw.wheelLen; current++ {
-				tw.current = current
-				for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
-					tick := tw.findWheel(timeout)
-					if tick >= tw.wheelLen {
-						t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
-					}
-				}
-			}
-		}
-	}
-}
-
-func TestSystemTimerWheel_Add(t *testing.T) {
-	tw := NewSystemTimerWheel(time.Second, time.Second*10)
-
-	fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
-	tw.Add(fp1, time.Second*1)
-
-	// Make sure we set head and tail properly
-	assert.NotNil(t, tw.wheel[2])
-	assert.Equal(t, fp1, tw.wheel[2].Head.Item)
-	assert.Nil(t, tw.wheel[2].Head.Next)
-	assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
-	assert.Nil(t, tw.wheel[2].Tail.Next)
-
-	// Make sure we only modify head
-	fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
-	tw.Add(fp2, time.Second*1)
-	assert.Equal(t, fp2, tw.wheel[2].Head.Item)
-	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
-	assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
-	assert.Nil(t, tw.wheel[2].Tail.Next)
-
-	// Make sure we use free'd items first
-	tw.itemCache = &SystemTimeoutItem{}
-	tw.itemsCached = 1
-	tw.Add(fp2, time.Second*1)
-	assert.Nil(t, tw.itemCache)
-	assert.Equal(t, 0, tw.itemsCached)
-}
-
-func TestSystemTimerWheel_Purge(t *testing.T) {
-	// First advance should set the lastTick and do nothing else
-	tw := NewSystemTimerWheel(time.Second, time.Second*10)
-	assert.Nil(t, tw.lastTick)
-	tw.advance(time.Now())
-	assert.NotNil(t, tw.lastTick)
-	assert.Equal(t, 0, tw.current)
-
-	fps := []iputil.VpnIp{9, 10, 11, 12}
-
-	//fp1 := ip2int(net.ParseIP("1.2.3.4"))
-
-	tw.Add(fps[0], time.Second*1)
-	tw.Add(fps[1], time.Second*1)
-	tw.Add(fps[2], time.Second*2)
-	tw.Add(fps[3], time.Second*2)
-
-	ta := time.Now().Add(time.Second * 3)
-	lastTick := *tw.lastTick
-	tw.advance(ta)
-	assert.Equal(t, 3, tw.current)
-	assert.True(t, tw.lastTick.After(lastTick))
-
-	// Make sure we get all 4 packets back
-	for i := 0; i < 4; i++ {
-		assert.Contains(t, fps, tw.Purge())
-	}
-
-	// Make sure there aren't any leftover
-	assert.Nil(t, tw.Purge())
-	assert.Nil(t, tw.expired.Head)
-	assert.Nil(t, tw.expired.Tail)
-
-	// Make sure we cached the free'd items
-	assert.Equal(t, 4, tw.itemsCached)
-	ci := tw.itemCache
-	for i := 0; i < 4; i++ {
-		assert.NotNil(t, ci)
-		ci = ci.Next
-	}
-	assert.Nil(t, ci)
-
-	// Lets make sure we roll over properly
-	ta = ta.Add(time.Second * 5)
-	tw.advance(ta)
-	assert.Equal(t, 8, tw.current)
-
-	ta = ta.Add(time.Second * 2)
-	tw.advance(ta)
-	assert.Equal(t, 10, tw.current)
-
-	ta = ta.Add(time.Second * 1)
-	tw.advance(ta)
-	assert.Equal(t, 11, tw.current)
-
-	ta = ta.Add(time.Second * 1)
-	tw.advance(ta)
-	assert.Equal(t, 0, tw.current)
-}

+ 32 - 20
timeout_test.go

@@ -10,7 +10,7 @@ import (
 
 func TestNewTimerWheel(t *testing.T) {
 	// Make sure we get an object we expect
-	tw := NewTimerWheel(time.Second, time.Second*10)
+	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
 	assert.Equal(t, 12, tw.wheelLen)
 	assert.Equal(t, 0, tw.current)
 	assert.Nil(t, tw.lastTick)
@@ -19,15 +19,27 @@ func TestNewTimerWheel(t *testing.T) {
 	assert.Len(t, tw.wheel, 12)
 
 	// Assert the math is correct
-	tw = NewTimerWheel(time.Second*3, time.Second*10)
+	tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10)
 	assert.Equal(t, 5, tw.wheelLen)
 
-	tw = NewTimerWheel(time.Second*120, time.Minute*10)
+	tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10)
 	assert.Equal(t, 7, tw.wheelLen)
+
+	// Test empty purge of non nil items
+	i, ok := tw.Purge()
+	assert.Equal(t, firewall.Packet{}, i)
+	assert.False(t, ok)
+
+	// Test empty purges of nil items
+	tw2 := NewTimerWheel[*int](time.Second, time.Second*10)
+	i2, ok := tw2.Purge()
+	assert.Nil(t, i2)
+	assert.False(t, ok)
+
 }
 
 func TestTimerWheel_findWheel(t *testing.T) {
-	tw := NewTimerWheel(time.Second, time.Second*10)
+	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
 	assert.Len(t, tw.wheel, 12)
 
 	// Current + tick + 1 since we don't know how far into current we are
@@ -49,28 +61,28 @@ func TestTimerWheel_findWheel(t *testing.T) {
 }
 
 func TestTimerWheel_Add(t *testing.T) {
-	tw := NewTimerWheel(time.Second, time.Second*10)
+	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
 
 	fp1 := firewall.Packet{}
 	tw.Add(fp1, time.Second*1)
 
 	// Make sure we set head and tail properly
 	assert.NotNil(t, tw.wheel[2])
-	assert.Equal(t, fp1, tw.wheel[2].Head.Packet)
+	assert.Equal(t, fp1, tw.wheel[2].Head.Item)
 	assert.Nil(t, tw.wheel[2].Head.Next)
-	assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
+	assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 
 	// Make sure we only modify head
 	fp2 := firewall.Packet{}
 	tw.Add(fp2, time.Second*1)
-	assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
-	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
-	assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
+	assert.Equal(t, fp2, tw.wheel[2].Head.Item)
+	assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
+	assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
 	assert.Nil(t, tw.wheel[2].Tail.Next)
 
 	// Make sure we use free'd items first
-	tw.itemCache = &TimeoutItem{}
+	tw.itemCache = &TimeoutItem[firewall.Packet]{}
 	tw.itemsCached = 1
 	tw.Add(fp2, time.Second*1)
 	assert.Nil(t, tw.itemCache)
@@ -79,7 +91,7 @@ func TestTimerWheel_Add(t *testing.T) {
 	// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
 	for min := time.Duration(1); min < 100; min++ {
 		for max := min; max < 100; max++ {
-			tw = NewTimerWheel(min, max)
+			tw = NewTimerWheel[firewall.Packet](min, max)
 
 			for current := 0; current < tw.wheelLen; current++ {
 				tw.current = current
@@ -96,9 +108,9 @@ func TestTimerWheel_Add(t *testing.T) {
 
 func TestTimerWheel_Purge(t *testing.T) {
 	// First advance should set the lastTick and do nothing else
-	tw := NewTimerWheel(time.Second, time.Second*10)
+	tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
 	assert.Nil(t, tw.lastTick)
-	tw.advance(time.Now())
+	tw.Advance(time.Now())
 	assert.NotNil(t, tw.lastTick)
 	assert.Equal(t, 0, tw.current)
 
@@ -116,7 +128,7 @@ func TestTimerWheel_Purge(t *testing.T) {
 
 	ta := time.Now().Add(time.Second * 3)
 	lastTick := *tw.lastTick
-	tw.advance(ta)
+	tw.Advance(ta)
 	assert.Equal(t, 3, tw.current)
 	assert.True(t, tw.lastTick.After(lastTick))
 
@@ -142,20 +154,20 @@ func TestTimerWheel_Purge(t *testing.T) {
 	}
 	assert.Nil(t, ci)
 
-	// Lets make sure we roll over properly
+	// Let's make sure we roll over properly
 	ta = ta.Add(time.Second * 5)
-	tw.advance(ta)
+	tw.Advance(ta)
 	assert.Equal(t, 8, tw.current)
 
 	ta = ta.Add(time.Second * 2)
-	tw.advance(ta)
+	tw.Advance(ta)
 	assert.Equal(t, 10, tw.current)
 
 	ta = ta.Add(time.Second * 1)
-	tw.advance(ta)
+	tw.Advance(ta)
 	assert.Equal(t, 11, tw.current)
 
 	ta = ta.Add(time.Second * 1)
-	tw.advance(ta)
+	tw.Advance(ta)
 	assert.Equal(t, 0, tw.current)
 }