Enumeration.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. """Utilities for enumeration of finite and countably infinite sets.
  2. """
  3. ###
  4. # Countable iteration
  5. # Simplifies some calculations
  6. class Aleph0(int):
  7. _singleton = None
  8. def __new__(type):
  9. if type._singleton is None:
  10. type._singleton = int.__new__(type)
  11. return type._singleton
  12. def __repr__(self): return '<aleph0>'
  13. def __str__(self): return 'inf'
  14. def __cmp__(self, b):
  15. return 1
  16. def __sub__(self, b):
  17. raise ValueError,"Cannot subtract aleph0"
  18. __rsub__ = __sub__
  19. def __add__(self, b):
  20. return self
  21. __radd__ = __add__
  22. def __mul__(self, b):
  23. if b == 0: return b
  24. return self
  25. __rmul__ = __mul__
  26. def __floordiv__(self, b):
  27. if b == 0: raise ZeroDivisionError
  28. return self
  29. __rfloordiv__ = __floordiv__
  30. __truediv__ = __floordiv__
  31. __rtuediv__ = __floordiv__
  32. __div__ = __floordiv__
  33. __rdiv__ = __floordiv__
  34. def __pow__(self, b):
  35. if b == 0: return 1
  36. return self
  37. aleph0 = Aleph0()
  38. def base(line):
  39. return line*(line+1)//2
  40. def pairToN((x,y)):
  41. line,index = x+y,y
  42. return base(line)+index
  43. def getNthPairInfo(N):
  44. # Avoid various singularities
  45. if N==0:
  46. return (0,0)
  47. # Gallop to find bounds for line
  48. line = 1
  49. next = 2
  50. while base(next)<=N:
  51. line = next
  52. next = line << 1
  53. # Binary search for starting line
  54. lo = line
  55. hi = line<<1
  56. while lo + 1 != hi:
  57. #assert base(lo) <= N < base(hi)
  58. mid = (lo + hi)>>1
  59. if base(mid)<=N:
  60. lo = mid
  61. else:
  62. hi = mid
  63. line = lo
  64. return line, N - base(line)
  65. def getNthPair(N):
  66. line,index = getNthPairInfo(N)
  67. return (line - index, index)
  68. def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
  69. """getNthPairBounded(N, W, H) -> (x, y)
  70. Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
  71. if W <= 0 or H <= 0:
  72. raise ValueError,"Invalid bounds"
  73. elif N >= W*H:
  74. raise ValueError,"Invalid input (out of bounds)"
  75. # Simple case...
  76. if W is aleph0 and H is aleph0:
  77. return getNthPair(N)
  78. # Otherwise simplify by assuming W < H
  79. if H < W:
  80. x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
  81. return y,x
  82. if useDivmod:
  83. return N%W,N//W
  84. else:
  85. # Conceptually we want to slide a diagonal line across a
  86. # rectangle. This gives more interesting results for large
  87. # bounds than using divmod.
  88. # If in lower left, just return as usual
  89. cornerSize = base(W)
  90. if N < cornerSize:
  91. return getNthPair(N)
  92. # Otherwise if in upper right, subtract from corner
  93. if H is not aleph0:
  94. M = W*H - N - 1
  95. if M < cornerSize:
  96. x,y = getNthPair(M)
  97. return (W-1-x,H-1-y)
  98. # Otherwise, compile line and index from number of times we
  99. # wrap.
  100. N = N - cornerSize
  101. index,offset = N%W,N//W
  102. # p = (W-1, 1+offset) + (-1,1)*index
  103. return (W-1-index, 1+offset+index)
  104. def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
  105. x,y = GNP(N,W,H,useDivmod)
  106. assert 0 <= x < W and 0 <= y < H
  107. return x,y
  108. def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
  109. """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
  110. Return the N-th W-tuple, where for 0 <= x_i < H."""
  111. if useLeftToRight:
  112. elts = [None]*W
  113. for i in range(W):
  114. elts[i],N = getNthPairBounded(N, H)
  115. return tuple(elts)
  116. else:
  117. if W==0:
  118. return ()
  119. elif W==1:
  120. return (N,)
  121. elif W==2:
  122. return getNthPairBounded(N, H, H)
  123. else:
  124. LW,RW = W//2, W - (W//2)
  125. L,R = getNthPairBounded(N, H**LW, H**RW)
  126. return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) +
  127. getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
  128. def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
  129. t = GNT(N,W,H,useLeftToRight)
  130. assert len(t) == W
  131. for i in t:
  132. assert i < H
  133. return t
  134. def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
  135. """getNthTuple(N, maxSize, maxElement) -> x
  136. Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
  137. y < maxElement."""
  138. # All zero sized tuples are isomorphic, don't ya know.
  139. if N == 0:
  140. return ()
  141. N -= 1
  142. if maxElement is not aleph0:
  143. if maxSize is aleph0:
  144. raise NotImplementedError,'Max element size without max size unhandled'
  145. bounds = [maxElement**i for i in range(1, maxSize+1)]
  146. S,M = getNthPairVariableBounds(N, bounds)
  147. else:
  148. S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
  149. return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
  150. def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0,
  151. useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
  152. # FIXME: maxsize is inclusive
  153. t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
  154. assert len(t) <= maxSize
  155. for i in t:
  156. assert i < maxElement
  157. return t
  158. def getNthPairVariableBounds(N, bounds):
  159. """getNthPairVariableBounds(N, bounds) -> (x, y)
  160. Given a finite list of bounds (which may be finite or aleph0),
  161. return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
  162. bounds[x]."""
  163. if not bounds:
  164. raise ValueError,"Invalid bounds"
  165. if not (0 <= N < sum(bounds)):
  166. raise ValueError,"Invalid input (out of bounds)"
  167. level = 0
  168. active = range(len(bounds))
  169. active.sort(key=lambda i: bounds[i])
  170. prevLevel = 0
  171. for i,index in enumerate(active):
  172. level = bounds[index]
  173. W = len(active) - i
  174. if level is aleph0:
  175. H = aleph0
  176. else:
  177. H = level - prevLevel
  178. levelSize = W*H
  179. if N<levelSize: # Found the level
  180. idelta,delta = getNthPairBounded(N, W, H)
  181. return active[i+idelta],prevLevel+delta
  182. else:
  183. N -= levelSize
  184. prevLevel = level
  185. else:
  186. raise RuntimError,"Unexpected loop completion"
  187. def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
  188. x,y = GNVP(N,bounds)
  189. assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
  190. return (x,y)
  191. ###
  192. def testPairs():
  193. W = 3
  194. H = 6
  195. a = [[' ' for x in range(10)] for y in range(10)]
  196. b = [[' ' for x in range(10)] for y in range(10)]
  197. for i in range(min(W*H,40)):
  198. x,y = getNthPairBounded(i,W,H)
  199. x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
  200. print i,(x,y),(x2,y2)
  201. a[y][x] = '%2d'%i
  202. b[y2][x2] = '%2d'%i
  203. print '-- a --'
  204. for ln in a[::-1]:
  205. if ''.join(ln).strip():
  206. print ' '.join(ln)
  207. print '-- b --'
  208. for ln in b[::-1]:
  209. if ''.join(ln).strip():
  210. print ' '.join(ln)
  211. def testPairsVB():
  212. bounds = [2,2,4,aleph0,5,aleph0]
  213. a = [[' ' for x in range(15)] for y in range(15)]
  214. b = [[' ' for x in range(15)] for y in range(15)]
  215. for i in range(min(sum(bounds),40)):
  216. x,y = getNthPairVariableBounds(i, bounds)
  217. print i,(x,y)
  218. a[y][x] = '%2d'%i
  219. print '-- a --'
  220. for ln in a[::-1]:
  221. if ''.join(ln).strip():
  222. print ' '.join(ln)
  223. ###
  224. # Toggle to use checked versions of enumeration routines.
  225. if False:
  226. getNthPairVariableBounds = getNthPairVariableBoundsChecked
  227. getNthPairBounded = getNthPairBoundedChecked
  228. getNthNTuple = getNthNTupleChecked
  229. getNthTuple = getNthTupleChecked
  230. if __name__ == '__main__':
  231. testPairs()
  232. testPairsVB()