Bläddra i källkod

Support CallInst when ReplaceStaticIndexingOnVector. (#3354)

* Support CallInst when ReplaceStaticIndexingOnVector.
Xiang Li 4 år sedan
förälder
incheckning
50981cdefe

+ 33 - 1
lib/Transforms/Scalar/LowerTypePasses.cpp

@@ -223,7 +223,8 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
         // Skip the pointer idx.
         Idx++;
         ConstantInt *constIdx = cast<ConstantInt>(Idx);
-
+        // AllocaInst for Call user.
+        AllocaInst *TmpAI = nullptr;
         for (auto GEPU = GEP->user_begin(), GEPE = GEP->user_end();
              GEPU != GEPE;) {
           Instruction *GEPUser = cast<Instruction>(*(GEPU++));
@@ -240,6 +241,37 @@ void DynamicIndexingVectorToArray::ReplaceStaticIndexingOnVector(Value *V) {
             Value *Elt = Builder.CreateExtractElement(ldVal, constIdx);
             ldInst->replaceAllUsesWith(Elt);
             ldInst->eraseFromParent();
+          } else if (CallInst *CI = dyn_cast<CallInst>(GEPUser)) {
+            // Change
+            //    call a->x
+            // into
+            //   tmp = alloca
+            //   b = ld a
+            //   st b.x, tmp
+            //   call tmp
+            //   b = ld a
+            //   b.x = ld tmp
+            //   st b, a
+            if (TmpAI == nullptr) {
+              Type *Ty = GEP->getType()->getPointerElementType();
+              IRBuilder<> AllocaB(CI->getParent()
+                                      ->getParent()
+                                      ->getEntryBlock()
+                                      .getFirstInsertionPt());
+              TmpAI = AllocaB.CreateAlloca(Ty);
+            }
+            Value *ldVal = Builder.CreateLoad(V);
+            Value *Elt = Builder.CreateExtractElement(ldVal, constIdx);
+            Builder.CreateStore(Elt, TmpAI);
+
+            CI->replaceUsesOfWith(GEP, TmpAI);
+
+            Builder.SetInsertPoint(CI->getNextNode());
+            Elt = Builder.CreateLoad(TmpAI);
+
+            ldVal = Builder.CreateLoad(V);
+            ldVal = Builder.CreateInsertElement(ldVal, Elt, constIdx);
+            Builder.CreateStore(ldVal, V);
           } else {
             // Change
             //    st val, a->x

+ 14 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/vector/GetDimOnVectorIndexing.hlsl

@@ -0,0 +1,14 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+// CHECK:call %dx.types.Dimensions @dx.op.getDimensions
+// CHECK:extractvalue %dx.types.Dimensions %{{.*}}, 0
+// CHECK:extractvalue %dx.types.Dimensions %{{.*}}, 1
+// CHECK:extractvalue %dx.types.Dimensions %{{.*}}, 3
+
+TextureCube<float4> T;
+
+float main() : SV_Target {
+  uint iMips = (uint)(0);
+    uint2 Dims = (uint2)(0);
+    (T.GetDimensions((uint(0u)), (Dims)[0], (Dims)[1], (iMips)));
+  return iMips + Dims.x + Dims.y;
+}

+ 34 - 0
tools/clang/test/HLSLFileCheck/hlsl/types/vector/VectorIndexingAsArgument.hlsl

@@ -0,0 +1,34 @@
+// RUN: %dxc -E main -T lib_6_3 %s | FileCheck %s
+
+// Make sure all pointer operand of foo is from alloca.
+// CHECK:%[[A0:.*]] = alloca i32
+// CHECK:%[[A1:.*]] = alloca i32
+// CHECK:%[[A2:.*]] = alloca i32
+// CHECK:%[[A3:.*]] = alloca i32
+// CHECK:%[[A4:.*]] = alloca i32
+// CHECK:%[[A5:.*]] = alloca i32
+// CHECK:%[[A6:.*]] = alloca i32
+
+// CHECK:call void @"\01?foo@@YAXIAIAI00@Z"(i32 0, i32* nonnull dereferenceable(4) %[[A1]], i32* nonnull dereferenceable(4) %[[A0]], i32* nonnull dereferenceable(4) %[[A6]])
+// CHECK:call void @"\01?foo@@YAXIAIAI00@Z"(i32 0, i32* nonnull dereferenceable(4) %[[A5]], i32* nonnull dereferenceable(4) %[[A4]], i32* nonnull dereferenceable(4) %[[A6]])
+// CHECK:call void @"\01?foo@@YAXIAIAI00@Z"(i32 0, i32* nonnull dereferenceable(4) %[[A3]], i32* nonnull dereferenceable(4) %[[A2]], i32* nonnull dereferenceable(4) %[[A6]])
+
+struct DimStruct {
+  uint2 Dims;
+};
+RWStructuredBuffer<DimStruct> SB;
+groupshared uint2 gs_Dims;
+
+void foo(uint i, out uint, out uint, out uint);
+
+
+[numthreads(1,1,1)]
+void main() {
+  uint iMips = (uint)(0);
+  uint2 Dims = (uint2)(0);
+  (foo((uint(0u)), (Dims)[0], (Dims)[1], (iMips)));
+  SB[0].Dims = Dims;
+  (foo((uint(0u)), (SB[1].Dims)[0], (SB[1].Dims)[1], (iMips)));
+  (foo((uint(0u)), (gs_Dims)[0], (gs_Dims)[1], (iMips)));
+  SB[2].Dims = gs_Dims;
+}