Explorar o código

Fix possible panic in the timerwheels (#802)

Nate Brown %!s(int64=2) %!d(string=hai) anos
pai
achega
c177126ed0
Modificáronse 5 ficheiros con 78 adicións e 34 borrados
  1. 6 6
      firewall_test.go
  2. 7 6
      timeout.go
  3. 7 6
      timeout_system.go
  4. 29 8
      timeout_system_test.go
  5. 29 8
      timeout_test.go

+ 6 - 6
firewall_test.go

@@ -34,27 +34,27 @@ func TestNewFirewall(t *testing.T) {
 
 
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
+	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
 	fw = NewFirewall(l, time.Second, time.Hour, time.Minute, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
+	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
 	fw = NewFirewall(l, time.Hour, time.Second, time.Minute, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
+	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
 	fw = NewFirewall(l, time.Hour, time.Minute, time.Second, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
+	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
 	fw = NewFirewall(l, time.Minute, time.Hour, time.Second, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
+	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
 
 
 	fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
 	fw = NewFirewall(l, time.Minute, time.Second, time.Hour, c)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
 	assert.Equal(t, time.Hour, conntrack.TimerWheel.wheelDuration)
-	assert.Equal(t, 3601, conntrack.TimerWheel.wheelLen)
+	assert.Equal(t, 3602, conntrack.TimerWheel.wheelLen)
 }
 }
 
 
 func TestFirewall_AddRule(t *testing.T) {
 func TestFirewall_AddRule(t *testing.T) {

+ 7 - 6
timeout.go

@@ -36,19 +36,19 @@ type TimerWheel struct {
 	itemsCached int
 	itemsCached int
 }
 }
 
 
-// Represents a tick in the wheel
+// TimeoutList Represents a tick in the wheel
 type TimeoutList struct {
 type TimeoutList struct {
 	Head *TimeoutItem
 	Head *TimeoutItem
 	Tail *TimeoutItem
 	Tail *TimeoutItem
 }
 }
 
 
-// Represents an item within a tick
+// TimeoutItem Represents an item within a tick
 type TimeoutItem struct {
 type TimeoutItem struct {
 	Packet firewall.Packet
 	Packet firewall.Packet
 	Next   *TimeoutItem
 	Next   *TimeoutItem
 }
 }
 
 
-// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
+// NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
 // Purge must be called once per entry to actually remove anything
 // Purge must be called once per entry to actually remove anything
 func NewTimerWheel(min, max time.Duration) *TimerWheel {
 func NewTimerWheel(min, max time.Duration) *TimerWheel {
 	//TODO provide an error
 	//TODO provide an error
@@ -56,9 +56,10 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
 	//	return nil
 	//	return nil
 	//}
 	//}
 
 
-	// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
-	// max duration
-	wLen := int((max / min) + 1)
+	// Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full
+	// max duration, even if our current tick is at the maximum position and the next item to be added is at maximum
+	// timeout
+	wLen := int((max / min) + 2)
 
 
 	tw := TimerWheel{
 	tw := TimerWheel{
 		wheelLen:      wLen,
 		wheelLen:      wLen,

+ 7 - 6
timeout_system.go

@@ -37,19 +37,19 @@ type SystemTimerWheel struct {
 	lock sync.Mutex
 	lock sync.Mutex
 }
 }
 
 
-// Represents a tick in the wheel
+// SystemTimeoutList Represents a tick in the wheel
 type SystemTimeoutList struct {
 type SystemTimeoutList struct {
 	Head *SystemTimeoutItem
 	Head *SystemTimeoutItem
 	Tail *SystemTimeoutItem
 	Tail *SystemTimeoutItem
 }
 }
 
 
-// Represents an item within a tick
+// SystemTimeoutItem Represents an item within a tick
 type SystemTimeoutItem struct {
 type SystemTimeoutItem struct {
 	Item iputil.VpnIp
 	Item iputil.VpnIp
 	Next *SystemTimeoutItem
 	Next *SystemTimeoutItem
 }
 }
 
 
-// Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
+// NewSystemTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
 // Purge must be called once per entry to actually remove anything
 // Purge must be called once per entry to actually remove anything
 func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
 func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
 	//TODO provide an error
 	//TODO provide an error
@@ -57,9 +57,10 @@ func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
 	//	return nil
 	//	return nil
 	//}
 	//}
 
 
-	// Round down and add 1 so we can have the smallest # of ticks in the wheel and still account for a full
-	// max duration
-	wLen := int((max / min) + 1)
+	// Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full
+	// max duration, even if our current tick is at the maximum position and the next item to be added is at maximum
+	// timeout
+	wLen := int((max / min) + 2)
 
 
 	tw := SystemTimerWheel{
 	tw := SystemTimerWheel{
 		wheelLen:      wLen,
 		wheelLen:      wLen,

+ 29 - 8
timeout_system_test.go

@@ -12,24 +12,24 @@ import (
 func TestNewSystemTimerWheel(t *testing.T) {
 func TestNewSystemTimerWheel(t *testing.T) {
 	// Make sure we get an object we expect
 	// Make sure we get an object we expect
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
-	assert.Equal(t, 11, tw.wheelLen)
+	assert.Equal(t, 12, tw.wheelLen)
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)
 	assert.Nil(t, tw.lastTick)
 	assert.Nil(t, tw.lastTick)
 	assert.Equal(t, time.Second*1, tw.tickDuration)
 	assert.Equal(t, time.Second*1, tw.tickDuration)
 	assert.Equal(t, time.Second*10, tw.wheelDuration)
 	assert.Equal(t, time.Second*10, tw.wheelDuration)
-	assert.Len(t, tw.wheel, 11)
+	assert.Len(t, tw.wheel, 12)
 
 
 	// Assert the math is correct
 	// Assert the math is correct
 	tw = NewSystemTimerWheel(time.Second*3, time.Second*10)
 	tw = NewSystemTimerWheel(time.Second*3, time.Second*10)
-	assert.Equal(t, 4, tw.wheelLen)
+	assert.Equal(t, 5, tw.wheelLen)
 
 
 	tw = NewSystemTimerWheel(time.Second*120, time.Minute*10)
 	tw = NewSystemTimerWheel(time.Second*120, time.Minute*10)
-	assert.Equal(t, 6, tw.wheelLen)
+	assert.Equal(t, 7, tw.wheelLen)
 }
 }
 
 
 func TestSystemTimerWheel_findWheel(t *testing.T) {
 func TestSystemTimerWheel_findWheel(t *testing.T) {
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
 	tw := NewSystemTimerWheel(time.Second, time.Second*10)
-	assert.Len(t, tw.wheel, 11)
+	assert.Len(t, tw.wheel, 12)
 
 
 	// Current + tick + 1 since we don't know how far into current we are
 	// Current + tick + 1 since we don't know how far into current we are
 	assert.Equal(t, 2, tw.findWheel(time.Second*1))
 	assert.Equal(t, 2, tw.findWheel(time.Second*1))
@@ -38,15 +38,32 @@ func TestSystemTimerWheel_findWheel(t *testing.T) {
 	assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
 	assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
 
 
 	// Make sure we hit that last index
 	// Make sure we hit that last index
-	assert.Equal(t, 0, tw.findWheel(time.Second*10))
+	assert.Equal(t, 11, tw.findWheel(time.Second*10))
 
 
 	// Scale down to max duration
 	// Scale down to max duration
-	assert.Equal(t, 0, tw.findWheel(time.Second*11))
+	assert.Equal(t, 11, tw.findWheel(time.Second*11))
 
 
 	tw.current = 1
 	tw.current = 1
 	// Make sure we account for the current position properly
 	// Make sure we account for the current position properly
 	assert.Equal(t, 3, tw.findWheel(time.Second*1))
 	assert.Equal(t, 3, tw.findWheel(time.Second*1))
-	assert.Equal(t, 1, tw.findWheel(time.Second*10))
+	assert.Equal(t, 0, tw.findWheel(time.Second*10))
+
+	// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
+	for min := time.Duration(1); min < 100; min++ {
+		for max := min; max < 100; max++ {
+			tw = NewSystemTimerWheel(min, max)
+
+			for current := 0; current < tw.wheelLen; current++ {
+				tw.current = current
+				for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
+					tick := tw.findWheel(timeout)
+					if tick >= tw.wheelLen {
+						t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
+					}
+				}
+			}
+		}
+	}
 }
 }
 
 
 func TestSystemTimerWheel_Add(t *testing.T) {
 func TestSystemTimerWheel_Add(t *testing.T) {
@@ -129,6 +146,10 @@ func TestSystemTimerWheel_Purge(t *testing.T) {
 	tw.advance(ta)
 	tw.advance(ta)
 	assert.Equal(t, 10, tw.current)
 	assert.Equal(t, 10, tw.current)
 
 
+	ta = ta.Add(time.Second * 1)
+	tw.advance(ta)
+	assert.Equal(t, 11, tw.current)
+
 	ta = ta.Add(time.Second * 1)
 	ta = ta.Add(time.Second * 1)
 	tw.advance(ta)
 	tw.advance(ta)
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)

+ 29 - 8
timeout_test.go

@@ -11,24 +11,24 @@ import (
 func TestNewTimerWheel(t *testing.T) {
 func TestNewTimerWheel(t *testing.T) {
 	// Make sure we get an object we expect
 	// Make sure we get an object we expect
 	tw := NewTimerWheel(time.Second, time.Second*10)
 	tw := NewTimerWheel(time.Second, time.Second*10)
-	assert.Equal(t, 11, tw.wheelLen)
+	assert.Equal(t, 12, tw.wheelLen)
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)
 	assert.Nil(t, tw.lastTick)
 	assert.Nil(t, tw.lastTick)
 	assert.Equal(t, time.Second*1, tw.tickDuration)
 	assert.Equal(t, time.Second*1, tw.tickDuration)
 	assert.Equal(t, time.Second*10, tw.wheelDuration)
 	assert.Equal(t, time.Second*10, tw.wheelDuration)
-	assert.Len(t, tw.wheel, 11)
+	assert.Len(t, tw.wheel, 12)
 
 
 	// Assert the math is correct
 	// Assert the math is correct
 	tw = NewTimerWheel(time.Second*3, time.Second*10)
 	tw = NewTimerWheel(time.Second*3, time.Second*10)
-	assert.Equal(t, 4, tw.wheelLen)
+	assert.Equal(t, 5, tw.wheelLen)
 
 
 	tw = NewTimerWheel(time.Second*120, time.Minute*10)
 	tw = NewTimerWheel(time.Second*120, time.Minute*10)
-	assert.Equal(t, 6, tw.wheelLen)
+	assert.Equal(t, 7, tw.wheelLen)
 }
 }
 
 
 func TestTimerWheel_findWheel(t *testing.T) {
 func TestTimerWheel_findWheel(t *testing.T) {
 	tw := NewTimerWheel(time.Second, time.Second*10)
 	tw := NewTimerWheel(time.Second, time.Second*10)
-	assert.Len(t, tw.wheel, 11)
+	assert.Len(t, tw.wheel, 12)
 
 
 	// Current + tick + 1 since we don't know how far into current we are
 	// Current + tick + 1 since we don't know how far into current we are
 	assert.Equal(t, 2, tw.findWheel(time.Second*1))
 	assert.Equal(t, 2, tw.findWheel(time.Second*1))
@@ -37,15 +37,15 @@ func TestTimerWheel_findWheel(t *testing.T) {
 	assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
 	assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
 
 
 	// Make sure we hit that last index
 	// Make sure we hit that last index
-	assert.Equal(t, 0, tw.findWheel(time.Second*10))
+	assert.Equal(t, 11, tw.findWheel(time.Second*10))
 
 
 	// Scale down to max duration
 	// Scale down to max duration
-	assert.Equal(t, 0, tw.findWheel(time.Second*11))
+	assert.Equal(t, 11, tw.findWheel(time.Second*11))
 
 
 	tw.current = 1
 	tw.current = 1
 	// Make sure we account for the current position properly
 	// Make sure we account for the current position properly
 	assert.Equal(t, 3, tw.findWheel(time.Second*1))
 	assert.Equal(t, 3, tw.findWheel(time.Second*1))
-	assert.Equal(t, 1, tw.findWheel(time.Second*10))
+	assert.Equal(t, 0, tw.findWheel(time.Second*10))
 }
 }
 
 
 func TestTimerWheel_Add(t *testing.T) {
 func TestTimerWheel_Add(t *testing.T) {
@@ -75,6 +75,23 @@ func TestTimerWheel_Add(t *testing.T) {
 	tw.Add(fp2, time.Second*1)
 	tw.Add(fp2, time.Second*1)
 	assert.Nil(t, tw.itemCache)
 	assert.Nil(t, tw.itemCache)
 	assert.Equal(t, 0, tw.itemsCached)
 	assert.Equal(t, 0, tw.itemsCached)
+
+	// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
+	for min := time.Duration(1); min < 100; min++ {
+		for max := min; max < 100; max++ {
+			tw = NewTimerWheel(min, max)
+
+			for current := 0; current < tw.wheelLen; current++ {
+				tw.current = current
+				for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
+					tick := tw.findWheel(timeout)
+					if tick >= tw.wheelLen {
+						t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
+					}
+				}
+			}
+		}
+	}
 }
 }
 
 
 func TestTimerWheel_Purge(t *testing.T) {
 func TestTimerWheel_Purge(t *testing.T) {
@@ -134,6 +151,10 @@ func TestTimerWheel_Purge(t *testing.T) {
 	tw.advance(ta)
 	tw.advance(ta)
 	assert.Equal(t, 10, tw.current)
 	assert.Equal(t, 10, tw.current)
 
 
+	ta = ta.Add(time.Second * 1)
+	tw.advance(ta)
+	assert.Equal(t, 11, tw.current)
+
 	ta = ta.Add(time.Second * 1)
 	ta = ta.Add(time.Second * 1)
 	tw.advance(ta)
 	tw.advance(ta)
 	assert.Equal(t, 0, tw.current)
 	assert.Equal(t, 0, tw.current)