Browse Source

Support vector indexing on output. (#55)

Xiang Li 8 years ago
parent
commit
3c7734c7da

+ 41 - 6
lib/HLSL/DxilGenerationPass.cpp

@@ -1199,7 +1199,6 @@ void DxilGenerationPass::GenerateDxilInputsOutputs(bool bInput) {
     if (!GV->getType()->isPointerTy()) {
       DXASSERT(bInput, "direct parameter must be input");
       Value *vertexID = undefVertexIdx;
-      Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
       Value *args[] = {OpArg, ID, /*rowIdx*/constZero, /*colIdx*/nullptr, vertexID};
       replaceDirectInputParameter(GV, dxilFunc, cols, args, bI1Cast, hlslOP, EntryBuilder);
       continue;
@@ -1222,7 +1221,6 @@ void DxilGenerationPass::GenerateDxilInputsOutputs(bool bInput) {
       }
 
       if (LoadInst *ldInst = dyn_cast<LoadInst>(info.user)) {
-        Constant *OpArg = hlslOP->GetU32Const((unsigned)opcode);
         Value *args[] = {OpArg, ID, idxVal, info.vectorIdx, vertexID};
         replaceLdWithLdInput(dxilFunc, ldInst, cols, args, bI1Cast, hlslOP);
       }
@@ -1230,10 +1228,47 @@ void DxilGenerationPass::GenerateDxilInputsOutputs(bool bInput) {
         if (bInput) {
           DXASSERT_LOCALVAR(bIsInout, bIsInout, "input should not have store use.");
         } else {
-          DXASSERT(!info.vectorIdx,
-                   "not implement vector indexing on output yet.");
-          replaceStWithStOutput(dxilFunc, stInst, opcode, ID, idxVal, cols,
-                                bI1Cast, hlslOP);
+          if (!info.vectorIdx) {
+            replaceStWithStOutput(dxilFunc, stInst, opcode, ID, idxVal, cols,
+                                  bI1Cast, hlslOP);
+          } else {
+            Value *V = stInst->getValueOperand();
+            Type *Ty = V->getType();
+            DXASSERT(Ty == Ty->getScalarType() && !Ty->isAggregateType(),
+                     "only support scalar here");
+
+            if (ConstantInt *ColIdx = dyn_cast<ConstantInt>(info.vectorIdx)) {
+              IRBuilder<> Builder(stInst);
+              Value *args[] = {OpArg, ID, idxVal, ColIdx, V};
+              GenerateStOutput(dxilFunc, args, Builder, bI1Cast);
+            } else {
+              BasicBlock *BB = stInst->getParent();
+              BasicBlock *EndBB = BB->splitBasicBlock(stInst);
+
+              TerminatorInst *TI = BB->getTerminator();
+              IRBuilder<> SwitchBuilder(TI);
+              LLVMContext &Ctx = m_pHLModule->GetCtx();
+              SwitchInst *Switch =
+                  SwitchBuilder.CreateSwitch(info.vectorIdx, EndBB, cols);
+              TI->eraseFromParent();
+
+              Function *F = EndBB->getParent();
+              for (unsigned i = 0; i < cols; i++) {
+                BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case", F, EndBB);
+                Switch->addCase(SwitchBuilder.getInt32(i), CaseBB);
+                IRBuilder<> CaseBuilder(CaseBB);
+
+                ConstantInt *CaseIdx = SwitchBuilder.getInt8(i);
+
+                Value *args[] = {OpArg, ID, idxVal, CaseIdx, V};
+                GenerateStOutput(dxilFunc, args, CaseBuilder, bI1Cast);
+
+                CaseBuilder.CreateBr(EndBB);
+              }
+            }
+            // remove stInst
+            stInst->eraseFromParent();
+          }
         }
       } else if (CallInst *CI = dyn_cast<CallInst>(info.user)) {
         HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());

+ 26 - 0
tools/clang/test/CodeGenHLSL/outputArray.hlsl

@@ -0,0 +1,26 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: switch
+
+struct Vertex
+{
+    float4 position     : POSITION0;
+    float4 color        : COLOR0;
+    float4 a[3]         : A;
+};
+
+struct Interpolants
+{
+    float4 position     : SV_POSITION0;
+    float4 color        : COLOR0;
+    float4 a[3]         : A;
+};
+
+uint i;
+
+void main( Vertex In, int j : J, out Interpolants output )
+{
+    output = In;
+    output.a[1][i] = 3;
+    output.a[0] = 4;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -423,6 +423,7 @@ public:
   TEST_METHOD(CodeGenOutput4)
   TEST_METHOD(CodeGenOutput5)
   TEST_METHOD(CodeGenOutput6)
+  TEST_METHOD(CodeGenOutputArray)
   TEST_METHOD(CodeGenPassthrough1)
   TEST_METHOD(CodeGenPassthrough2)
   TEST_METHOD(CodeGenPrecise1)
@@ -2198,6 +2199,10 @@ TEST_F(CompilerTest, CodeGenOutput6) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\output6.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenOutputArray) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\outputArray.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenPassthrough1) {
   CodeGenTest(L"..\\CodeGenHLSL\\passthrough1.hlsl");
 }