sort_private.odin 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. //+private
  2. package slice
  3. import "base:intrinsics"
  4. _ :: intrinsics
  5. ORD :: intrinsics.type_is_ordered
  6. Sort_Kind :: enum {
  7. Ordered,
  8. Less,
  9. Cmp,
  10. }
  11. _quick_sort_general :: proc(data: $T/[]$E, a, b, max_depth: int, call: $P, $KIND: Sort_Kind) where (ORD(E) && KIND == .Ordered) || (KIND != .Ordered) #no_bounds_check {
  12. less :: #force_inline proc(a, b: E, call: P) -> bool {
  13. when KIND == .Ordered {
  14. return a < b
  15. } else when KIND == .Less {
  16. return call(a, b)
  17. } else when KIND == .Cmp {
  18. return call(a, b) == .Less
  19. } else {
  20. #panic("unhandled Sort_Kind")
  21. }
  22. }
  23. insertion_sort :: proc(data: $T/[]$E, a, b: int, call: P) #no_bounds_check {
  24. for i in a+1..<b {
  25. for j := i; j > a && less(data[j], data[j-1], call); j -= 1 {
  26. swap(data, j, j-1)
  27. }
  28. }
  29. }
  30. heap_sort :: proc(data: $T/[]$E, a, b: int, call: P) #no_bounds_check {
  31. sift_down :: proc(data: T, lo, hi, first: int, call: P) #no_bounds_check {
  32. root := lo
  33. for {
  34. child := 2*root + 1
  35. if child >= hi {
  36. break
  37. }
  38. if child+1 < hi && less(data[first+child], data[first+child+1], call) {
  39. child += 1
  40. }
  41. if !less(data[first+root], data[first+child], call) {
  42. return
  43. }
  44. swap(data, first+root, first+child)
  45. root = child
  46. }
  47. }
  48. first, lo, hi := a, 0, b-a
  49. for i := (hi-1)/2; i >= 0; i -= 1 {
  50. sift_down(data, i, hi, first, call)
  51. }
  52. for i := hi-1; i >= 0; i -= 1 {
  53. swap(data, first, first+i)
  54. sift_down(data, lo, i, first, call)
  55. }
  56. }
  57. median3 :: proc(data: T, m1, m0, m2: int, call: P) #no_bounds_check {
  58. if less(data[m1], data[m0], call) {
  59. swap(data, m1, m0)
  60. }
  61. if less(data[m2], data[m1], call) {
  62. swap(data, m2, m1)
  63. if less(data[m1], data[m0], call) {
  64. swap(data, m1, m0)
  65. }
  66. }
  67. }
  68. do_pivot :: proc(data: T, lo, hi: int, call: P) -> (midlo, midhi: int) #no_bounds_check {
  69. m := int(uint(lo+hi)>>1)
  70. if hi-lo > 40 {
  71. s := (hi-lo)/8
  72. median3(data, lo, lo+s, lo+s*2, call)
  73. median3(data, m, m-s, m+s, call)
  74. median3(data, hi-1, hi-1-s, hi-1-s*2, call)
  75. }
  76. median3(data, lo, m, hi-1, call)
  77. pivot := lo
  78. a, c := lo+1, hi-1
  79. for ; a < c && less(data[a], data[pivot], call); a += 1 {
  80. }
  81. b := a
  82. for {
  83. for ; b < c && !less(data[pivot], data[b], call); b += 1 { // data[b] <= pivot
  84. }
  85. for ; b < c && less(data[pivot], data[c-1], call); c -=1 { // data[c-1] > pivot
  86. }
  87. if b >= c {
  88. break
  89. }
  90. swap(data, b, c-1)
  91. b += 1
  92. c -= 1
  93. }
  94. protect := hi-c < 5
  95. if !protect && hi-c < (hi-lo)/4 {
  96. dups := 0
  97. if !less(data[pivot], data[hi-1], call) {
  98. swap(data, c, hi-1)
  99. c += 1
  100. dups += 1
  101. }
  102. if !less(data[b-1], data[pivot], call) {
  103. b -= 1
  104. dups += 1
  105. }
  106. if !less(data[m], data[pivot], call) {
  107. swap(data, m, b-1)
  108. b -= 1
  109. dups += 1
  110. }
  111. protect = dups > 1
  112. }
  113. if protect {
  114. for {
  115. for ; a < b && !less(data[b-1], data[pivot], call); b -= 1 {
  116. }
  117. for ; a < b && less(data[a], data[pivot], call); a += 1 {
  118. }
  119. if a >= b {
  120. break
  121. }
  122. swap(data, a, b-1)
  123. a += 1
  124. b -= 1
  125. }
  126. }
  127. swap(data, pivot, b-1)
  128. return b-1, c
  129. }
  130. a, b, max_depth := a, b, max_depth
  131. for b-a > 12 { // only use shell sort for lengths <= 12
  132. if max_depth == 0 {
  133. heap_sort(data, a, b, call)
  134. return
  135. }
  136. max_depth -= 1
  137. mlo, mhi := do_pivot(data, a, b, call)
  138. if mlo-a < b-mhi {
  139. _quick_sort_general(data, a, mlo, max_depth, call, KIND)
  140. a = mhi
  141. } else {
  142. _quick_sort_general(data, mhi, b, max_depth, call, KIND)
  143. b = mlo
  144. }
  145. }
  146. if b-a > 1 {
  147. // Shell short with gap 6
  148. for i in a+6..<b {
  149. if less(data[i], data[i-6], call) {
  150. swap(data, i, i-6)
  151. }
  152. }
  153. insertion_sort(data, a, b, call)
  154. }
  155. }
  156. _stable_sort_general :: proc(data: $T/[]$E, call: $P, $KIND: Sort_Kind) where (ORD(E) && KIND == .Ordered) || (KIND != .Ordered) #no_bounds_check {
  157. less :: #force_inline proc(a, b: E, call: P) -> bool {
  158. when KIND == .Ordered {
  159. return a < b
  160. } else when KIND == .Less {
  161. return call(a, b)
  162. } else when KIND == .Cmp {
  163. return call(a, b) == .Less
  164. } else {
  165. #panic("unhandled Sort_Kind")
  166. }
  167. }
  168. // insertion sort
  169. // TODO(bill): use a different algorithm as insertion sort is O(n^2)
  170. n := len(data)
  171. for i in 1..<n {
  172. for j := i; j > 0 && less(data[j], data[j-1], call); j -= 1 {
  173. swap(data, j, j-1)
  174. }
  175. }
  176. }
  177. _quick_sort_general_with_indices :: proc(data: $T/[]$E, indices: []int, a, b, max_depth: int, call: $P, $KIND: Sort_Kind) where (ORD(E) && KIND == .Ordered) || (KIND != .Ordered) #no_bounds_check {
  178. less :: #force_inline proc(a, b: E, call: P) -> bool {
  179. when KIND == .Ordered {
  180. return a < b
  181. } else when KIND == .Less {
  182. return call(a, b)
  183. } else when KIND == .Cmp {
  184. return call(a, b) == .Less
  185. } else {
  186. #panic("unhandled Sort_Kind")
  187. }
  188. }
  189. insertion_sort :: proc(data: $T/[]$E, indices: []int, a, b: int, call: P) #no_bounds_check {
  190. for i in a+1..<b {
  191. for j := i; j > a && less(data[j], data[j-1], call); j -= 1 {
  192. swap(data, j, j-1)
  193. swap(indices, j, j-1)
  194. }
  195. }
  196. }
  197. heap_sort :: proc(data: $T/[]$E, indices: []int, a, b: int, call: P) #no_bounds_check {
  198. sift_down :: proc(data: T, indices: []int, lo, hi, first: int, call: P) #no_bounds_check {
  199. root := lo
  200. for {
  201. child := 2*root + 1
  202. if child >= hi {
  203. break
  204. }
  205. if child+1 < hi && less(data[first+child], data[first+child+1], call) {
  206. child += 1
  207. }
  208. if !less(data[first+root], data[first+child], call) {
  209. return
  210. }
  211. swap(data, first+root, first+child)
  212. swap(indices, first+root, first+child)
  213. root = child
  214. }
  215. }
  216. first, lo, hi := a, 0, b-a
  217. for i := (hi-1)/2; i >= 0; i -= 1 {
  218. sift_down(data, indices, i, hi, first, call)
  219. }
  220. for i := hi-1; i >= 0; i -= 1 {
  221. swap(data, first, first+i)
  222. swap(indices, first, first+i)
  223. sift_down(data, indices, lo, i, first, call)
  224. }
  225. }
  226. median3 :: proc(data: T, indices: []int, m1, m0, m2: int, call: P) #no_bounds_check {
  227. if less(data[m1], data[m0], call) {
  228. swap(data, m1, m0)
  229. swap(indices, m1, m0)
  230. }
  231. if less(data[m2], data[m1], call) {
  232. swap(data, m2, m1)
  233. swap(indices, m2, m1)
  234. if less(data[m1], data[m0], call) {
  235. swap(data, m1, m0)
  236. swap(indices, m1, m0)
  237. }
  238. }
  239. }
  240. do_pivot :: proc(data: T, indices: []int, lo, hi: int, call: P) -> (midlo, midhi: int) #no_bounds_check {
  241. m := int(uint(lo+hi)>>1)
  242. if hi-lo > 40 {
  243. s := (hi-lo)/8
  244. median3(data, indices, lo, lo+s, lo+s*2, call)
  245. median3(data, indices, m, m-s, m+s, call)
  246. median3(data, indices, hi-1, hi-1-s, hi-1-s*2, call)
  247. }
  248. median3(data, indices, lo, m, hi-1, call)
  249. pivot := lo
  250. a, c := lo+1, hi-1
  251. for ; a < c && less(data[a], data[pivot], call); a += 1 {
  252. }
  253. b := a
  254. for {
  255. for ; b < c && !less(data[pivot], data[b], call); b += 1 { // data[b] <= pivot
  256. }
  257. for ; b < c && less(data[pivot], data[c-1], call); c -=1 { // data[c-1] > pivot
  258. }
  259. if b >= c {
  260. break
  261. }
  262. swap(data, b, c-1)
  263. swap(indices, b, c-1)
  264. b += 1
  265. c -= 1
  266. }
  267. protect := hi-c < 5
  268. if !protect && hi-c < (hi-lo)/4 {
  269. dups := 0
  270. if !less(data[pivot], data[hi-1], call) {
  271. swap(data, c, hi-1)
  272. swap(indices, c, hi-1)
  273. c += 1
  274. dups += 1
  275. }
  276. if !less(data[b-1], data[pivot], call) {
  277. b -= 1
  278. dups += 1
  279. }
  280. if !less(data[m], data[pivot], call) {
  281. swap(data, m, b-1)
  282. swap(indices, m, b-1)
  283. b -= 1
  284. dups += 1
  285. }
  286. protect = dups > 1
  287. }
  288. if protect {
  289. for {
  290. for ; a < b && !less(data[b-1], data[pivot], call); b -= 1 {
  291. }
  292. for ; a < b && less(data[a], data[pivot], call); a += 1 {
  293. }
  294. if a >= b {
  295. break
  296. }
  297. swap(data, a, b-1)
  298. swap(indices, a, b-1)
  299. a += 1
  300. b -= 1
  301. }
  302. }
  303. swap(data, pivot, b-1)
  304. swap(indices, pivot, b-1)
  305. return b-1, c
  306. }
  307. assert(len(data) == len(indices))
  308. a, b, max_depth := a, b, max_depth
  309. for b-a > 12 { // only use shell sort for lengths <= 12
  310. if max_depth == 0 {
  311. heap_sort(data, indices, a, b, call)
  312. return
  313. }
  314. max_depth -= 1
  315. mlo, mhi := do_pivot(data, indices, a, b, call)
  316. if mlo-a < b-mhi {
  317. _quick_sort_general_with_indices(data, indices, a, mlo, max_depth, call, KIND)
  318. a = mhi
  319. } else {
  320. _quick_sort_general_with_indices(data, indices, mhi, b, max_depth, call, KIND)
  321. b = mlo
  322. }
  323. }
  324. if b-a > 1 {
  325. // Shell short with gap 6
  326. for i in a+6..<b {
  327. if less(data[i], data[i-6], call) {
  328. swap(data, i, i-6)
  329. swap(indices, i, i-6)
  330. }
  331. }
  332. insertion_sort(data, indices, a, b, call)
  333. }
  334. }