|
@@ -2045,8 +2045,10 @@ bool SROA_HLSL::TypeHasComponent(Type *T, uint64_t Offset, uint64_t Size,
|
|
|
}
|
|
|
|
|
|
/// LoadVectorArray - Load vector array like [2 x <4 x float>] from
|
|
|
-/// arrays like 4 [2 x float].
|
|
|
-static Value *LoadVectorArray(ArrayType *AT, ArrayRef<Value *> NewElts,
|
|
|
+/// arrays like 4 [2 x float] or struct array like
|
|
|
+/// [2 x { <4 x float>, < 4 x uint> }]
|
|
|
+/// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
|
|
|
+static Value *LoadVectorOrStructArray(ArrayType *AT, ArrayRef<Value *> NewElts,
|
|
|
SmallVector<Value *, 8> &idxList,
|
|
|
IRBuilder<> &Builder) {
|
|
|
Type *EltTy = AT->getElementType();
|
|
@@ -2059,7 +2061,7 @@ static Value *LoadVectorArray(ArrayType *AT, ArrayRef<Value *> NewElts,
|
|
|
idxList.emplace_back(idx);
|
|
|
|
|
|
if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
|
|
|
- Value *EltVal = LoadVectorArray(EltAT, NewElts, idxList, Builder);
|
|
|
+ Value *EltVal = LoadVectorOrStructArray(EltAT, NewElts, idxList, Builder);
|
|
|
retVal = Builder.CreateInsertValue(retVal, EltVal, i);
|
|
|
} else {
|
|
|
assert(EltTy->isVectorTy() ||
|
|
@@ -2087,9 +2089,12 @@ static Value *LoadVectorArray(ArrayType *AT, ArrayRef<Value *> NewElts,
|
|
|
}
|
|
|
return retVal;
|
|
|
}
|
|
|
+
|
|
|
/// LoadVectorArray - Store vector array like [2 x <4 x float>] to
|
|
|
-/// arrays like 4 [2 x float].
|
|
|
-static void StoreVectorArray(ArrayType *AT, Value *val,
|
|
|
+/// arrays like 4 [2 x float] or struct array like
|
|
|
+/// [2 x { <4 x float>, < 4 x uint> }]
|
|
|
+/// from arrays like [ 2 x <4 x float> ], [ 2 x <4 x uint> ].
|
|
|
+static void StoreVectorOrStructArray(ArrayType *AT, Value *val,
|
|
|
ArrayRef<Value *> NewElts,
|
|
|
SmallVector<Value *, 8> &idxList,
|
|
|
IRBuilder<> &Builder) {
|
|
@@ -2104,7 +2109,7 @@ static void StoreVectorArray(ArrayType *AT, Value *val,
|
|
|
idxList.emplace_back(idx);
|
|
|
|
|
|
if (ArrayType *EltAT = dyn_cast<ArrayType>(EltTy)) {
|
|
|
- StoreVectorArray(EltAT, elt, NewElts, idxList, Builder);
|
|
|
+ StoreVectorOrStructArray(EltAT, elt, NewElts, idxList, Builder);
|
|
|
} else {
|
|
|
assert(EltTy->isVectorTy() ||
|
|
|
EltTy->isStructTy() && "must be a vector or struct type");
|
|
@@ -2532,16 +2537,21 @@ void SROA_Helper::RewriteForGEP(GEPOperator *GEP, IRBuilder<> &Builder) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-/// isVectorArray - Check if T is array of vector.
|
|
|
-static bool isVectorArray(Type *T) {
|
|
|
+static Type *getArrayEltType(Type *T) {
|
|
|
+ while (isa<ArrayType>(T)) {
|
|
|
+ T = T->getArrayElementType();
|
|
|
+ }
|
|
|
+ return T;
|
|
|
+}
|
|
|
+
|
|
|
+/// isVectorOrStructArray - Check if T is array of vector or struct.
|
|
|
+static bool isVectorOrStructArray(Type *T) {
|
|
|
if (!T->isArrayTy())
|
|
|
return false;
|
|
|
|
|
|
- while (T->getArrayElementType()->isArrayTy()) {
|
|
|
- T = T->getArrayElementType();
|
|
|
- }
|
|
|
+ T = getArrayEltType(T);
|
|
|
|
|
|
- return T->getArrayElementType()->isVectorTy();
|
|
|
+ return T->isStructTy() || T->isVectorTy();
|
|
|
}
|
|
|
|
|
|
static void SimplifyStructValUsage(Value *StructVal, std::vector<Value *> Elts,
|
|
@@ -2596,7 +2606,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
|
|
|
LI->replaceAllUsesWith(Insert);
|
|
|
DeadInsts.push_back(LI);
|
|
|
} else if (isCompatibleAggregate(LIType, ValTy)) {
|
|
|
- if (isVectorArray(LIType)) {
|
|
|
+ if (isVectorOrStructArray(LIType)) {
|
|
|
// Replace:
|
|
|
// %res = load [2 x <2 x float>] * %alloc
|
|
|
// with:
|
|
@@ -2610,7 +2620,7 @@ void SROA_Helper::RewriteForLoad(LoadInst *LI) {
|
|
|
SmallVector<Value *, 8> idxList;
|
|
|
idxList.emplace_back(zero);
|
|
|
Value *newLd =
|
|
|
- LoadVectorArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
|
|
|
+ LoadVectorOrStructArray(cast<ArrayType>(LIType), NewElts, idxList, Builder);
|
|
|
LI->replaceAllUsesWith(newLd);
|
|
|
DeadInsts.push_back(LI);
|
|
|
} else {
|
|
@@ -2674,7 +2684,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
|
|
|
}
|
|
|
DeadInsts.push_back(SI);
|
|
|
} else if (isCompatibleAggregate(SIType, ValTy)) {
|
|
|
- if (isVectorArray(SIType)) {
|
|
|
+ if (isVectorOrStructArray(SIType)) {
|
|
|
// Replace:
|
|
|
// store [2 x <2 x i32>] %val, [2 x <2 x i32>]* %alloc, align 16
|
|
|
// with:
|
|
@@ -2701,7 +2711,7 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
|
|
|
Value *zero = ConstantInt::get(i32Ty, 0);
|
|
|
SmallVector<Value *, 8> idxList;
|
|
|
idxList.emplace_back(zero);
|
|
|
- StoreVectorArray(AT, Val, NewElts, idxList, Builder);
|
|
|
+ StoreVectorOrStructArray(AT, Val, NewElts, idxList, Builder);
|
|
|
DeadInsts.push_back(SI);
|
|
|
} else {
|
|
|
// Replace:
|
|
@@ -2715,9 +2725,9 @@ void SROA_Helper::RewriteForStore(StoreInst *SI) {
|
|
|
Module *M = SI->getModule();
|
|
|
for (unsigned i = 0, e = NewElts.size(); i != e; ++i) {
|
|
|
Value *Extract = Builder.CreateExtractValue(Val, i, Val->getName());
|
|
|
- if (!HLMatrixLower::IsMatrixType(Extract->getType()))
|
|
|
+ if (!HLMatrixLower::IsMatrixType(Extract->getType())) {
|
|
|
Builder.CreateStore(Extract, NewElts[i]);
|
|
|
- else {
|
|
|
+ } else {
|
|
|
// Generate Matrix Store.
|
|
|
HLModule::EmitHLOperationCall(
|
|
|
Builder, HLOpcodeGroup::HLMatLoadStore,
|