Browse Source

PythonUtil: weightedChoice should throw IndexError on empty list

Also includes a unit test.

Closes #682
pythonengineer 6 years ago
parent
commit
46a3a72029
2 changed files with 67 additions and 0 deletions
  1. 5 0
      direct/src/showbase/PythonUtil.py
  2. 62 0
      tests/showbase/test_PythonUtil.py

+ 5 - 0
direct/src/showbase/PythonUtil.py

@@ -1130,6 +1130,10 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
     """given a list of (weight, item) pairs, chooses an item based on the
     weights. rng must return 0..1. if you happen to have the sum of the
     weights, pass it in 'sum'."""
+    # Throw an IndexError if we got an empty list.
+    if not choiceList:
+        raise IndexError('Cannot choose from an empty sequence')
+
     # TODO: add support for dicts
     if sum is None:
         sum = 0.
@@ -1138,6 +1142,7 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
 
     rand = rng()
     accum = rand * sum
+    item = None
     for weight, item in choiceList:
         accum -= weight
         if accum <= 0.:

+ 62 - 0
tests/showbase/test_PythonUtil.py

@@ -1,4 +1,5 @@
 from direct.showbase import PythonUtil
+import pytest
 
 
 def test_queue():
@@ -103,3 +104,64 @@ def test_priority_callbacks():
     pc.clear()
     pc()
     assert len(l) == 0
+
+def test_weighted_choice():
+    # Test PythonUtil.weightedChoice() with no valid list.
+    with pytest.raises(IndexError):
+        PythonUtil.weightedChoice([])
+
+    # Create a sample choice list.
+    # This contains a few tuples containing only a weight
+    # and an arbitrary item.
+    choicelist = [(3, 'item1'), (1, 'item2'), (7, 'item3')]
+
+    # These are the items that we expect.
+    items = ['item1', 'item2', 'item3']
+
+    # Test PythonUtil.weightedChoice() with our choice list.
+    item = PythonUtil.weightedChoice(choicelist)
+
+    # Assert that what we got was at least an available item.
+    assert item in items
+
+    # Create yet another sample choice list, but with a couple more items.
+    choicelist = [(2, 'item1'), (25, 'item2'), (14, 'item3'), (5, 'item4'),
+                  (7, 'item5'), (3, 'item6'), (6, 'item7'), (50, 'item8')]
+
+    # Set the items that we expect again.
+    items = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7', 'item8']
+
+    # The sum of all of the weights is 112.
+    weightsum = 2 + 25 + 14 + 5 + 7 + 3 + 6 + 50
+
+    # Test PythonUtil.weightedChoice() with the sum.
+    item = PythonUtil.weightedChoice(choicelist, sum=weightsum)
+
+    # Assert that we got a valid item (most of the time this should be 'item8').
+    assert item in items
+
+    # Test PythonUtil.weightedChoice(), but with an invalid sum.
+    item = PythonUtil.weightedChoice(choicelist, sum=1)
+
+    # Assert that we got 'item1'.
+    assert item == items[0]
+
+    # Test PythonUtil.weightedChoice() with an invalid sum.
+    # This time, we're using 2000 so that regardless of the random
+    # number, we will still reach the very last item.
+    item = PythonUtil.weightedChoice(choicelist, sum=100000)
+
+    # Assert that we got 'item8', since we would get the last item.
+    assert item == items[-1]
+
+    # Create a bogus random function.
+    rnd = lambda: 0.5
+
+    # Test PythonUtil.weightedChoice() with the bogus function.
+    item = PythonUtil.weightedChoice(choicelist, rng=rnd, sum=weightsum)
+
+    # Assert that we got 'item6'.
+    # We expect 'item6' because 0.5 multiplied by 112 is 56.0.
+    # When subtracting that number by each weight, it will reach 0
+    # by the time it hits 'item6' in the iteration.
+    assert item == items[5]