RadixSortCount.bsl 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. #include "$ENGINE$/RadixSortCommon.bslinc"
  2. shader RadixSortCount
  3. {
  4. mixin RadixSortCommon;
  5. code
  6. {
  7. #define NUM_COUNTERS (NUM_THREADS * NUM_DIGITS)
  8. #define NUM_REDUCE_THREADS 64
  9. #define NUM_REDUCE_THREADS_PER_DIGIT (NUM_REDUCE_THREADS/NUM_DIGITS)
  10. #define NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT (NUM_THREADS/NUM_REDUCE_THREADS_PER_DIGIT)
  11. Buffer<uint> gInputKeys;
  12. RWBuffer<uint> gOutputCounts;
  13. groupshared uint sCounters[NUM_COUNTERS];
  14. groupshared uint sReduceCounters[NUM_REDUCE_THREADS];
  15. [numthreads(NUM_THREADS, 1, 1)]
  16. void csmain(uint3 groupThreadId : SV_GroupThreadID, uint3 groupId : SV_GroupID)
  17. {
  18. uint threadId = groupThreadId.x;
  19. // Initialize counters to zero
  20. for(uint i = 0; i < NUM_DIGITS; i++)
  21. sCounters[threadId * NUM_DIGITS + i] = 0;
  22. if(threadId < NUM_REDUCE_THREADS)
  23. sReduceCounters[threadId] = 0;
  24. GroupMemoryBarrierWithGroupSync();
  25. // Handle case when number of tiles isn't exactly divisible by number of groups, in
  26. // which case first N groups handle those extra tiles
  27. uint tileIdx, tileCount;
  28. if(groupId.x < gNumExtraTiles)
  29. {
  30. tileCount = gTilesPerGroup + 1;
  31. tileIdx = groupId.x * tileCount;
  32. }
  33. else
  34. {
  35. tileCount = gTilesPerGroup;
  36. tileIdx = groupId.x * tileCount + gNumExtraTiles;
  37. }
  38. uint keyBegin = tileIdx * TILE_SIZE;
  39. uint keyEnd = keyBegin + tileCount * TILE_SIZE;
  40. // For each key determine its digits and count how many digits of each type
  41. // there are. We shift and mask the key using a radix in order to only handle
  42. // M digits at a time. Multiple passes are therefore required to fully sort
  43. // the solution.
  44. while(keyBegin < keyEnd)
  45. {
  46. uint key = gInputKeys[keyBegin + threadId];
  47. uint digit = (key >> gBitOffset) & KEY_MASK;
  48. sCounters[threadId * NUM_DIGITS + digit] += 1;
  49. keyBegin += NUM_THREADS;
  50. }
  51. // Unless the number of keys is an exact multiple of the number of tiles, there will
  52. // be an extra set of keys that require per-threading checking in case we go out of
  53. // range. We handle this as a special case for the last group, to avoid paying the
  54. // cost of the check for every key.
  55. if(groupId.x == (gNumGroups - 1))
  56. {
  57. keyBegin = keyEnd;
  58. keyEnd = keyBegin + gNumExtraKeys;
  59. while(keyBegin < keyEnd)
  60. {
  61. if((keyBegin + threadId) < keyEnd)
  62. {
  63. uint key = gInputKeys[keyBegin + threadId];
  64. uint digit = (key >> gBitOffset) & KEY_MASK;
  65. sCounters[threadId * NUM_DIGITS + digit] += 1;
  66. }
  67. keyBegin += NUM_THREADS;
  68. }
  69. }
  70. GroupMemoryBarrierWithGroupSync();
  71. // Reduce the counts for all threads in a group into a single NUM_DIGITS array
  72. if(threadId < NUM_REDUCE_THREADS)
  73. {
  74. uint digitIdx = threadId / NUM_REDUCE_THREADS_PER_DIGIT;
  75. uint setIdx = threadId & (NUM_REDUCE_THREADS_PER_DIGIT - 1);
  76. // First do the sum sequentially to a certain extent (shown to be faster
  77. // than doing it fully parallel). In the end we end up with
  78. // NUM_REDUCE_THREADS_PER_DIGIT sets of digits
  79. uint total = 0;
  80. for(uint i = 0; i < NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT; i++)
  81. {
  82. // Note: Check & reduce bank conflicts
  83. uint threadIdx = (setIdx * NUM_REDUCE_ELEMS_PER_THREAD_PER_DIGIT + i) * NUM_DIGITS;
  84. total += sCounters[threadIdx + digitIdx];
  85. }
  86. sReduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx] = total;
  87. // And do parallel reduction on the result of serial additions
  88. [unroll]
  89. for(uint i = 1; i < NUM_REDUCE_THREADS_PER_DIGIT; i <<= 1)
  90. {
  91. // Not using sync because operations at this point should happen in the same warp
  92. WarpGroupMemoryBarrier();
  93. // Note: Check & reduce bank conflicts
  94. total += sReduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx + i];
  95. sReduceCounters[digitIdx * NUM_REDUCE_THREADS_PER_DIGIT + setIdx] = total;
  96. }
  97. }
  98. GroupMemoryBarrierWithGroupSync();
  99. // Write the summed up per-digit counts to global memory
  100. if(threadId < NUM_DIGITS)
  101. {
  102. gOutputCounts[groupId.x * NUM_DIGITS + threadId] = sReduceCounters[threadId * NUM_REDUCE_THREADS_PER_DIGIT];
  103. }
  104. }
  105. };
  106. };