RadixSortPrefixScan.bsl 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #include "$ENGINE$/RadixSortCommon.bslinc"
  2. shader RadixSortPrefixScan
  3. {
  4. mixin RadixSortCommon;
  5. code
  6. {
  7. Buffer<uint> gInputCounts;
  8. RWBuffer<uint> gOutputOffsets;
  9. groupshared uint sDigitPrefixSum[MAX_NUM_GROUPS * NUM_DIGITS];
  10. groupshared uint sTotalPrefixSum[NUM_DIGITS];
  11. [numthreads(MAX_NUM_GROUPS, 1, 1)]
  12. void csmain(uint3 groupThreadId : SV_GroupThreadID, uint3 groupId : SV_GroupID)
  13. {
  14. uint threadId = groupThreadId.x;
  15. // Load per-group counts into local memory
  16. for(uint i = 0; i < NUM_DIGITS; i++)
  17. {
  18. sDigitPrefixSum[threadId * NUM_DIGITS + i] = gInputCounts[threadId * NUM_DIGITS + i];
  19. }
  20. /* // Load per-group counts into local memory
  21. uint digitIdx = threadId & (NUM_DIGITS - 1);
  22. for(uint i = 0; i < NUM_DIGITS; i++)
  23. {
  24. // Strided loads to avoid bank conflicts:
  25. // Iteration 0: T0: 0, T1: 1, T2: 2, T3: 3, ..., T31: 31
  26. // Iteration 1: T0: 16, T1: 17, T2: 18, T15: 31, T16: 32 (0), ..., T31 : 47 (15)
  27. // ...
  28. uint idx = i * NUM_DIGITS + digitIdx;
  29. sDigitPrefixSum[idx] = gInputCounts[idx];
  30. } */
  31. // Calculate a prefix sum (each group accounting for counts of the
  32. // groups that comes before it, per digit)
  33. //// Upsweep to generate partial sums
  34. uint offset = 1;
  35. for (uint i = MAX_NUM_GROUPS >> 1; i > 0; i >>= 1)
  36. {
  37. GroupMemoryBarrierWithGroupSync();
  38. if (threadId < i)
  39. {
  40. // Note: If I run more than MAX_NUM_GROUPS threads I wouldn't have to
  41. // iterate over all digits in a single thread
  42. // Note: Perhaps run part of this step serially for better performance
  43. for (uint j = 0; j < NUM_DIGITS; j++)
  44. {
  45. uint idx0 = (offset * (2 * threadId + 1) - 1) * NUM_DIGITS + j;
  46. uint idx1 = (offset * (2 * threadId + 2) - 1) * NUM_DIGITS + j;
  47. // Note: Check and remove bank conflicts
  48. sDigitPrefixSum[idx1] += sDigitPrefixSum[idx0];
  49. }
  50. }
  51. offset <<= 1;
  52. }
  53. GroupMemoryBarrierWithGroupSync();
  54. //// Last entry now contains the total count, save it and calculate its prefix
  55. if(threadId < NUM_DIGITS)
  56. {
  57. uint idx = (MAX_NUM_GROUPS - 1) * NUM_DIGITS + threadId;
  58. sTotalPrefixSum[threadId] = sDigitPrefixSum[idx];
  59. // Set tree roots to zero (prepare for downsweep)
  60. sDigitPrefixSum[idx] = 0;
  61. }
  62. //// Downsweep to calculate the prefix sum from partial sums that were generated
  63. //// during upsweep
  64. for (uint i = 1; i < MAX_NUM_GROUPS; i <<= 1)
  65. {
  66. GroupMemoryBarrierWithGroupSync();
  67. offset >>= 1;
  68. if (threadId < i)
  69. {
  70. for (uint j = 0; j < NUM_DIGITS; j++)
  71. {
  72. uint idx0 = (offset * (2 * threadId + 1) - 1) * NUM_DIGITS + j;
  73. uint idx1 = (offset * (2 * threadId + 2) - 1) * NUM_DIGITS + j;
  74. // Note: Check and resolve bank conflicts
  75. uint temp = sDigitPrefixSum[idx0];
  76. sDigitPrefixSum[idx0] = sDigitPrefixSum[idx1];
  77. sDigitPrefixSum[idx1] += temp;
  78. }
  79. }
  80. }
  81. // Calculate prefix sum over the total (serially)
  82. if(threadId == 0)
  83. {
  84. for(uint i = 1; i < NUM_DIGITS; i++)
  85. sTotalPrefixSum[i] += sTotalPrefixSum[i - 1];
  86. // Make it an exclusive sum by shifting
  87. for(uint i = NUM_DIGITS - 1; i > 0; i--)
  88. sTotalPrefixSum[i] = sTotalPrefixSum[i - 1];
  89. sTotalPrefixSum[0] = 0;
  90. }
  91. GroupMemoryBarrierWithGroupSync();
  92. // Add the total to per-digit prefix to generate the global offset,
  93. // and write to global memory
  94. [unroll]
  95. for(uint i = 0; i < NUM_DIGITS; i++)
  96. {
  97. uint idx = threadId * NUM_DIGITS + i;
  98. // Note: Check and resolve bank conflicts
  99. gOutputOffsets[idx] = sTotalPrefixSum[i] + sDigitPrefixSum[idx];
  100. }
  101. }
  102. };
  103. };