2
0
Эх сурвалжийг харах

Support addrspacecast when flatten global variable.

Xiang Li 7 жил өмнө
parent
commit
d74ff8a97d

+ 31 - 5
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -105,6 +105,7 @@ private:
 
   void RewriteForConstExpr(ConstantExpr *user, IRBuilder<> &Builder);
   void RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder);
+  void RewriteForAddrSpaceCast(ConstantExpr *user, IRBuilder<> &Builder);
   void RewriteForLoad(LoadInst *loadInst);
   void RewriteForStore(StoreInst *storeInst);
   void RewriteMemIntrin(MemIntrinsic *MI, Instruction *Inst);
@@ -3098,6 +3099,22 @@ void SROA_Helper::RewriteCall(CallInst *CI) {
   }
 }
 
+/// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
+void SROA_Helper::RewriteForAddrSpaceCast(ConstantExpr *CE,
+                                          IRBuilder<> &Builder) {
+  SmallVector<Value *, 8> NewCasts;
+  // create new AddrSpaceCast.
+  for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
+    Value *NewGEP = Builder.CreateAddrSpaceCast(
+        NewElts[i],
+        PointerType::get(NewElts[i]->getType()->getPointerElementType(),
+                         CE->getType()->getPointerAddressSpace()));
+    NewCasts.emplace_back(NewGEP);
+  }
+  SROA_Helper helper(CE, NewCasts, DeadInsts);
+  helper.RewriteForScalarRepl(CE, Builder);
+}
+
 /// RewriteForConstExpr - Rewrite the GEP which is ConstantExpr.
 void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
   if (GEPOperator *GEP = dyn_cast<GEPOperator>(CE)) {
@@ -3107,17 +3124,26 @@ void SROA_Helper::RewriteForConstExpr(ConstantExpr *CE, IRBuilder<> &Builder) {
       return;
     }
   }
+  if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+    if (OldVal == CE->getOperand(0)) {
+      // Flatten AddrSpaceCast.
+      RewriteForAddrSpaceCast(CE, Builder);
+      return;
+    }
+  }
   // Skip unused CE. 
   if (CE->use_empty())
     return;
 
-  Instruction *constInst = CE->getAsInstruction();
-  Builder.Insert(constInst);
-  // Replace CE with constInst.
   for (Value::use_iterator UI = CE->use_begin(), E = CE->use_end(); UI != E;) {
     Use &TheUse = *UI++;
-    if (isa<Instruction>(TheUse.getUser()))
-      TheUse.set(constInst);
+    if (Instruction *I = dyn_cast<Instruction>(TheUse.getUser())) {
+      IRBuilder<> tmpBuilder(I);
+      // Replace CE with constInst.
+      Instruction *tmpInst = CE->getAsInstruction();
+      tmpBuilder.Insert(tmpInst);
+      TheUse.set(tmpInst);
+    }
     else {
       RewriteForConstExpr(cast<ConstantExpr>(TheUse.getUser()), Builder);
     }

+ 28 - 0
tools/clang/test/CodeGenHLSL/quick-test/flat_addrspacecast.hlsl

@@ -0,0 +1,28 @@
+// RUN: %dxc -E main -T cs_6_0 %s | FileCheck %s
+
+// Make sure generate addrspacecast.
+// CHECK: addrspacecast (float addrspace(3)*
+
+struct ST
+{
+	float3 a; // center
+	float3 b; // half extents
+
+        void func(float3 x, float3 y)
+	{
+		a = x + y;
+		b = x * y;
+	}
+};
+
+groupshared ST myST;
+StructuredBuffer<ST> buf0;
+float3 a;
+float3 b;
+RWBuffer<float3> buf1;
+[numthreads(8,8,1)]
+void main() {
+  myST = buf0[0];
+  myST.func(a, b);
+  buf1[0] = myST.b;
+}