浏览代码

Support array of struct when flatten types. (#143)

Xiang Li 8 年之前
父节点
当前提交
29a09802e3

+ 28 - 18
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -2045,8 +2045,10 @@ bool SROA_HLSL::TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
 }
 
 /// LoadVectorArray - Load vector array like [2 x <4 x float>] from
-///  arrays like 4 [2 x float].
-static Value *LoadVectorArray(ArrayType *AT, ArrayRef<Value *> NewElts,
+///  arrays like 4 [2 x float] or struct array like
+///  [2 x { <4 x float>, < 4 x uint> }]
+/// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
+static Value *LoadVectorOrStructArray(ArrayType *AT, ArrayRef<Value *> NewElts,
                               SmallVector<Value *, 8> &idxList,
                               IRBuilder<> &Builder) {
   Type *EltTy = AT->getElementType();
@@ -2059,7 +2061,7 @@ static Value *LoadVectorArray(ArrayType *AT, ArrayRef<Value *> NewElts,
     idxList.emplace_back(idx);
 
     if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
-      Value *EltVal = LoadVectorArray(EltAT, NewElts, idxList, Builder);
+      Value *EltVal = LoadVectorOrStructArray(EltAT, NewElts, idxList, Builder);
       retVal = Builder.CreateInsertValue(retVal, EltVal, i);
     } else {
       assert(EltTy->isVectorTy() ||
@@ -2087,9 +2089,12 @@ static Value *LoadVectorArray(ArrayType *AT, ArrayRef<Value *> NewElts,
   }
   return retVal;
 }
+
 /// LoadVectorArray - Store vector array like [2 x <4 x float>] to
-///  arrays like 4 [2 x float].
-static void StoreVectorArray(ArrayType *AT, Value *val,
+///  arrays like 4 [2 x float] or struct array like
+///  [2 x { <4 x float>, < 4 x uint> }]
+/// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
+static void StoreVectorOrStructArray(ArrayType *AT, Value *val,
                              ArrayRef<Value *> NewElts,
                              SmallVector<Value *, 8> &idxList,
                              IRBuilder<> &Builder) {
@@ -2104,7 +2109,7 @@ static void StoreVectorArray(ArrayType *AT, Value *val,
     idxList.emplace_back(idx);
 
     if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
-      StoreVectorArray(EltAT, elt, NewElts, idxList, Builder);
+      StoreVectorOrStructArray(EltAT, elt, NewElts, idxList, Builder);
     } else {
       assert(EltTy->isVectorTy() ||
              EltTy->isStructTy() && "must be a vector or struct type");
@@ -2532,16 +2537,21 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
   }
 }
 
-/// isVectorArray - Check if T is array of vector.
-static bool isVectorArray(Type *T) {
+static Type *getArrayEltType(Type *T) {
+  while (isa<ArrayType>(T)) {
+    T = T->getArrayElementType();
+  }
+  return T;
+}
+
+/// isVectorOrStructArray - Check if T is array of vector or struct.
+static bool isVectorOrStructArray(Type *T) {
   if (!T->isArrayTy())
     return false;
 
-  while (T->getArrayElementType()->isArrayTy()) {
-    T = T->getArrayElementType();
-  }
+  T = getArrayEltType(T);
 
-  return T->getArrayElementType()->isVectorTy();
+  return T->isStructTy() || T->isVectorTy();
 }
 
 static void SimplifyStructValUsage(Value *StructVal, std::vector<Value *> Elts,
@@ -2596,7 +2606,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
     LI->replaceAllUsesWith(Insert);
     DeadInsts.push_back(LI);
   } else if (isCompatibleAggregate(LIType, ValTy)) {
-    if (isVectorArray(LIType)) {
+    if (isVectorOrStructArray(LIType)) {
       // Replace:
       //   %res = load [2 x <2 x float>] * %alloc
       // with:
@@ -2610,7 +2620,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
       SmallVector<Value *, 8> idxList;
       idxList.emplace_back(zero);
       Value *newLd =
-          LoadVectorArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
+          LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
       LI->replaceAllUsesWith(newLd);
       DeadInsts.push_back(LI);
     } else {
@@ -2674,7 +2684,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
     }
     DeadInsts.push_back(SI);
   } else if (isCompatibleAggregate(SIType, ValTy)) {
-    if (isVectorArray(SIType)) {
+    if (isVectorOrStructArray(SIType)) {
       // Replace:
       //   store [2 x <2 x i32>] %val, [2 x <2 x i32>]* %alloc, align 16
       // with:
@@ -2701,7 +2711,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       Value *zero = ConstantInt::get(i32Ty, 0);
       SmallVector<Value *, 8> idxList;
       idxList.emplace_back(zero);
-      StoreVectorArray(AT, Val, NewElts, idxList, Builder);
+      StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
       DeadInsts.push_back(SI);
     } else {
       // Replace:
@@ -2715,9 +2725,9 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
       Module *M = SI->getModule();
       for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
         Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
-        if (!HLMatrixLower::IsMatrixType(Extract->getType()))
+        if (!HLMatrixLower::IsMatrixType(Extract->getType())) {
           Builder.CreateStore(Extract, NewElts[i]);
-        else {
+        } else {
           // Generate Matrix Store.
           HLModule::EmitHLOperationCall(
               Builder, HLOpcodeGroup::HLMatLoadStore,

+ 34 - 0
tools/clang/test/CodeGenHLSL/structArray.hlsl

@@ -0,0 +1,34 @@
+// RUN: %dxc -E main -T vs_6_0 %s
+
+struct Vertex
+{
+    float4 position     : POSITION0;
+    float4 color        : COLOR0;
+};
+
+struct Interpolants
+{
+    float4 position     : SV_POSITION0;
+    float4 color        : COLOR0;
+};
+
+
+struct T {
+  float4 t;
+};
+
+struct TA {
+  T  ta[2];
+};
+
+TA test(T t[2]) {
+  TA ta = { t };
+  return ta;
+}
+
+Interpolants main(  Vertex In)
+{
+  TA ta = In;
+
+  return test(ta.ta);
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -520,6 +520,7 @@ public:
   TEST_METHOD(CodeGenStruct_Buf1)
   TEST_METHOD(CodeGenStruct_BufHasCounter)
   TEST_METHOD(CodeGenStruct_BufHasCounter2)
+  TEST_METHOD(CodeGenStructArray)
   TEST_METHOD(CodeGenStructCast)
   TEST_METHOD(CodeGenStructCast2)
   TEST_METHOD(CodeGenStructInBuffer)
@@ -2780,6 +2781,10 @@ TEST_F(CompilerTest, CodeGenStruct_BufHasCounter2) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\struct_bufHasCounter2.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenStructArray) {
+  CodeGenTest(L"..\\CodeGenHLSL\\structArray.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenStructCast) {
   CodeGenTest(L"..\\CodeGenHLSL\\StructCast.hlsl");
 }