avl.odin 15 KB


  1. // A non-intrusive and non-recursive implementation of `AVL` trees.
  2. package container_avl
  3. @(require) import "base:intrinsics"
  4. @(require) import "base:runtime"
  5. import "core:slice"
  6. // Originally based on the CC0 implementation by Eric Biggers
  7. // See: https://github.com/ebiggers/avl_tree/
  8. // Direction specifies the traversal direction for a tree iterator.
  9. Direction :: enum i8 {
  10. // Backward is the in-order backwards direction.
  11. Backward = -1,
  12. // Forward is the in-order forwards direction.
  13. Forward = 1,
  14. }
  15. // Ordering specifies order when inserting/finding values into the tree.
  16. Ordering :: slice.Ordering
  17. // Tree is an AVL tree.
  18. Tree :: struct($Value: typeid) {
  19. // user_data is a parameter that will be passed to the on_remove
  20. // callback.
  21. user_data: rawptr,
  22. // on_remove is an optional callback that can be called immediately
  23. // after a node is removed from the tree.
  24. on_remove: proc(value: Value, user_data: rawptr),
  25. _root: ^Node(Value),
  26. _node_allocator: runtime.Allocator,
  27. _cmp_fn: proc(a, b: Value) -> Ordering,
  28. _size: int,
  29. }
  30. // Node is an AVL tree node.
  31. //
  32. // WARNING: It is unsafe to mutate value if the node is part of a tree
  33. // if doing so will alter the Node's sort position relative to other
  34. // elements in the tree.
  35. Node :: struct($Value: typeid) {
  36. value: Value,
  37. _parent: ^Node(Value),
  38. _left: ^Node(Value),
  39. _right: ^Node(Value),
  40. _balance: i8,
  41. }
  42. // Iterator is a tree iterator.
  43. //
  44. // WARNING: It is unsafe to modify the tree while iterating, except via
  45. // the iterator_remove method.
  46. Iterator :: struct($Value: typeid) {
  47. _tree: ^Tree(Value),
  48. _cur: ^Node(Value),
  49. _next: ^Node(Value),
  50. _direction: Direction,
  51. _called_next: bool,
  52. }
  53. // init initializes a tree.
  54. init :: proc {
  55. init_ordered,
  56. init_cmp,
  57. }
  58. // init_cmp initializes a tree.
  59. init_cmp :: proc(
  60. t: ^$T/Tree($Value),
  61. cmp_fn: proc(a, b: Value) -> Ordering,
  62. node_allocator := context.allocator,
  63. ) {
  64. t._root = nil
  65. t._node_allocator = node_allocator
  66. t._cmp_fn = cmp_fn
  67. t._size = 0
  68. }
  69. // init_ordered initializes a tree containing ordered items, with
  70. // a comparison function that results in an ascending order sort.
  71. init_ordered :: proc(
  72. t: ^$T/Tree($Value),
  73. node_allocator := context.allocator,
  74. ) where intrinsics.type_is_ordered(Value) {
  75. init_cmp(t, slice.cmp_proc(Value), node_allocator)
  76. }
  77. // destroy de-initializes a tree.
  78. destroy :: proc(t: ^$T/Tree($Value), call_on_remove: bool = true) {
  79. iter := iterator(t, Direction.Forward)
  80. for _ in iterator_next(&iter) {
  81. iterator_remove(&iter, call_on_remove)
  82. }
  83. }
  84. // len returns the number of elements in the tree.
  85. len :: proc "contextless" (t: ^$T/Tree($Value)) -> int {
  86. return t._size
  87. }
  88. // first returns the first node in the tree (in-order) or nil iff
  89. // the tree is empty.
  90. first :: proc "contextless" (t: ^$T/Tree($Value)) -> ^Node(Value) {
  91. return tree_first_or_last_in_order(t, Direction.Backward)
  92. }
  93. // last returns the last element in the tree (in-order) or nil iff
  94. // the tree is empty.
  95. last :: proc "contextless" (t: ^$T/Tree($Value)) -> ^Node(Value) {
  96. return tree_first_or_last_in_order(t, Direction.Forward)
  97. }
  98. // find finds the value in the tree, and returns the corresponding
  99. // node or nil iff the value is not present.
  100. find :: proc(t: ^$T/Tree($Value), value: Value) -> ^Node(Value) {
  101. cur := t._root
  102. descend_loop: for cur != nil {
  103. switch t._cmp_fn(value, cur.value) {
  104. case .Less:
  105. cur = cur._left
  106. case .Greater:
  107. cur = cur._right
  108. case .Equal:
  109. break descend_loop
  110. }
  111. }
  112. return cur
  113. }
  114. // find_or_insert attempts to insert the value into the tree, and returns
  115. // the node, a boolean indicating if the value was inserted, and the
  116. // node allocator error if relevant. If the value is already
  117. // present, the existing node is returned un-altered.
  118. find_or_insert :: proc(
  119. t: ^$T/Tree($Value),
  120. value: Value,
  121. ) -> (
  122. n: ^Node(Value),
  123. inserted: bool,
  124. err: runtime.Allocator_Error,
  125. ) {
  126. n_ptr := &t._root
  127. for n_ptr^ != nil {
  128. n = n_ptr^
  129. switch t._cmp_fn(value, n.value) {
  130. case .Less:
  131. n_ptr = &n._left
  132. case .Greater:
  133. n_ptr = &n._right
  134. case .Equal:
  135. return
  136. }
  137. }
  138. parent := n
  139. n = new(Node(Value), t._node_allocator) or_return
  140. n.value = value
  141. n._parent = parent
  142. n_ptr^ = n
  143. tree_rebalance_after_insert(t, n)
  144. t._size += 1
  145. inserted = true
  146. return
  147. }
  148. // remove removes a node or value from the tree, and returns true iff the
  149. // removal was successful. While the node's value will be left intact,
  150. // the node itself will be freed via the tree's node allocator.
  151. remove :: proc {
  152. remove_value,
  153. remove_node,
  154. }
  155. // remove_value removes a value from the tree, and returns true iff the
  156. // removal was successful. While the node's value will be left intact,
  157. // the node itself will be freed via the tree's node allocator.
  158. remove_value :: proc(t: ^$T/Tree($Value), value: Value, call_on_remove: bool = true) -> bool {
  159. n := find(t, value)
  160. if n == nil {
  161. return false
  162. }
  163. return remove_node(t, n, call_on_remove)
  164. }
  165. // remove_node removes a node from the tree, and returns true iff the
  166. // removal was successful. While the node's value will be left intact,
  167. // the node itself will be freed via the tree's node allocator.
  168. remove_node :: proc(t: ^$T/Tree($Value), node: ^Node(Value), call_on_remove: bool = true) -> bool {
  169. if node._parent == node || (node._parent == nil && t._root != node) {
  170. return false
  171. }
  172. defer {
  173. if call_on_remove && t.on_remove != nil {
  174. t.on_remove(node.value, t.user_data)
  175. }
  176. free(node, t._node_allocator)
  177. }
  178. parent: ^Node(Value)
  179. left_deleted: bool
  180. t._size -= 1
  181. if node._left != nil && node._right != nil {
  182. parent, left_deleted = tree_swap_with_successor(t, node)
  183. } else {
  184. child := node._left
  185. if child == nil {
  186. child = node._right
  187. }
  188. parent = node._parent
  189. if parent != nil {
  190. if node == parent._left {
  191. parent._left = child
  192. left_deleted = true
  193. } else {
  194. parent._right = child
  195. left_deleted = false
  196. }
  197. if child != nil {
  198. child._parent = parent
  199. }
  200. } else {
  201. if child != nil {
  202. child._parent = parent
  203. }
  204. t._root = child
  205. node_reset(node)
  206. return true
  207. }
  208. }
  209. for {
  210. if left_deleted {
  211. parent = tree_handle_subtree_shrink(t, parent, +1, &left_deleted)
  212. } else {
  213. parent = tree_handle_subtree_shrink(t, parent, -1, &left_deleted)
  214. }
  215. if parent == nil {
  216. break
  217. }
  218. }
  219. node_reset(node)
  220. return true
  221. }
  222. // iterator returns a tree iterator in the specified direction.
  223. iterator :: proc "contextless" (t: ^$T/Tree($Value), direction: Direction) -> Iterator(Value) {
  224. it: Iterator(Value)
  225. it._tree = transmute(^Tree(Value))t
  226. it._direction = direction
  227. iterator_first(&it)
  228. return it
  229. }
  230. // iterator_from_pos returns a tree iterator in the specified direction,
  231. // spanning the range [pos, last] (inclusive).
  232. iterator_from_pos :: proc "contextless" (
  233. t: ^$T/Tree($Value),
  234. pos: ^Node(Value),
  235. direction: Direction,
  236. ) -> Iterator(Value) {
  237. it: Iterator(Value)
  238. it._tree = transmute(^Tree(Value))t
  239. it._direction = direction
  240. it._next = nil
  241. it._called_next = false
  242. if it._cur = pos; pos != nil {
  243. it._next = node_next_or_prev_in_order(it._cur, it._direction)
  244. }
  245. return it
  246. }
  247. // iterator_get returns the node currently pointed to by the iterator,
  248. // or nil iff the node has been removed, the tree is empty, or the end
  249. // of the tree has been reached.
  250. iterator_get :: proc "contextless" (it: ^$I/Iterator($Value)) -> ^Node(Value) {
  251. return it._cur
  252. }
  253. // iterator_remove removes the node currently pointed to by the iterator,
  254. // and returns true iff the removal was successful. Semantics are the
  255. // same as the Tree remove.
  256. iterator_remove :: proc(it: ^$I/Iterator($Value), call_on_remove: bool = true) -> bool {
  257. if it._cur == nil {
  258. return false
  259. }
  260. ok := remove_node(it._tree, it._cur, call_on_remove)
  261. if ok {
  262. it._cur = nil
  263. }
  264. return ok
  265. }
  266. // iterator_next advances the iterator and returns the (node, true) or
  267. // or (nil, false) iff the end of the tree has been reached.
  268. //
  269. // Note: The first call to iterator_next will return the first node instead
  270. // of advancing the iterator.
  271. iterator_next :: proc "contextless" (it: ^$I/Iterator($Value)) -> (^Node(Value), bool) {
  272. // This check is needed so that the first element gets returned from
  273. // a brand-new iterator, and so that the somewhat contrived case where
  274. // iterator_remove is called before the first call to iterator_next
  275. // returns the correct value.
  276. if !it._called_next {
  277. it._called_next = true
  278. // There can be the contrived case where iterator_remove is
  279. // called before ever calling iterator_next, which needs to be
  280. // handled as an actual call to next.
  281. //
  282. // If this happens it._cur will be nil, so only return the
  283. // first value, if it._cur is valid.
  284. if it._cur != nil {
  285. return it._cur, true
  286. }
  287. }
  288. if it._next == nil {
  289. return nil, false
  290. }
  291. it._cur = it._next
  292. it._next = node_next_or_prev_in_order(it._cur, it._direction)
  293. return it._cur, true
  294. }
  295. @(private)
  296. tree_first_or_last_in_order :: proc "contextless" (
  297. t: ^$T/Tree($Value),
  298. direction: Direction,
  299. ) -> ^Node(Value) {
  300. first, sign := t._root, i8(direction)
  301. if first != nil {
  302. for {
  303. tmp := node_get_child(first, +sign)
  304. if tmp == nil {
  305. break
  306. }
  307. first = tmp
  308. }
  309. }
  310. return first
  311. }
  312. @(private)
  313. tree_replace_child :: proc "contextless" (
  314. t: ^$T/Tree($Value),
  315. parent, old_child, new_child: ^Node(Value),
  316. ) {
  317. if parent != nil {
  318. if old_child == parent._left {
  319. parent._left = new_child
  320. } else {
  321. parent._right = new_child
  322. }
  323. } else {
  324. t._root = new_child
  325. }
  326. }
  327. @(private)
  328. tree_rotate :: proc "contextless" (t: ^$T/Tree($Value), a: ^Node(Value), sign: i8) {
  329. b := node_get_child(a, -sign)
  330. e := node_get_child(b, +sign)
  331. p := a._parent
  332. node_set_child(a, -sign, e)
  333. a._parent = b
  334. node_set_child(b, +sign, a)
  335. b._parent = p
  336. if e != nil {
  337. e._parent = a
  338. }
  339. tree_replace_child(t, p, a, b)
  340. }
  341. @(private)
  342. tree_double_rotate :: proc "contextless" (
  343. t: ^$T/Tree($Value),
  344. b, a: ^Node(Value),
  345. sign: i8,
  346. ) -> ^Node(Value) {
  347. e := node_get_child(b, +sign)
  348. f := node_get_child(e, -sign)
  349. g := node_get_child(e, +sign)
  350. p := a._parent
  351. e_bal := e._balance
  352. node_set_child(a, -sign, g)
  353. a_bal := -e_bal
  354. if sign * e_bal >= 0 {
  355. a_bal = 0
  356. }
  357. node_set_parent_balance(a, e, a_bal)
  358. node_set_child(b, +sign, f)
  359. b_bal := -e_bal
  360. if sign * e_bal <= 0 {
  361. b_bal = 0
  362. }
  363. node_set_parent_balance(b, e, b_bal)
  364. node_set_child(e, +sign, a)
  365. node_set_child(e, -sign, b)
  366. node_set_parent_balance(e, p, 0)
  367. if g != nil {
  368. g._parent = a
  369. }
  370. if f != nil {
  371. f._parent = b
  372. }
  373. tree_replace_child(t, p, a, e)
  374. return e
  375. }
  376. @(private)
  377. tree_handle_subtree_growth :: proc "contextless" (
  378. t: ^$T/Tree($Value),
  379. node, parent: ^Node(Value),
  380. sign: i8,
  381. ) -> bool {
  382. old_balance_factor := parent._balance
  383. if old_balance_factor == 0 {
  384. node_adjust_balance_factor(parent, sign)
  385. return false
  386. }
  387. new_balance_factor := old_balance_factor + sign
  388. if new_balance_factor == 0 {
  389. node_adjust_balance_factor(parent, sign)
  390. return true
  391. }
  392. if sign * node._balance > 0 {
  393. tree_rotate(t, parent, -sign)
  394. node_adjust_balance_factor(parent, -sign)
  395. node_adjust_balance_factor(node, -sign)
  396. } else {
  397. tree_double_rotate(t, node, parent, -sign)
  398. }
  399. return true
  400. }
  401. @(private)
  402. tree_rebalance_after_insert :: proc "contextless" (t: ^$T/Tree($Value), inserted: ^Node(Value)) {
  403. node, parent := inserted, inserted._parent
  404. switch {
  405. case parent == nil:
  406. return
  407. case node == parent._left:
  408. node_adjust_balance_factor(parent, -1)
  409. case:
  410. node_adjust_balance_factor(parent, +1)
  411. }
  412. if parent._balance == 0 {
  413. return
  414. }
  415. for done := false; !done; {
  416. node = parent
  417. if parent = node._parent; parent == nil {
  418. return
  419. }
  420. if node == parent._left {
  421. done = tree_handle_subtree_growth(t, node, parent, -1)
  422. } else {
  423. done = tree_handle_subtree_growth(t, node, parent, +1)
  424. }
  425. }
  426. }
  427. @(private)
  428. tree_swap_with_successor :: proc "contextless" (
  429. t: ^$T/Tree($Value),
  430. x: ^Node(Value),
  431. ) -> (
  432. ^Node(Value),
  433. bool,
  434. ) {
  435. ret: ^Node(Value)
  436. left_deleted: bool
  437. y := x._right
  438. if y._left == nil {
  439. ret = y
  440. } else {
  441. q: ^Node(Value)
  442. for {
  443. q = y
  444. if y = y._left; y._left == nil {
  445. break
  446. }
  447. }
  448. if q._left = y._right; q._left != nil {
  449. q._left._parent = q
  450. }
  451. y._right = x._right
  452. x._right._parent = y
  453. ret = q
  454. left_deleted = true
  455. }
  456. y._left = x._left
  457. x._left._parent = y
  458. y._parent = x._parent
  459. y._balance = x._balance
  460. tree_replace_child(t, x._parent, x, y)
  461. return ret, left_deleted
  462. }
  463. @(private)
  464. tree_handle_subtree_shrink :: proc "contextless" (
  465. t: ^$T/Tree($Value),
  466. parent: ^Node(Value),
  467. sign: i8,
  468. left_deleted: ^bool,
  469. ) -> ^Node(Value) {
  470. old_balance_factor := parent._balance
  471. if old_balance_factor == 0 {
  472. node_adjust_balance_factor(parent, sign)
  473. return nil
  474. }
  475. node: ^Node(Value)
  476. new_balance_factor := old_balance_factor + sign
  477. if new_balance_factor == 0 {
  478. node_adjust_balance_factor(parent, sign)
  479. node = parent
  480. } else {
  481. node = node_get_child(parent, sign)
  482. if sign * node._balance >= 0 {
  483. tree_rotate(t, parent, -sign)
  484. if node._balance == 0 {
  485. node_adjust_balance_factor(node, -sign)
  486. return nil
  487. }
  488. node_adjust_balance_factor(parent, -sign)
  489. node_adjust_balance_factor(node, -sign)
  490. } else {
  491. node = tree_double_rotate(t, node, parent, -sign)
  492. }
  493. }
  494. parent := parent
  495. if parent = node._parent; parent != nil {
  496. left_deleted^ = node == parent._left
  497. }
  498. return parent
  499. }
  500. @(private)
  501. node_reset :: proc "contextless" (n: ^Node($Value)) {
  502. // Mostly pointless as n will be deleted after this is called, but
  503. // attempt to be able to catch cases of n not being in the tree.
  504. n._parent = n
  505. n._left = nil
  506. n._right = nil
  507. n._balance = 0
  508. }
  509. @(private)
  510. node_set_parent_balance :: #force_inline proc "contextless" (
  511. n, parent: ^Node($Value),
  512. balance: i8,
  513. ) {
  514. n._parent = parent
  515. n._balance = balance
  516. }
  517. @(private)
  518. node_get_child :: #force_inline proc "contextless" (n: ^Node($Value), sign: i8) -> ^Node(Value) {
  519. if sign < 0 {
  520. return n._left
  521. }
  522. return n._right
  523. }
  524. @(private)
  525. node_next_or_prev_in_order :: proc "contextless" (
  526. n: ^Node($Value),
  527. direction: Direction,
  528. ) -> ^Node(Value) {
  529. next, tmp: ^Node(Value)
  530. sign := i8(direction)
  531. if next = node_get_child(n, +sign); next != nil {
  532. for {
  533. tmp = node_get_child(next, -sign)
  534. if tmp == nil {
  535. break
  536. }
  537. next = tmp
  538. }
  539. } else {
  540. tmp, next = n, n._parent
  541. for next != nil && tmp == node_get_child(next, +sign) {
  542. tmp, next = next, next._parent
  543. }
  544. }
  545. return next
  546. }
  547. @(private)
  548. node_set_child :: #force_inline proc "contextless" (
  549. n: ^Node($Value),
  550. sign: i8,
  551. child: ^Node(Value),
  552. ) {
  553. if sign < 0 {
  554. n._left = child
  555. } else {
  556. n._right = child
  557. }
  558. }
  559. @(private)
  560. node_adjust_balance_factor :: #force_inline proc "contextless" (n: ^Node($Value), amount: i8) {
  561. n._balance += amount
  562. }
  563. @(private)
  564. iterator_first :: proc "contextless" (it: ^Iterator($Value)) {
  565. // This is private because behavior when the user manually calls
  566. // iterator_first followed by iterator_next is unintuitive, since
  567. // the first call to iterator_next MUST return the first node
  568. // instead of advancing so that `for node in iterator_next(&next)`
  569. // works as expected.
  570. switch it._direction {
  571. case .Forward:
  572. it._cur = tree_first_or_last_in_order(it._tree, .Backward)
  573. case .Backward:
  574. it._cur = tree_first_or_last_in_order(it._tree, .Forward)
  575. }
  576. it._next = nil
  577. it._called_next = false
  578. if it._cur != nil {
  579. it._next = node_next_or_prev_in_order(it._cur, it._direction)
  580. }
  581. }