RadixSortReorder.bsl 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. #include "$ENGINE$/RadixSortCommon.bslinc"
  2. shader RadixSortReorder
  3. {
  4. mixin RadixSortCommon;
  5. code
  6. {
  7. Buffer<uint> gInputKeys;
  8. Buffer<uint> gInputValues;
  9. Buffer<uint> gInputOffsets;
  10. RWBuffer<uint> gOutputKeys;
  11. RWBuffer<uint> gOutputValues;
  12. groupshared uint sGroupOffsets[NUM_DIGITS];
  13. groupshared uint sLocalScratch[NUM_DIGITS * NUM_THREADS];
  14. groupshared uint sTileTotals[NUM_DIGITS];
  15. groupshared uint sCurrentTileTotal[NUM_DIGITS];
  16. // Transforms counts in sLocalScratch into a prefix sum. Also outputs a total sum in sTileTotalSum.
  17. void prefixSum(uint threadId)
  18. {
  19. // Upsweep to generate partial sums
  20. uint offset = 1;
  21. for (uint i = NUM_THREADS >> 1; i > 0; i >>= 1)
  22. {
  23. GroupMemoryBarrierWithGroupSync();
  24. if (threadId < i)
  25. {
  26. // Note: If I run more than NUM_THREADS threads I wouldn't have to
  27. // iterate over all digits in a single thread
  28. // Note: Perhaps run part of this step serially for better performance
  29. for (uint j = 0; j < NUM_DIGITS; j++)
  30. {
  31. uint idx0 = (offset * (2 * threadId + 1) - 1) * NUM_DIGITS + j;
  32. uint idx1 = (offset * (2 * threadId + 2) - 1) * NUM_DIGITS + j;
  33. // Note: Check and remove bank conflicts
  34. sLocalScratch[idx1] += sLocalScratch[idx0];
  35. }
  36. }
  37. offset <<= 1;
  38. }
  39. GroupMemoryBarrierWithGroupSync();
  40. // Set tree roots to zero (prepare for downsweep)
  41. if(threadId < NUM_DIGITS)
  42. {
  43. uint idx = (NUM_THREADS - 1) * NUM_DIGITS + threadId;
  44. sCurrentTileTotal[threadId] = sLocalScratch[idx];
  45. sLocalScratch[idx] = 0;
  46. }
  47. // Downsweep to calculate the prefix sum from partial sums that were generated
  48. // during upsweep
  49. for (uint i = 1; i < NUM_THREADS; i <<= 1)
  50. {
  51. GroupMemoryBarrierWithGroupSync();
  52. offset >>= 1;
  53. if (threadId < i)
  54. {
  55. for (uint j = 0; j < NUM_DIGITS; j++)
  56. {
  57. uint idx0 = (offset * (2 * threadId + 1) - 1) * NUM_DIGITS + j;
  58. uint idx1 = (offset * (2 * threadId + 2) - 1) * NUM_DIGITS + j;
  59. // Note: Check and resolve bank conflicts
  60. uint temp = sLocalScratch[idx0];
  61. sLocalScratch[idx0] = sLocalScratch[idx1];
  62. sLocalScratch[idx1] += temp;
  63. }
  64. }
  65. }
  66. GroupMemoryBarrierWithGroupSync();
  67. }
  68. [numthreads(NUM_THREADS, 1, 1)]
  69. void csmain(uint3 groupThreadId : SV_GroupThreadID, uint3 groupId : SV_GroupID)
  70. {
  71. uint threadId = groupThreadId.x;
  72. if(threadId < NUM_DIGITS)
  73. {
  74. // Load offsets for this group to local memory
  75. sGroupOffsets[threadId] = gInputOffsets[groupId.x * NUM_DIGITS + threadId];
  76. // Clear tile totals
  77. sTileTotals[threadId] = 0;
  78. }
  79. // Handle case when number of tiles isn't exactly divisible by number of groups, in
  80. // which case first N groups handle those extra tiles
  81. uint tileIdx, tileCount;
  82. if(groupId.x < gNumExtraTiles)
  83. {
  84. tileCount = gTilesPerGroup + 1;
  85. tileIdx = groupId.x * tileCount;
  86. }
  87. else
  88. {
  89. tileCount = gTilesPerGroup;
  90. tileIdx = groupId.x * tileCount + gNumExtraTiles;
  91. }
  92. // We need to generate per-thread offsets (prefix sum) of where to store the keys at
  93. // (This is equivalent to what was done in count & prefix sum shaders, except that was done per-group)
  94. //// First, count all digits
  95. uint keyBegin = tileIdx * TILE_SIZE;
  96. uint keyEnd = keyBegin + tileCount * TILE_SIZE;
  97. while(keyBegin < keyEnd)
  98. {
  99. GroupMemoryBarrierWithGroupSync();
  100. // Zero out local counter
  101. for(uint i = 0; i < NUM_DIGITS; i++)
  102. sLocalScratch[i * NUM_THREADS + threadId] = 0;
  103. GroupMemoryBarrierWithGroupSync();
  104. for(uint i = 0; i < KEYS_PER_LOOP; i++)
  105. {
  106. uint idx = keyBegin + threadId * KEYS_PER_LOOP + i;
  107. uint key = gInputKeys[idx];
  108. uint digit = (key >> gBitOffset) & KEY_MASK;
  109. sLocalScratch[threadId * NUM_DIGITS + digit] += 1;
  110. }
  111. // Calculate the prefix sum per-digit
  112. prefixSum(threadId);
  113. // Actually re-order the keys
  114. uint localOffsets[NUM_DIGITS];
  115. for(uint i = 0; i < NUM_DIGITS; i++)
  116. localOffsets[i] = 0;
  117. for(uint i = 0; i < KEYS_PER_LOOP; i++)
  118. {
  119. uint idx = keyBegin + threadId * KEYS_PER_LOOP + i;
  120. uint key = gInputKeys[idx];
  121. uint digit = (key >> gBitOffset) & KEY_MASK;
  122. uint offset = sGroupOffsets[digit] + sTileTotals[digit] + sLocalScratch[threadId * NUM_DIGITS + digit] + localOffsets[digit];
  123. localOffsets[digit]++;
  124. // Note: First write to local memory then attempt to coalesce when writing to global?
  125. gOutputKeys[offset] = key;
  126. gOutputValues[offset] = gInputValues[idx];
  127. }
  128. GroupMemoryBarrierWithGroupSync();
  129. if (threadId < NUM_DIGITS)
  130. sTileTotals[threadId] += sCurrentTileTotal[threadId];
  131. keyBegin += TILE_SIZE;
  132. }
  133. if(groupId.x == (gNumGroups - 1) && gNumExtraKeys > 0)
  134. {
  135. // Zero out local counter
  136. for(uint i = 0; i < NUM_DIGITS; i++)
  137. sLocalScratch[i * NUM_THREADS + threadId] = 0;
  138. GroupMemoryBarrierWithGroupSync();
  139. for(uint i = 0; i < KEYS_PER_LOOP; i++)
  140. {
  141. uint localIdx = threadId * KEYS_PER_LOOP + i;
  142. if(localIdx >= gNumExtraKeys)
  143. continue;
  144. uint idx = keyBegin + localIdx;
  145. uint key = gInputKeys[idx];
  146. uint digit = (key >> gBitOffset) & KEY_MASK;
  147. sLocalScratch[threadId * NUM_DIGITS + digit] += 1;
  148. }
  149. // Calculate the prefix sum per-digit
  150. prefixSum(threadId);
  151. // Actually re-order the keys
  152. uint localOffsets[NUM_DIGITS];
  153. for(uint i = 0; i < NUM_DIGITS; i++)
  154. localOffsets[i] = 0;
  155. for(uint i = 0; i < KEYS_PER_LOOP; i++)
  156. {
  157. uint localIdx = threadId * KEYS_PER_LOOP + i;
  158. if(localIdx >= gNumExtraKeys)
  159. continue;
  160. uint idx = keyBegin + localIdx;
  161. uint key = gInputKeys[idx];
  162. uint digit = (key >> gBitOffset) & KEY_MASK;
  163. uint offset = sGroupOffsets[digit] + sTileTotals[digit] + sLocalScratch[threadId * NUM_DIGITS + digit] + localOffsets[digit];
  164. localOffsets[digit]++;
  165. // Note: First write to local memory then attempt to coalesce when writing to global?
  166. gOutputKeys[offset] = key;
  167. gOutputValues[offset] = gInputValues[idx];
  168. }
  169. }
  170. }
  171. };
  172. };