|
@@ -16,6 +16,7 @@ int WaveInBreakBlock(int a : A, int b : B);
|
|
|
int WaveInEntry(int a : A, int b : B);
|
|
|
int WaveInSubLoop(int a : A, int b : B);
|
|
|
int WaveInOtherLoop(int a : A, int b : B, int c : C);
|
|
|
+int MultiWaveInMultiLoops(int a : A, int b : B, int c : C, uint d : D);
|
|
|
|
|
|
// CHECK: @dx.break.cond = internal constant
|
|
|
|
|
@@ -37,7 +38,7 @@ int WaveInOtherLoop(int a : A, int b : B, int c : C);
|
|
|
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
// CHECK-SAME: %mainBuf
|
|
|
|
|
|
-int main(int a : A, int b : B, int c : C) : SV_Target
|
|
|
+int main(int a : A, int b : B, int c : C, int d : D) : SV_Target
|
|
|
{
|
|
|
int res = 0;
|
|
|
int i = 0;
|
|
@@ -57,7 +58,7 @@ int main(int a : A, int b : B, int c : C) : SV_Target
|
|
|
}
|
|
|
}
|
|
|
return res + WaveInPostLoop(a, b) + WaveInBreakBlock(a, b) + WaveInEntry(a, b) +
|
|
|
- WaveInSubLoop(a,b) + WaveInOtherLoop(a,b,c);
|
|
|
+ WaveInSubLoop(a,b) + WaveInOtherLoop(a,b,c) + MultiWaveInMultiLoops(a,b,c,d);
|
|
|
}
|
|
|
|
|
|
// Wave moved to after the break block. Expected to keep the block in loop
|
|
@@ -79,7 +80,6 @@ int main(int a : A, int b : B, int c : C) : SV_Target
|
|
|
// CHECK: br i1
|
|
|
|
|
|
// CHECK: call i32 @dx.op.waveReadLaneFirst
|
|
|
-export
|
|
|
int WaveInPostLoop(int a : A, int b : B)
|
|
|
{
|
|
|
int res = 0;
|
|
@@ -117,7 +117,6 @@ int WaveInPostLoop(int a : A, int b : B)
|
|
|
// CHECK-SAME: %breakBuf
|
|
|
// CHECK: br i1
|
|
|
|
|
|
-export
|
|
|
int WaveInBreakBlock(int a : A, int b : B)
|
|
|
{
|
|
|
int res = 0;
|
|
@@ -147,7 +146,6 @@ int WaveInBreakBlock(int a : A, int b : B)
|
|
|
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
// CHECK-SAME: %entryBuf
|
|
|
-export
|
|
|
int WaveInEntry(int a : A, int b : B)
|
|
|
{
|
|
|
int res = 0;
|
|
@@ -190,7 +188,6 @@ int WaveInEntry(int a : A, int b : B)
|
|
|
// CHECK-SAME: %subBuf
|
|
|
// CHECK: add
|
|
|
// CHECK: br i1
|
|
|
-export
|
|
|
int WaveInSubLoop(int a : A, int b : B)
|
|
|
{
|
|
|
int res = 0;
|
|
@@ -235,13 +232,11 @@ int WaveInSubLoop(int a : A, int b : B)
|
|
|
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
// CHECK-SAME: %otherBuf
|
|
|
-// CHECK-NOT: br i1
|
|
|
|
|
|
// These verify the third break block doesn't
|
|
|
// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
// CHECK-SAME: %otherBuf
|
|
|
-// CHECK: add
|
|
|
int WaveInOtherLoop(int a : A, int b : B, int c : C)
|
|
|
{
|
|
|
int res = 0;
|
|
@@ -275,6 +270,134 @@ int WaveInOtherLoop(int a : A, int b : B, int c : C)
|
|
|
return res;
|
|
|
}
|
|
|
|
|
|
+// Complicated case where multiple waves are in multiple loops overlapping and not
|
|
|
+
|
|
|
+// Position all the wave ops
|
|
|
+// CHECK: call i32 @dx.op.waveReadLaneFirst
|
|
|
+// CHECK: call i32 @dx.op.waveActiveOp
|
|
|
+// CHECK: call i32 @dx.op.waveActiveOp
|
|
|
+// CHECK: call i32 @dx.op.waveActiveBit
|
|
|
+
|
|
|
+// These verify the first four break blocks keep the conditional
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+
|
|
|
+// Repeat for second loop
|
|
|
+
|
|
|
+// Position all the wave ops
|
|
|
+// CHECK: call i32 @dx.op.waveReadLaneFirst
|
|
|
+// CHECK: call i32 @dx.op.waveActiveOp
|
|
|
+// CHECK: call i32 @dx.op.waveActiveOp
|
|
|
+// CHECK: call i32 @dx.op.waveActiveBit
|
|
|
+
|
|
|
+// These verify the last four break blocks keep the conditional
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+// CHECK: call %dx.types.Handle @dx.op.createHandle
|
|
|
+// CHECK: call %dx.types.ResRet.i32 @dx.op.rawBufferLoad
|
|
|
+// CHECK-SAME: %mainBuf
|
|
|
+// CHECK: add
|
|
|
+// CHECK: br i1
|
|
|
+
|
|
|
+
|
|
|
+int MultiWaveInMultiLoops(int a : A, int b : B, int c : C, uint d : D)
|
|
|
+{
|
|
|
+ int res = 0;
|
|
|
+ int u = 0;
|
|
|
+ int v = 0;
|
|
|
+ int w = 0;
|
|
|
+ int x = 0;
|
|
|
+
|
|
|
+ for (;;) {
|
|
|
+ u += WaveReadLaneFirst(a);
|
|
|
+ v += WaveActiveSum(b);
|
|
|
+ for (int i = 0; i < c; i++) {
|
|
|
+ w += WaveActiveProduct(c);
|
|
|
+ x += WaveActiveBitAnd(d);
|
|
|
+ }
|
|
|
+ if (a == u + b) {
|
|
|
+ res += mainBuf[u + c][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (a == v + b) {
|
|
|
+ res += mainBuf[v + c][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (a == w + b) {
|
|
|
+ res += mainBuf[w + c][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (a == x + b) {
|
|
|
+ res += mainBuf[x + c][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ u = 0;
|
|
|
+ v = 0;
|
|
|
+ w = 0;
|
|
|
+ x = 0;
|
|
|
+
|
|
|
+ for (;;) {
|
|
|
+ u += WaveReadLaneFirst(a);
|
|
|
+ v += WaveActiveSum(b);
|
|
|
+ for (int i = 0; i < c; i++) {
|
|
|
+ w += WaveActiveProduct(c);
|
|
|
+ x += WaveActiveBitAnd(d);
|
|
|
+ }
|
|
|
+ if (a == b + u) {
|
|
|
+ res += mainBuf[c + u][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (a == b + v) {
|
|
|
+ res += mainBuf[c + v][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (a == b + w) {
|
|
|
+ res += mainBuf[c + w][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ if (a == b + x) {
|
|
|
+ res += mainBuf[c + x][b];
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ return res;
|
|
|
+}
|
|
|
+
|
|
|
// Final operations
|
|
|
// CHECK-NOT: br i1
|
|
|
// CHECK: call void @dx.op.storeOutput
|