Przeglądaj źródła

add more tests around bits counters (#1441)

Co-authored-by: Nate Brown <[email protected]>
Wade Simmons 3 tygodni temu
rodzic
commit
27ea667aee
2 zmienionych plików z 118 dodań i 100 usunięć
  1. 32 77
      bits.go
  2. 86 23
      bits_test.go

+ 32 - 77
bits.go

@@ -9,14 +9,13 @@ type Bits struct {
 	length             uint64
 	current            uint64
 	bits               []bool
-	firstSeen          bool
 	lostCounter        metrics.Counter
 	dupeCounter        metrics.Counter
 	outOfWindowCounter metrics.Counter
 }
 
 func NewBits(bits uint64) *Bits {
-	return &Bits{
+	b := &Bits{
 		length:             bits,
 		bits:               make([]bool, bits, bits),
 		current:            0,
@@ -24,34 +23,37 @@ func NewBits(bits uint64) *Bits {
 		dupeCounter:        metrics.GetOrRegisterCounter("network.packets.duplicate", nil),
 		outOfWindowCounter: metrics.GetOrRegisterCounter("network.packets.out_of_window", nil),
 	}
+
+	// There is no counter value 0, mark it to avoid counting a lost packet later.
+	b.bits[0] = true
+	b.current = 0
+	return b
 }
 
-func (b *Bits) Check(l logrus.FieldLogger, i uint64) bool {
+func (b *Bits) Check(l *logrus.Logger, i uint64) bool {
 	// If i is the next number, return true.
-	if i > b.current || (i == 0 && b.firstSeen == false && b.current < b.length) {
+	if i > b.current {
 		return true
 	}
 
-	// If i is within the window, check if it's been set already. The first window will fail this check
-	if i > b.current-b.length {
-		return !b.bits[i%b.length]
-	}
-
-	// If i is within the first window
-	if i < b.length {
+	// If i is within the window, check if it's been set already.
+	if i > b.current-b.length || i < b.length && b.current < b.length {
 		return !b.bits[i%b.length]
 	}
 
 	// Not within the window
-	l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
+	if l.Level >= logrus.DebugLevel {
+		l.Debugf("rejected a packet (top) %d %d\n", b.current, i)
+	}
 	return false
 }
 
 func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 	// If i is the next number, return true and update current.
 	if i == b.current+1 {
-		// Report missed packets, we can only understand what was missed after the first window has been gone through
-		if i > b.length && b.bits[i%b.length] == false {
+		// Check if the oldest bit was lost since we are shifting the window by 1 and occupying it with this counter
+		// The very first window can only be tracked as lost once we are on the 2nd window or greater
+		if b.bits[i%b.length] == false && i > b.length {
 			b.lostCounter.Inc(1)
 		}
 		b.bits[i%b.length] = true
@@ -59,61 +61,32 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 		return true
 	}
 
-	// If i packet is greater than current but less than the maximum length of our bitmap,
-	// flip everything in between to false and move ahead.
-	if i > b.current && i < b.current+b.length {
-		// In between current and i need to be zero'd to allow those packets to come in later
-		for n := b.current + 1; n < i; n++ {
+	// If i is a jump, adjust the window, record lost, update current, and return true
+	if i > b.current {
+		lost := int64(0)
+		// Zero out the bits between the current and the new counter value, limited by the window size,
+		// since the window is shifting
+		for n := b.current + 1; n <= min(i, b.current+b.length); n++ {
+			if b.bits[n%b.length] == false && n > b.length {
+				lost++
+			}
 			b.bits[n%b.length] = false
 		}
 
-		b.bits[i%b.length] = true
-		b.current = i
-		//l.Debugf("missed %d packets between %d and %d\n", i-b.current, i, b.current)
-		return true
-	}
-
-	// If i is greater than the delta between current and the total length of our bitmap,
-	// just flip everything in the map and move ahead.
-	if i >= b.current+b.length {
-		// The current window loss will be accounted for later, only record the jump as loss up until then
-		lost := maxInt64(0, int64(i-b.current-b.length))
-		//TODO: explain this
-		if b.current == 0 {
-			lost++
-		}
-
-		for n := range b.bits {
-			// Don't want to count the first window as a loss
-			//TODO: this is likely wrong, we are wanting to track only the bit slots that we aren't going to track anymore and this is marking everything as missed
-			//if b.bits[n] == false {
-			//	lost++
-			//}
-			b.bits[n] = false
-		}
-
+		// Only record any skipped packets as a result of the window moving further than the window length
+		// Any loss within the new window will be accounted for in future calls
+		lost += max(0, int64(i-b.current-b.length))
 		b.lostCounter.Inc(lost)
 
-		if l.Level >= logrus.DebugLevel {
-			l.WithField("receiveWindow", m{"accepted": true, "currentCounter": b.current, "incomingCounter": i, "reason": "window shifting"}).
-				Debug("Receive window")
-		}
 		b.bits[i%b.length] = true
 		b.current = i
 		return true
 	}
 
-	// Allow for the 0 packet to come in within the first window
-	if i == 0 && b.firstSeen == false && b.current < b.length {
-		b.firstSeen = true
-		b.bits[i%b.length] = true
-		return true
-	}
-
-	// If i is within the window of current minus length (the total pat window size),
-	// allow it and flip to true but to NOT change current. We also have to account for the first window
-	if ((b.current >= b.length && i > b.current-b.length) || (b.current < b.length && i < b.length)) && i <= b.current {
-		if b.current == i {
+	// If i is within the current window but below the current counter,
+	// Check to see if it's a duplicate
+	if i > b.current-b.length || i < b.length && b.current < b.length {
+		if b.current == i || b.bits[i%b.length] == true {
 			if l.Level >= logrus.DebugLevel {
 				l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "duplicate"}).
 					Debug("Receive window")
@@ -122,18 +95,8 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 			return false
 		}
 
-		if b.bits[i%b.length] == true {
-			if l.Level >= logrus.DebugLevel {
-				l.WithField("receiveWindow", m{"accepted": false, "currentCounter": b.current, "incomingCounter": i, "reason": "old duplicate"}).
-					Debug("Receive window")
-			}
-			b.dupeCounter.Inc(1)
-			return false
-		}
-
 		b.bits[i%b.length] = true
 		return true
-
 	}
 
 	// In all other cases, fail and don't change current.
@@ -147,11 +110,3 @@ func (b *Bits) Update(l *logrus.Logger, i uint64) bool {
 	}
 	return false
 }
-
-func maxInt64(a, b int64) int64 {
-	if a > b {
-		return a
-	}
-
-	return b
-}

+ 86 - 23
bits_test.go

@@ -15,48 +15,41 @@ func TestBits(t *testing.T) {
 	assert.Len(t, b.bits, 10)
 
 	// This is initialized to zero - receive one. This should work.
-
 	assert.True(t, b.Check(l, 1))
-	u := b.Update(l, 1)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 1))
 	assert.EqualValues(t, 1, b.current)
-	g := []bool{false, true, false, false, false, false, false, false, false, false}
+	g := []bool{true, true, false, false, false, false, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Receive two
 	assert.True(t, b.Check(l, 2))
-	u = b.Update(l, 2)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 2))
 	assert.EqualValues(t, 2, b.current)
-	g = []bool{false, true, true, false, false, false, false, false, false, false}
+	g = []bool{true, true, true, false, false, false, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Receive two again - it will fail
 	assert.False(t, b.Check(l, 2))
-	u = b.Update(l, 2)
-	assert.False(t, u)
+	assert.False(t, b.Update(l, 2))
 	assert.EqualValues(t, 2, b.current)
 
 	// Jump ahead to 15, which should clear everything and set the 6th element
 	assert.True(t, b.Check(l, 15))
-	u = b.Update(l, 15)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 15))
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, false, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Mark 14, which is allowed because it is in the window
 	assert.True(t, b.Check(l, 14))
-	u = b.Update(l, 14)
-	assert.True(t, u)
+	assert.True(t, b.Update(l, 14))
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, true, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
 
 	// Mark 5, which is not allowed because it is not in the window
 	assert.False(t, b.Check(l, 5))
-	u = b.Update(l, 5)
-	assert.False(t, u)
+	assert.False(t, b.Update(l, 5))
 	assert.EqualValues(t, 15, b.current)
 	g = []bool{false, false, false, false, true, true, false, false, false, false}
 	assert.Equal(t, g, b.bits)
@@ -69,10 +62,29 @@ func TestBits(t *testing.T) {
 
 	// Walk through a few windows in order
 	b = NewBits(10)
-	for i := uint64(0); i <= 100; i++ {
+	for i := uint64(1); i <= 100; i++ {
 		assert.True(t, b.Check(l, i), "Error while checking %v", i)
 		assert.True(t, b.Update(l, i), "Error while updating %v", i)
 	}
+
+	assert.False(t, b.Check(l, 1), "Out of window check")
+}
+
+func TestBitsLargeJumps(t *testing.T) {
+	l := test.NewLogger()
+	b := NewBits(10)
+	b.lostCounter.Clear()
+
+	b = NewBits(10)
+	b.lostCounter.Clear()
+	assert.True(t, b.Update(l, 55)) // We saw packet 55 and can still track 45,46,47,48,49,50,51,52,53,54
+	assert.Equal(t, int64(45), b.lostCounter.Count())
+
+	assert.True(t, b.Update(l, 100)) // We saw packet 55 and 100 and can still track 90,91,92,93,94,95,96,97,98,99
+	assert.Equal(t, int64(89), b.lostCounter.Count())
+
+	assert.True(t, b.Update(l, 200)) // We saw packet 55, 100, and 200 and can still track 190,191,192,193,194,195,196,197,198,199
+	assert.Equal(t, int64(188), b.lostCounter.Count())
 }
 
 func TestBitsDupeCounter(t *testing.T) {
@@ -124,8 +136,7 @@ func TestBitsOutOfWindowCounter(t *testing.T) {
 	assert.False(t, b.Update(l, 0))
 	assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
 
-	//tODO: make sure lostcounter doesn't increase in orderly increment
-	assert.Equal(t, int64(20), b.lostCounter.Count())
+	assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
 	assert.Equal(t, int64(0), b.dupeCounter.Count())
 	assert.Equal(t, int64(1), b.outOfWindowCounter.Count())
 }
@@ -137,8 +148,6 @@ func TestBitsLostCounter(t *testing.T) {
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	//assert.True(t, b.Update(0))
-	assert.True(t, b.Update(l, 0))
 	assert.True(t, b.Update(l, 20))
 	assert.True(t, b.Update(l, 21))
 	assert.True(t, b.Update(l, 22))
@@ -149,7 +158,7 @@ func TestBitsLostCounter(t *testing.T) {
 	assert.True(t, b.Update(l, 27))
 	assert.True(t, b.Update(l, 28))
 	assert.True(t, b.Update(l, 29))
-	assert.Equal(t, int64(20), b.lostCounter.Count())
+	assert.Equal(t, int64(19), b.lostCounter.Count()) // packet 0 wasn't lost
 	assert.Equal(t, int64(0), b.dupeCounter.Count())
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 
@@ -158,8 +167,6 @@ func TestBitsLostCounter(t *testing.T) {
 	b.dupeCounter.Clear()
 	b.outOfWindowCounter.Clear()
 
-	assert.True(t, b.Update(l, 0))
-	assert.Equal(t, int64(0), b.lostCounter.Count())
 	assert.True(t, b.Update(l, 9))
 	assert.Equal(t, int64(0), b.lostCounter.Count())
 	// 10 will set 0 index, 0 was already set, no lost packets
@@ -214,6 +221,62 @@ func TestBitsLostCounter(t *testing.T) {
 	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
 }
 
+func TestBitsLostCounterIssue1(t *testing.T) {
+	l := test.NewLogger()
+	b := NewBits(10)
+	b.lostCounter.Clear()
+	b.dupeCounter.Clear()
+	b.outOfWindowCounter.Clear()
+
+	assert.True(t, b.Update(l, 4))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 1))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 9))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 2))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 3))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 5))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 6))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 7))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	// assert.True(t, b.Update(l, 8))
+	assert.True(t, b.Update(l, 10))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 11))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+
+	assert.True(t, b.Update(l, 14))
+	assert.Equal(t, int64(0), b.lostCounter.Count())
+	// Issue seems to be here, we reset missing packet 8 to false here and don't increment the lost counter
+	assert.True(t, b.Update(l, 19))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 12))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 13))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 15))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 16))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 17))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 18))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 20))
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.True(t, b.Update(l, 21))
+
+	// We missed packet 8 above
+	assert.Equal(t, int64(1), b.lostCounter.Count())
+	assert.Equal(t, int64(0), b.dupeCounter.Count())
+	assert.Equal(t, int64(0), b.outOfWindowCounter.Count())
+}
+
 func BenchmarkBits(b *testing.B) {
 	z := NewBits(10)
 	for n := 0; n < b.N; n++ {