Browse Source

PythonUtil: weightedChoice should throw IndexError on empty list

Also includes a unit test.

Closes #682
pythonengineer 6 years ago
parent
commit
46a3a72029
2 changed files with 67 additions and 0 deletions
  1. 5 0
      direct/src/showbase/PythonUtil.py
  2. 62 0
      tests/showbase/test_PythonUtil.py

+ 5 - 0
direct/src/showbase/PythonUtil.py

@@ -1130,6 +1130,10 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
     """given a list of (weight, item) pairs, chooses an item based on the
     """given a list of (weight, item) pairs, chooses an item based on the
     weights. rng must return 0..1. if you happen to have the sum of the
     weights. rng must return 0..1. if you happen to have the sum of the
     weights, pass it in 'sum'."""
     weights, pass it in 'sum'."""
+    # Throw an IndexError if we got an empty list.
+    if not choiceList:
+        raise IndexError('Cannot choose from an empty sequence')
+
     # TODO: add support for dicts
     # TODO: add support for dicts
     if sum is None:
     if sum is None:
         sum = 0.
         sum = 0.
@@ -1138,6 +1142,7 @@ def weightedChoice(choiceList, rng=random.random, sum=None):
 
 
     rand = rng()
     rand = rng()
     accum = rand * sum
     accum = rand * sum
+    item = None
     for weight, item in choiceList:
     for weight, item in choiceList:
         accum -= weight
         accum -= weight
         if accum <= 0.:
         if accum <= 0.:

+ 62 - 0
tests/showbase/test_PythonUtil.py

@@ -1,4 +1,5 @@
 from direct.showbase import PythonUtil
 from direct.showbase import PythonUtil
+import pytest
 
 
 
 
 def test_queue():
 def test_queue():
@@ -103,3 +104,64 @@ def test_priority_callbacks():
     pc.clear()
     pc.clear()
     pc()
     pc()
     assert len(l) == 0
     assert len(l) == 0
+
+def test_weighted_choice():
+    # Test PythonUtil.weightedChoice() with no valid list.
+    with pytest.raises(IndexError):
+        PythonUtil.weightedChoice([])
+
+    # Create a sample choice list.
+    # This contains a few tuples containing only a weight
+    # and an arbitrary item.
+    choicelist = [(3, 'item1'), (1, 'item2'), (7, 'item3')]
+
+    # These are the items that we expect.
+    items = ['item1', 'item2', 'item3']
+
+    # Test PythonUtil.weightedChoice() with our choice list.
+    item = PythonUtil.weightedChoice(choicelist)
+
+    # Assert that what we got was at least an available item.
+    assert item in items
+
+    # Create yet another sample choice list, but with a couple more items.
+    choicelist = [(2, 'item1'), (25, 'item2'), (14, 'item3'), (5, 'item4'),
+                  (7, 'item5'), (3, 'item6'), (6, 'item7'), (50, 'item8')]
+
+    # Set the items that we expect again.
+    items = ['item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7', 'item8']
+
+    # The sum of all of the weights is 112.
+    weightsum = 2 + 25 + 14 + 5 + 7 + 3 + 6 + 50
+
+    # Test PythonUtil.weightedChoice() with the sum.
+    item = PythonUtil.weightedChoice(choicelist, sum=weightsum)
+
+    # Assert that we got a valid item (most of the time this should be 'item8').
+    assert item in items
+
+    # Test PythonUtil.weightedChoice(), but with an invalid sum.
+    item = PythonUtil.weightedChoice(choicelist, sum=1)
+
+    # Assert that we got 'item1'.
+    assert item == items[0]
+
+    # Test PythonUtil.weightedChoice() with an invalid sum.
+    # This time, we're using 2000 so that regardless of the random
+    # number, we will still reach the very last item.
+    item = PythonUtil.weightedChoice(choicelist, sum=100000)
+
+    # Assert that we got 'item8', since we would get the last item.
+    assert item == items[-1]
+
+    # Create a bogus random function.
+    rnd = lambda: 0.5
+
+    # Test PythonUtil.weightedChoice() with the bogus function.
+    item = PythonUtil.weightedChoice(choicelist, rng=rnd, sum=weightsum)
+
+    # Assert that we got 'item6'.
+    # We expect 'item6' because 0.5 multiplied by 112 is 56.0.
+    # When subtracting that number by each weight, it will reach 0
+    # by the time it hits 'item6' in the iteration.
+    assert item == items[5]