|
@@ -1463,6 +1463,19 @@ Value *HLMatrixLowerPass::lowerHLInit(CallInst *Call) {
|
|
|
// Figure out the result type
|
|
|
HLMatrixType MatTy = HLMatrixType::cast(Call->getType());
|
|
|
VectorType *LoweredTy = MatTy.getLoweredVectorTypeForReg();
|
|
|
+
|
|
|
+ // Handle case where produced by EmitHLSLFlatConversion where there's one
|
|
|
+ // vector argument, instead of scalar arguments.
|
|
|
+ if (1 == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx &&
|
|
|
+ Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx)->
|
|
|
+ getType()->isVectorTy()) {
|
|
|
+ Value *LoweredVec = Call->getArgOperand(HLOperandIndex::kInitFirstArgOpIdx);
|
|
|
+ DXASSERT(LoweredTy->getNumElements() ==
|
|
|
+ LoweredVec->getType()->getVectorNumElements(),
|
|
|
+ "Invalid matrix init argument vector element count.");
|
|
|
+ return LoweredVec;
|
|
|
+ }
|
|
|
+
|
|
|
DXASSERT(LoweredTy->getNumElements() == Call->getNumArgOperands() - HLOperandIndex::kInitFirstArgOpIdx,
|
|
|
"Invalid matrix init argument count.");
|
|
|
|