Browse Source

Support dynamic indexing on matrix subscript. (#140)

Xiang Li 8 years ago
parent
commit
8b5a14ec74

+ 36 - 12
lib/HLSL/HLOperationLower.cpp

@@ -4244,12 +4244,43 @@ void TranslateCBGep(GetElementPtrInst *GEP, Value *handle, Value *baseOffset,
                     DxilFieldAnnotation *prevFieldAnnotation,
                     const DataLayout &DL, DxilTypeSystem &dxilTypeSys);
 
+Value *GenerateVecEltFromGEP(Value *ldData, GetElementPtrInst *GEP,
+                             IRBuilder<> &Builder) {
+  DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
+  Value *baseIdx = (GEP->idx_begin())->get();
+  Value *zeroIdx = Builder.getInt32(0);
+  DXASSERT_LOCALVAR(baseIdx && zeroIdx, baseIdx == zeroIdx,
+                    "base index must be 0");
+  Value *idx = (GEP->idx_begin() + 1)->get();
+  if (ConstantInt *cidx = dyn_cast<ConstantInt>(idx)) {
+    return Builder.CreateExtractElement(ldData, idx);
+  } else {
+    // Dynamic indexing.
+    // Copy vec to array.
+    Type *Ty = ldData->getType();
+    Type *EltTy = Ty->getVectorElementType();
+    unsigned vecSize = Ty->getVectorNumElements();
+    ArrayType *AT = ArrayType::get(EltTy, vecSize);
+    IRBuilder<> AllocaBuilder(
+        GEP->getParent()->getParent()->getEntryBlock().getFirstInsertionPt());
+    Value *tempArray = AllocaBuilder.CreateAlloca(AT);
+    Value *zero = Builder.getInt32(0);
+    for (unsigned int i = 0; i < vecSize; i++) {
+      Value *Elt = Builder.CreateExtractElement(ldData, Builder.getInt32(i));
+      Value *Ptr =
+          Builder.CreateInBoundsGEP(tempArray, {zero, Builder.getInt32(i)});
+      Builder.CreateStore(Elt, Ptr);
+    }
+    // Load from temp array.
+    Value *EltGEP = Builder.CreateInBoundsGEP(tempArray, {zero, idx});
+    return Builder.CreateLoad(EltGEP);
+  }
+}
+
 void TranslateCBAddressUser(Instruction *user, Value *handle, Value *baseOffset,
                             hlsl::OP *hlslOP,
                             DxilFieldAnnotation *prevFieldAnnotation,
                             DxilTypeSystem &dxilTypeSys, const DataLayout &DL) {
-  const Value *zeroIdx = hlslOP->GetU32Const(0);
-
   IRBuilder<> Builder(user);
   if (CallInst *CI = dyn_cast<CallInst>(user)) {
     HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction());
@@ -4350,11 +4381,8 @@ void TranslateCBAddressUser(Instruction *user, Value *handle, Value *baseOffset,
       for (auto U = CI->user_begin(); U != CI->user_end();) {
         Value *subsUser = *(U++);
         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-          DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-          Value *baseIdx = (GEP->idx_begin())->get();
-          DXASSERT_LOCALVAR(baseIdx && zeroIdx, baseIdx == zeroIdx, "base index must be 0");
-          Value *idx = (GEP->idx_begin() + 1)->get();
-          Value *subData = Builder.CreateExtractElement(ldData, idx);
+          Value *subData = GenerateVecEltFromGEP(ldData, GEP, Builder);
+
           for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
             Value *gepUser = *(gepU++);
             // Must be load here;
@@ -4829,11 +4857,7 @@ void TranslateCBAddressUserLegacy(Instruction *user, Value *handle,
       for (auto U = CI->user_begin(); U != CI->user_end();) {
         Value *subsUser = *(U++);
         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(subsUser)) {
-          DXASSERT(GEP->getNumIndices() == 2, "must have 2 level");
-          Value *baseIdx = (GEP->idx_begin())->get();
-          DXASSERT_LOCALVAR(baseIdx, baseIdx == zeroIdx, "base index must be 0");
-          Value *idx = (GEP->idx_begin() + 1)->get();
-          Value *subData = Builder.CreateExtractElement(ldData, idx);
+          Value *subData = GenerateVecEltFromGEP(ldData, GEP, Builder);
           for (auto gepU = GEP->user_begin(); gepU != GEP->user_end();) {
             Value *gepUser = *(gepU++);
             // Must be load here;

+ 10 - 0
tools/clang/test/CodeGenHLSL/matSubscript7.hlsl

@@ -0,0 +1,10 @@
+// RUN: %dxc -E main -T vs_6_0 %s | FileCheck %s
+
+// CHECK: @main
+
+float4x4 m;
+uint i;
+float4 main() : SV_POSITION {
+  float4x4 m2 = m;
+  return m[2][i] + m2[i][i];
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -418,6 +418,7 @@ public:
   TEST_METHOD(CodeGenMatSubscript4)
   TEST_METHOD(CodeGenMatSubscript5)
   TEST_METHOD(CodeGenMatSubscript6)
+  TEST_METHOD(CodeGenMatSubscript7)
   TEST_METHOD(CodeGenMaxMin)
   TEST_METHOD(CodeGenMinprec1)
   TEST_METHOD(CodeGenMinprec2)
@@ -2368,6 +2369,10 @@ TEST_F(CompilerTest, CodeGenMatSubscript6) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\matSubscript6.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenMatSubscript7) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\matSubscript7.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenMaxMin) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\max_min.hlsl");
 }