|
@@ -4657,31 +4657,8 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
|
|
|
const SpirvEvalInfo &rhsVal,
|
|
|
const QualType lhsValType) {
|
|
|
|
|
|
- // Lambda for cases where we want to store per each array element.
|
|
|
- const auto storeValueForEachArrayElement = [this, &lhsPtr,
|
|
|
- &rhsVal](uint32_t arraySize,
|
|
|
- QualType arrayElemType) {
|
|
|
- for (uint32_t i = 0; i < arraySize; ++i) {
|
|
|
- const auto subRhsValType =
|
|
|
- typeTranslator.translateType(arrayElemType, rhsVal.getLayoutRule());
|
|
|
- const auto subRhsVal =
|
|
|
- theBuilder.createCompositeExtract(subRhsValType, rhsVal, {i});
|
|
|
- const auto subLhsPtrType = theBuilder.getPointerType(
|
|
|
- typeTranslator.translateType(arrayElemType, lhsPtr.getLayoutRule()),
|
|
|
- lhsPtr.getStorageClass());
|
|
|
- const auto subLhsPtr = theBuilder.createAccessChain(
|
|
|
- subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(i)});
|
|
|
-
|
|
|
- storeValue(lhsPtr.substResultId(subLhsPtr),
|
|
|
- rhsVal.substResultId(subRhsVal), arrayElemType);
|
|
|
- }
|
|
|
- };
|
|
|
-
|
|
|
QualType matElemType = {};
|
|
|
- QualType elemType = {};
|
|
|
- uint32_t numRows = 0, numCols = 0;
|
|
|
- const bool lhsIsMat =
|
|
|
- typeTranslator.isMxNMatrix(lhsValType, &matElemType, &numRows, &numCols);
|
|
|
+ const bool lhsIsMat = typeTranslator.isMxNMatrix(lhsValType, &matElemType);
|
|
|
const bool lhsIsFloatMat = lhsIsMat && matElemType->isFloatingType();
|
|
|
const bool lhsIsNonFpMat = lhsIsMat && !matElemType->isFloatingType();
|
|
|
|
|
@@ -4743,40 +4720,83 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
|
|
|
// Note: this check should happen after those setting needsLegalization.
|
|
|
// TODO: is this optimization always correct?
|
|
|
theBuilder.createStore(lhsPtr, rhsVal);
|
|
|
- } else if (lhsIsNonFpMat) {
|
|
|
+ } else if (lhsValType->isRecordType() || lhsValType->isConstantArrayType() ||
|
|
|
+ lhsIsNonFpMat) {
|
|
|
+ theBuilder.createStore(
|
|
|
+ lhsPtr, reconstructValue(rhsVal, lhsValType, lhsPtr.getLayoutRule()));
|
|
|
+ } else {
|
|
|
+ emitError("storing value of type %0 unimplemented", {}) << lhsValType;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+uint32_t SPIRVEmitter::reconstructValue(const SpirvEvalInfo &srcVal,
|
|
|
+ const QualType valType,
|
|
|
+ LayoutRule dstLR) {
|
|
|
+ // Lambda for cases where we want to reconstruct an array
|
|
|
+ const auto reconstructArray = [this, &srcVal, valType,
|
|
|
+ dstLR](uint32_t arraySize,
|
|
|
+ QualType arrayElemType) {
|
|
|
+ llvm::SmallVector<uint32_t, 4> elements;
|
|
|
+ for (uint32_t i = 0; i < arraySize; ++i) {
|
|
|
+ const auto subSrcValType =
|
|
|
+ typeTranslator.translateType(arrayElemType, srcVal.getLayoutRule());
|
|
|
+ const auto subSrcVal =
|
|
|
+ theBuilder.createCompositeExtract(subSrcValType, srcVal, {i});
|
|
|
+
|
|
|
+ elements.push_back(reconstructValue(srcVal.substResultId(subSrcVal),
|
|
|
+ arrayElemType, dstLR));
|
|
|
+ }
|
|
|
+ const auto dstValType = typeTranslator.translateType(valType, dstLR);
|
|
|
+ return theBuilder.createCompositeConstruct(dstValType, elements);
|
|
|
+ };
|
|
|
+
|
|
|
+ // Constant arrays
|
|
|
+ if (const auto *arrayType = astContext.getAsConstantArrayType(valType)) {
|
|
|
+ const auto elemType = arrayType->getElementType();
|
|
|
+ const auto size =
|
|
|
+ static_cast<uint32_t>(arrayType->getSize().getZExtValue());
|
|
|
+ return reconstructArray(size, elemType);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Non-floating-point matrices
|
|
|
+ QualType matElemType = {};
|
|
|
+ uint32_t numRows = 0, numCols = 0;
|
|
|
+ const bool isNonFpMat =
|
|
|
+ typeTranslator.isMxNMatrix(valType, &matElemType, &numRows, &numCols) &&
|
|
|
+ !matElemType->isFloatingType();
|
|
|
+
|
|
|
+ if (isNonFpMat) {
|
|
|
// Note: This check should happen before the RecordType check.
|
|
|
// Non-fp matrices are represented as arrays of vectors in SPIR-V.
|
|
|
// Each array element is a vector. Get the QualType for the vector.
|
|
|
const auto elemType = astContext.getExtVectorType(matElemType, numCols);
|
|
|
- storeValueForEachArrayElement(numRows, elemType);
|
|
|
- } else if (const auto *recordType = lhsValType->getAs<RecordType>()) {
|
|
|
+ return reconstructArray(numRows, elemType);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Note: This check should happen before the RecordType check since
|
|
|
+ // vector/matrix/resource types are represented as RecordType in the AST.
|
|
|
+ if (hlsl::IsHLSLVecMatType(valType) || hlsl::IsHLSLResourceType(valType))
|
|
|
+ return srcVal;
|
|
|
+
|
|
|
+ // Structs
|
|
|
+ if (const auto *recordType = valType->getAs<RecordType>()) {
|
|
|
uint32_t index = 0;
|
|
|
+ llvm::SmallVector<uint32_t, 4> elements;
|
|
|
for (const auto *field : recordType->getDecl()->fields()) {
|
|
|
- const auto subRhsValType = typeTranslator.translateType(
|
|
|
- field->getType(), rhsVal.getLayoutRule());
|
|
|
- const auto subRhsVal =
|
|
|
- theBuilder.createCompositeExtract(subRhsValType, rhsVal, {index});
|
|
|
- const auto subLhsPtrType = theBuilder.getPointerType(
|
|
|
- typeTranslator.translateType(field->getType(),
|
|
|
- lhsPtr.getLayoutRule()),
|
|
|
- lhsPtr.getStorageClass());
|
|
|
- const auto subLhsPtr = theBuilder.createAccessChain(
|
|
|
- subLhsPtrType, lhsPtr, {theBuilder.getConstantUint32(index)});
|
|
|
-
|
|
|
- storeValue(lhsPtr.substResultId(subLhsPtr),
|
|
|
- rhsVal.substResultId(subRhsVal), field->getType());
|
|
|
+ const auto subSrcValType = typeTranslator.translateType(
|
|
|
+ field->getType(), srcVal.getLayoutRule());
|
|
|
+ const auto subSrcVal =
|
|
|
+ theBuilder.createCompositeExtract(subSrcValType, srcVal, {index});
|
|
|
+
|
|
|
+ elements.push_back(reconstructValue(srcVal.substResultId(subSrcVal),
|
|
|
+ field->getType(), dstLR));
|
|
|
++index;
|
|
|
}
|
|
|
- } else if (const auto *arrayType =
|
|
|
- astContext.getAsConstantArrayType(lhsValType)) {
|
|
|
- const auto elemType = arrayType->getElementType();
|
|
|
- // TODO: handle extra large array size?
|
|
|
- const auto size =
|
|
|
- static_cast<uint32_t>(arrayType->getSize().getZExtValue());
|
|
|
- storeValueForEachArrayElement(size, elemType);
|
|
|
- } else {
|
|
|
- emitError("storing value of type %0 unimplemented", {}) << lhsValType;
|
|
|
+ const auto dstValType = typeTranslator.translateType(valType, dstLR);
|
|
|
+ return theBuilder.createCompositeConstruct(dstValType, elements);
|
|
|
}
|
|
|
+
|
|
|
+ return srcVal;
|
|
|
}
|
|
|
|
|
|
SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
|