Ver código fonte

Support write mask for typed buf. (#126)

Xiang Li 8 anos atrás
pai
commit
cf8e34ec11

+ 96 - 14
lib/HLSL/HLOperationLower.cpp

@@ -2890,7 +2890,8 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK,
   }
   }
   // replace
   // replace
   helper.retVal->replaceAllUsesWith(retValNew);
   helper.retVal->replaceAllUsesWith(retValNew);
-
+  // Save new ret val.
+  helper.retVal = retValNew;
   // get status
   // get status
   UpdateStatus(ResRet, helper.status, Builder);
   UpdateStatus(ResRet, helper.status, Builder);
 }
 }
@@ -5762,6 +5763,61 @@ void TranslateStructBufSubscript(CallInst *CI, Value *handle, Value *status,
 
 
 // HLSubscript.
 // HLSubscript.
 namespace {
 namespace {
+
+Value *TranslateTypedBufLoad(CallInst *CI, DXIL::ResourceKind RK,
+                             DXIL::ResourceClass RC, Value *handle,
+                             LoadInst *ldInst, IRBuilder<> &Builder,
+                             hlsl::OP *hlslOP, const DataLayout &DL) {
+  ResLoadHelper ldHelper(CI, RK, RC, handle, /*bForSubscript*/ true);
+  // Default sampleIdx for 2DMS textures.
+  if (RK == DxilResource::Kind::Texture2DMS ||
+      RK == DxilResource::Kind::Texture2DMSArray)
+    ldHelper.mipLevel = hlslOP->GetU32Const(0);
+  // use ldInst as retVal
+  ldHelper.retVal = ldInst;
+  TranslateLoad(ldHelper, RK, Builder, hlslOP, DL);
+  // delete the ld
+  ldInst->eraseFromParent();
+  return ldHelper.retVal;
+}
+
+Value *UpdateVectorElt(Value *VecVal, Value *EltVal, Value *EltIdx,
+                       unsigned vectorSize, Instruction *InsertPt) {
+  IRBuilder<> Builder(InsertPt);
+  if (ConstantInt *CEltIdx = dyn_cast<ConstantInt>(EltIdx)) {
+    VecVal =
+        Builder.CreateInsertElement(VecVal, EltVal, CEltIdx->getLimitedValue());
+  } else {
+    BasicBlock *BB = InsertPt->getParent();
+    BasicBlock *EndBB = BB->splitBasicBlock(InsertPt);
+
+    TerminatorInst *TI = BB->getTerminator();
+    IRBuilder<> SwitchBuilder(TI);
+    LLVMContext &Ctx = InsertPt->getContext();
+
+    SwitchInst *Switch = SwitchBuilder.CreateSwitch(EltIdx, EndBB, vectorSize);
+    TI->eraseFromParent();
+
+    Function *F = EndBB->getParent();
+    IRBuilder<> endSwitchBuilder(EndBB->begin());
+    Type *Ty = VecVal->getType();
+    PHINode *VecPhi = endSwitchBuilder.CreatePHI(Ty, vectorSize + 1);
+
+    for (unsigned i = 0; i < vectorSize; i++) {
+      BasicBlock *CaseBB = BasicBlock::Create(Ctx, "case", F, EndBB);
+      Switch->addCase(SwitchBuilder.getInt32(i), CaseBB);
+      IRBuilder<> CaseBuilder(CaseBB);
+
+      Value *CaseVal = CaseBuilder.CreateInsertElement(VecVal, EltVal, i);
+      VecPhi->addIncoming(CaseVal, CaseBB);
+      CaseBuilder.CreateBr(EndBB);
+    }
+    VecPhi->addIncoming(VecVal, BB);
+    VecVal = VecPhi;
+  }
+  return VecVal;
+}
+
 void TranslateDefaultSubscript(CallInst *CI, HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
 void TranslateDefaultSubscript(CallInst *CI, HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
   auto U = CI->user_begin();
   auto U = CI->user_begin();
 
 
@@ -5779,21 +5835,14 @@ void TranslateDefaultSubscript(CallInst *CI, HLOperationLowerHelper &helper,  HL
   DXIL::ResourceClass RC = pObjHelper->GetRC(resTy);
   DXIL::ResourceClass RC = pObjHelper->GetRC(resTy);
   DXIL::ResourceKind RK = pObjHelper->GetRK(resTy);
   DXIL::ResourceKind RK = pObjHelper->GetRK(resTy);
 
 
+  Type *Ty = CI->getType()->getPointerElementType();
+
   for (auto It = CI->user_begin(); It != CI->user_end(); ) {
   for (auto It = CI->user_begin(); It != CI->user_end(); ) {
     User *user = *(It++);
     User *user = *(It++);
     Instruction *I = cast<Instruction>(user);
     Instruction *I = cast<Instruction>(user);
     IRBuilder<> Builder(I);
     IRBuilder<> Builder(I);
     if (LoadInst *ldInst = dyn_cast<LoadInst>(user)) {
     if (LoadInst *ldInst = dyn_cast<LoadInst>(user)) {
-      ResLoadHelper ldHelper(CI, RK, RC, handle, /*bForSubscript*/ true);
-      // Default sampleIdx for 2DMS textures.
-      if (RK == DxilResource::Kind::Texture2DMS ||
-          RK == DxilResource::Kind::Texture2DMSArray)
-        ldHelper.mipLevel = hlslOP->GetU32Const(0);
-      // use ldInst as retVal
-      ldHelper.retVal = ldInst;
-      TranslateLoad(ldHelper, RK, Builder, hlslOP, helper.legacyDataLayout);
-      // delete the ld
-      ldInst->eraseFromParent();
+      TranslateTypedBufLoad(CI, RK, RC, handle, ldInst, Builder, hlslOP, helper.legacyDataLayout);
     } else if (StoreInst *stInst = dyn_cast<StoreInst>(user)) {
     } else if (StoreInst *stInst = dyn_cast<StoreInst>(user)) {
       Value *val = stInst->getValueOperand();
       Value *val = stInst->getValueOperand();
       TranslateStore(RK, handle, val,
       TranslateStore(RK, handle, val,
@@ -5802,10 +5851,36 @@ void TranslateDefaultSubscript(CallInst *CI, HLOperationLowerHelper &helper,  HL
       // delete the st
       // delete the st
       stInst->eraseFromParent();
       stInst->eraseFromParent();
     } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
     } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
-      // Invalid operations.
-      Translated = false;
-      for (User *GEPUser : GEP->users()) {
+      // Must be vector type here.
+      unsigned vectorSize = Ty->getVectorNumElements();
+      DXASSERT(GEP->getNumIndices() == 2, "");
+      Use *GEPIdx = GEP->idx_begin();
+      GEPIdx++;
+      Value *EltIdx = *GEPIdx;
+      for (auto GEPIt = GEP->user_begin(); GEPIt != GEP->user_end();) {
+        User *GEPUser = *(GEPIt++);
+        if (StoreInst *SI = dyn_cast<StoreInst>(GEPUser)) {
+          IRBuilder<> StBuilder(SI);
+          // Generate Ld.
+          LoadInst *tmpLd = StBuilder.CreateLoad(CI);
+
+          Value *ldVal = TranslateTypedBufLoad(CI, RK, RC, handle, tmpLd, StBuilder,
+                                          hlslOP, helper.legacyDataLayout);
+          // Update vector.
+          ldVal = UpdateVectorElt(ldVal, SI->getValueOperand(), EltIdx,
+                                  vectorSize, SI);
+          // Generate St.
+          // Reset insert point, UpdateVectorElt may move SI to different block.
+          StBuilder.SetInsertPoint(SI);
+          TranslateStore(RK, handle, ldVal,
+                         CI->getArgOperand(HLOperandIndex::kStoreOffsetOpIdx),
+                         StBuilder, hlslOP);
+          SI->eraseFromParent();
+          continue;
+        }
         if (!isa<CallInst>(GEPUser)) {
         if (!isa<CallInst>(GEPUser)) {
+          // Invalid operations.
+          Translated = false;
           CI->getContext().emitError(GEP, "Invalid operation on typed buffer");
           CI->getContext().emitError(GEP, "Invalid operation on typed buffer");
           return;
           return;
         }
         }
@@ -5813,6 +5888,8 @@ void TranslateDefaultSubscript(CallInst *CI, HLOperationLowerHelper &helper,  HL
         HLOpcodeGroup group =
         HLOpcodeGroup group =
             hlsl::GetHLOpcodeGroupByName(userCall->getCalledFunction());
             hlsl::GetHLOpcodeGroupByName(userCall->getCalledFunction());
         if (group != HLOpcodeGroup::HLIntrinsic) {
         if (group != HLOpcodeGroup::HLIntrinsic) {
+          // Invalid operations.
+          Translated = false;
           CI->getContext().emitError(userCall,
           CI->getContext().emitError(userCall,
                                      "Invalid operation on typed buffer");
                                      "Invalid operation on typed buffer");
           return;
           return;
@@ -5831,17 +5908,22 @@ void TranslateDefaultSubscript(CallInst *CI, HLOperationLowerHelper &helper,  HL
         case IntrinsicOp::IOP_InterlockedXor:
         case IntrinsicOp::IOP_InterlockedXor:
         case IntrinsicOp::IOP_InterlockedCompareStore:
         case IntrinsicOp::IOP_InterlockedCompareStore:
         case IntrinsicOp::IOP_InterlockedCompareExchange: {
         case IntrinsicOp::IOP_InterlockedCompareExchange: {
+          // Invalid operations.
+          Translated = false;
           CI->getContext().emitError(
           CI->getContext().emitError(
               userCall, "Atomic operation on typed buffer is not supported");
               userCall, "Atomic operation on typed buffer is not supported");
           return;
           return;
         } break;
         } break;
         default:
         default:
+          // Invalid operations.
+          Translated = false;
           CI->getContext().emitError(userCall,
           CI->getContext().emitError(userCall,
                                      "Invalid operation on typed buffer");
                                      "Invalid operation on typed buffer");
           return;
           return;
           break;
           break;
         }
         }
       }
       }
+      GEP->eraseFromParent();
     } else {
     } else {
       CallInst *userCall = cast<CallInst>(user);
       CallInst *userCall = cast<CallInst>(user);
       HLOpcodeGroup group =
       HLOpcodeGroup group =

+ 17 - 5
tools/clang/lib/CodeGen/CGExpr.cpp

@@ -1694,11 +1694,23 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
   llvm::Value *Zero = Builder.getInt32(0);
   llvm::Value *Zero = Builder.getInt32(0);
   if (VTy) {
   if (VTy) {
     llvm::Type *VecTy = VecDstPtr->getType()->getPointerElementType();
     llvm::Type *VecTy = VecDstPtr->getType()->getPointerElementType();
-    for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
-      if (llvm::Constant *Elt = Elts->getAggregateElement(i)) {
-        llvm::Value *EltGEP = Builder.CreateGEP(VecDstPtr, {Zero, Elt});
-        llvm::Value *SrcElt = Builder.CreateExtractElement(SrcVal, i);
-        Builder.CreateStore(SrcElt, EltGEP);
+    unsigned NumSrcElts = VTy->getNumElements();
+    if (VecTy->getVectorNumElements() == NumSrcElts) {
+      // Full vector write, create one store.
+      for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
+        if (llvm::Constant *Elt = Elts->getAggregateElement(i)) {
+          llvm::Value *SrcElt = Builder.CreateExtractElement(SrcVal, i);
+          Vec = Builder.CreateInsertElement(Vec, SrcElt, Elt);
+        }
+      }
+      Builder.CreateStore(Vec, VecDstPtr);
+    } else {
+      for (unsigned i = 0; i < VecTy->getVectorNumElements(); i++) {
+        if (llvm::Constant *Elt = Elts->getAggregateElement(i)) {
+          llvm::Value *EltGEP = Builder.CreateGEP(VecDstPtr, {Zero, Elt});
+          llvm::Value *SrcElt = Builder.CreateExtractElement(SrcVal, i);
+          Builder.CreateStore(SrcElt, EltGEP);
+        }
       }
       }
     }
     }
   } else {
   } else {

+ 11 - 0
tools/clang/test/CodeGenHLSL/writeMaskBuf2.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK-NOT: dx.op.bufferLoad
+
+
+RWBuffer<int4> buf;
+
+[numthreads(8, 8, 1)]
+void main(uint2 id : SV_DispatchThreadId) {
+  buf[id.x].xyzw = 1;
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/writeMaskBuf3.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK: dx.op.bufferLoad
+
+
+RWBuffer<int4> buf;
+
+[numthreads(8, 8, 1)]
+void main(uint2 id : SV_DispatchThreadId) {
+  buf[id.x].xyz = 1;
+}

+ 11 - 0
tools/clang/test/CodeGenHLSL/writeMaskBuf4.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// CHECK: dx.op.bufferLoad
+// CHECK: switch
+
+RWBuffer<int4> buf;
+
+[numthreads(8, 8, 1)]
+void main(uint2 id : SV_DispatchThreadId) {
+  buf[id.x][id.y] = 1;
+}

+ 15 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -541,6 +541,9 @@ public:
   TEST_METHOD(CodeGenVecCmpCond)
   TEST_METHOD(CodeGenVecCmpCond)
   TEST_METHOD(CodeGenWave)
   TEST_METHOD(CodeGenWave)
   TEST_METHOD(CodeGenWriteMaskBuf)
   TEST_METHOD(CodeGenWriteMaskBuf)
+  TEST_METHOD(CodeGenWriteMaskBuf2)
+  TEST_METHOD(CodeGenWriteMaskBuf3)
+  TEST_METHOD(CodeGenWriteMaskBuf4)
   TEST_METHOD(CodeGenWriteToInput)
   TEST_METHOD(CodeGenWriteToInput)
   TEST_METHOD(CodeGenWriteToInput2)
   TEST_METHOD(CodeGenWriteToInput2)
   TEST_METHOD(CodeGenWriteToInput3)
   TEST_METHOD(CodeGenWriteToInput3)
@@ -2849,6 +2852,18 @@ TEST_F(CompilerTest, CodeGenWriteMaskBuf) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\writeMaskBuf.hlsl");
   CodeGenTestCheck(L"..\\CodeGenHLSL\\writeMaskBuf.hlsl");
 }
 }
 
 
+TEST_F(CompilerTest, CodeGenWriteMaskBuf2) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\writeMaskBuf2.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenWriteMaskBuf3) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\writeMaskBuf3.hlsl");
+}
+
+TEST_F(CompilerTest, CodeGenWriteMaskBuf4) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\writeMaskBuf4.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenWriteToInput) {
 TEST_F(CompilerTest, CodeGenWriteToInput) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\writeToInput.hlsl");
   CodeGenTestCheck(L"..\\CodeGenHLSL\\writeToInput.hlsl");
 }
 }