123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- ///////////////////////////////////////////////////////////////////////////////
- // //
- // HLLowerUDT.cpp //
- // Copyright (C) Microsoft Corporation. All rights reserved. //
- // This file is distributed under the University of Illinois Open Source //
- // License. See LICENSE.TXT for details. //
- // //
- // Lower user defined type used directly by certain intrinsic operations. //
- // //
- ///////////////////////////////////////////////////////////////////////////////
- #include "dxc/HLSL/HLLowerUDT.h"
- #include "dxc/Support/Global.h"
- #include "dxc/DXIL/DxilConstants.h"
- #include "dxc/HLSL/HLModule.h"
- #include "dxc/HLSL/HLOperations.h"
- #include "dxc/DXIL/DxilTypeSystem.h"
- #include "dxc/HLSL/HLMatrixLowerHelper.h"
- #include "dxc/HLSL/HLMatrixType.h"
- #include "dxc/HlslIntrinsicOp.h"
- #include "dxc/DXIL/DxilUtil.h"
- #include "HLMatrixSubscriptUseReplacer.h"
- #include "llvm/ADT/SmallVector.h"
- #include "llvm/IR/CallSite.h"
- #include "llvm/IR/Constants.h"
- #include "llvm/IR/DebugInfo.h"
- #include "llvm/IR/Function.h"
- #include "llvm/IR/GlobalVariable.h"
- #include "llvm/IR/IRBuilder.h"
- #include "llvm/IR/Instructions.h"
- #include "llvm/IR/IntrinsicInst.h"
- #include "llvm/IR/LLVMContext.h"
- #include "llvm/IR/Module.h"
- using namespace llvm;
- using namespace hlsl;
- // Lowered UDT is the same layout, but with vectors and matrices translated to
- // arrays.
- // Returns nullptr for failure due to embedded HLSL object type.
- StructType *hlsl::GetLoweredUDT(StructType *structTy, DxilTypeSystem *pTypeSys) {
- bool changed = false;
- SmallVector<Type*, 8> NewElTys(structTy->getNumContainedTypes());
- for (unsigned iField = 0; iField < NewElTys.size(); ++iField) {
- Type *FieldTy = structTy->getContainedType(iField);
- // Default to original type
- NewElTys[iField] = FieldTy;
- // Unwrap arrays:
- SmallVector<unsigned, 4> OuterToInnerLengths;
- Type *EltTy = dxilutil::StripArrayTypes(FieldTy, &OuterToInnerLengths);
- Type *NewTy = EltTy;
- // Lower element if necessary
- if (EltTy->isVectorTy()) {
- NewTy = ArrayType::get(EltTy->getVectorElementType(),
- EltTy->getVectorNumElements());
- } else if (HLMatrixType Mat = HLMatrixType::dyn_cast(EltTy)) {
- NewTy = ArrayType::get(Mat.getElementType(/*MemRepr*/true),
- Mat.getNumElements());
- } else if (dxilutil::IsHLSLObjectType(EltTy) ||
- dxilutil::IsHLSLRayQueryType(EltTy)) {
- // We cannot lower a structure with an embedded object type
- return nullptr;
- } else if (StructType *ST = dyn_cast<StructType>(EltTy)) {
- NewTy = GetLoweredUDT(ST);
- if (nullptr == NewTy)
- return nullptr; // Propagate failure back to root
- } else if (EltTy->isIntegerTy(1)) {
- // Must translate bool to mem type
- EltTy = IntegerType::get(EltTy->getContext(), 32);
- }
- // if unchanged, skip field
- if (NewTy == EltTy)
- continue;
- // Rewrap Arrays:
- for (auto itLen = OuterToInnerLengths.rbegin(),
- E = OuterToInnerLengths.rend();
- itLen != E; ++itLen) {
- NewTy = ArrayType::get(NewTy, *itLen);
- }
- // Update field, and set changed
- NewElTys[iField] = NewTy;
- changed = true;
- }
- if (changed) {
- StructType *newStructTy = StructType::create(
- structTy->getContext(), NewElTys, structTy->getStructName());
- if (DxilStructAnnotation *pSA = pTypeSys ?
- pTypeSys->GetStructAnnotation(structTy) : nullptr) {
- if (!pTypeSys->GetStructAnnotation(newStructTy)) {
- DxilStructAnnotation &NewSA = *pTypeSys->AddStructAnnotation(newStructTy);
- for (unsigned iField = 0; iField < NewElTys.size(); ++iField) {
- NewSA.GetFieldAnnotation(iField) = pSA->GetFieldAnnotation(iField);
- }
- }
- }
- return newStructTy;
- }
- return structTy;
- }
- Constant *hlsl::TranslateInitForLoweredUDT(
- Constant *Init, Type *NewTy,
- // We need orientation for matrix fields
- DxilTypeSystem *pTypeSys,
- MatrixOrientation matOrientation) {
- // handle undef and zero init
- if (isa<UndefValue>(Init))
- return UndefValue::get(NewTy);
- else if (Init->getType()->isAggregateType() && Init->isZeroValue())
- return ConstantAggregateZero::get(NewTy);
- // unchanged
- Type *Ty = Init->getType();
- if (Ty == NewTy)
- return Init;
- SmallVector<Constant*, 16> values;
- if (Ty->isArrayTy()) {
- values.reserve(Ty->getArrayNumElements());
- ConstantArray *CA = cast<ConstantArray>(Init);
- for (unsigned i = 0; i < Ty->getArrayNumElements(); ++i)
- values.emplace_back(
- TranslateInitForLoweredUDT(
- CA->getAggregateElement(i),
- NewTy->getArrayElementType(),
- pTypeSys, matOrientation));
- return ConstantArray::get(cast<ArrayType>(NewTy), values);
- } else if (Ty->isVectorTy()) {
- values.reserve(Ty->getVectorNumElements());
- ConstantVector *CV = cast<ConstantVector>(Init);
- for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i)
- values.emplace_back(CV->getAggregateElement(i));
- return ConstantArray::get(cast<ArrayType>(NewTy), values);
- } else if (HLMatrixType Mat = HLMatrixType::dyn_cast(Ty)) {
- values.reserve(Mat.getNumElements());
- ConstantArray *MatArray = cast<ConstantArray>(
- cast<ConstantStruct>(Init)->getOperand(0));
- for (unsigned row = 0; row < Mat.getNumRows(); ++row) {
- ConstantVector *RowVector = cast<ConstantVector>(
- MatArray->getOperand(row));
- for (unsigned col = 0; col < Mat.getNumColumns(); ++col) {
- unsigned index = matOrientation == MatrixOrientation::ColumnMajor ?
- Mat.getColumnMajorIndex(row, col) : Mat.getRowMajorIndex(row, col);
- values[index] = RowVector->getOperand(col);
- }
- }
- } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
- DxilStructAnnotation *pStructAnnotation =
- pTypeSys ? pTypeSys->GetStructAnnotation(ST) : nullptr;
- values.reserve(ST->getNumContainedTypes());
- ConstantStruct *CS = cast<ConstantStruct>(Init);
- for (unsigned i = 0; i < ST->getStructNumElements(); ++i) {
- MatrixOrientation matFieldOrientation = matOrientation;
- if (pStructAnnotation) {
- DxilFieldAnnotation &FA = pStructAnnotation->GetFieldAnnotation(i);
- if (FA.HasMatrixAnnotation()) {
- matFieldOrientation = FA.GetMatrixAnnotation().Orientation;
- }
- }
- values.emplace_back(
- TranslateInitForLoweredUDT(
- cast<Constant>(CS->getAggregateElement(i)),
- NewTy->getStructElementType(i),
- pTypeSys, matFieldOrientation));
- }
- return ConstantStruct::get(cast<StructType>(NewTy), values);
- }
- return Init;
- }
- void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
- Type *Ty = V->getType();
- Type *NewTy = NewV->getType();
- if (Ty == NewTy) {
- V->replaceAllUsesWith(NewV);
- if (Instruction *I = dyn_cast<Instruction>(V))
- I->dropAllReferences();
- if (Constant *CV = dyn_cast<Constant>(V))
- CV->removeDeadConstantUsers();
- return;
- }
- if (Ty->isPointerTy())
- Ty = Ty->getPointerElementType();
- if (NewTy->isPointerTy())
- NewTy = NewTy->getPointerElementType();
- while (!V->use_empty()) {
- Use &use = *V->use_begin();
- User *user = use.getUser();
- // Clear use to prevent infinite loop on unhandled case.
- use.set(UndefValue::get(V->getType()));
- if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
- // Load for non-matching type should only be vector
- DXASSERT(Ty->isVectorTy() && NewTy->isArrayTy() &&
- Ty->getVectorNumElements() == NewTy->getArrayNumElements(),
- "unexpected load of non-matching type");
- IRBuilder<> Builder(LI);
- Value *result = UndefValue::get(Ty);
- for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i) {
- Value *GEP = Builder.CreateInBoundsGEP(NewV,
- {Builder.getInt32(0), Builder.getInt32(i)});
- Value *El = Builder.CreateLoad(GEP);
- result = Builder.CreateInsertElement(result, El, i);
- }
- LI->replaceAllUsesWith(result);
- LI->eraseFromParent();
- } else if (StoreInst *SI = dyn_cast<StoreInst>(user)) {
- // Store for non-matching type should only be vector
- DXASSERT(Ty->isVectorTy() && NewTy->isArrayTy() &&
- Ty->getVectorNumElements() == NewTy->getArrayNumElements(),
- "unexpected load of non-matching type");
- IRBuilder<> Builder(SI);
- for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i) {
- Value *EE = Builder.CreateExtractElement(SI->getValueOperand(), i);
- Value *GEP = Builder.CreateInBoundsGEP(
- NewV, {Builder.getInt32(0), Builder.getInt32(i)});
- Builder.CreateStore(EE, GEP);
- }
- SI->eraseFromParent();
- } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
- // Non-constant GEP
- IRBuilder<> Builder(GEP);
- SmallVector<Value*, 4> idxList(GEP->idx_begin(), GEP->idx_end());
- Value *NewGEP = Builder.CreateGEP(NewV, idxList);
- ReplaceUsesForLoweredUDT(GEP, NewGEP);
- GEP->eraseFromParent();
- } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(user)) {
- // Has to be constant GEP, NewV better be constant
- SmallVector<Value*, 4> idxList(GEP->idx_begin(), GEP->idx_end());
- Constant *NewGEP = ConstantExpr::getGetElementPtr(
- nullptr, cast<Constant>(NewV), idxList, true);
- ReplaceUsesForLoweredUDT(GEP, NewGEP);
- GEP->dropAllReferences();
- } else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
- // Address space cast
- IRBuilder<> Builder(AC);
- unsigned AddrSpace = AC->getType()->getPointerAddressSpace();
- Value *NewAC = Builder.CreateAddrSpaceCast(
- NewV, PointerType::get(NewTy, AddrSpace));
- ReplaceUsesForLoweredUDT(user, NewAC);
- AC->eraseFromParent();
- } else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
- IRBuilder<> Builder(BC);
- if (BC->getType()->getPointerElementType() == NewTy) {
- // if alreday bitcast to new type, just replace the bitcast
- // with the new value (already translated user function)
- BC->replaceAllUsesWith(NewV);
- } else {
- // Could be i8 for memcpy?
- // Replace bitcast argument with new value
- use.set(NewV);
- }
- } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
- // Constant AddrSpaceCast, or BitCast
- if (CE->getOpcode() == Instruction::AddrSpaceCast) {
- unsigned AddrSpace = CE->getType()->getPointerAddressSpace();
- ReplaceUsesForLoweredUDT(user,
- ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
- PointerType::get(NewTy, AddrSpace)));
- } else if (CE->getOpcode() == Instruction::BitCast) {
- if (CE->getType()->getPointerElementType() == NewTy) {
- // if alreday bitcast to new type, just replace the bitcast
- // with the new value
- CE->replaceAllUsesWith(NewV);
- } else {
- // Could be i8 for memcpy?
- // Replace bitcast argument with new value
- use.set(NewV);
- }
- } else {
- DXASSERT(0, "unhandled constant expr for lowered UTD");
- CE->dropAllReferences(); // better than infinite loop on release
- }
- } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
- // Lower some matrix intrinsics that access pointers early, and
- // cast arguments for user functions or special UDT intrinsics
- // for later translation.
- Function *F = CI->getCalledFunction();
- HLOpcodeGroup group = GetHLOpcodeGroupByName(F);
- HLMatrixType Mat = HLMatrixType::dyn_cast(Ty);
- bool bColMajor = false;
- switch (group) {
- case HLOpcodeGroup::HLMatLoadStore: {
- DXASSERT(Mat, "otherwise, matrix operation on non-matrix value");
- IRBuilder<> Builder(CI);
- HLMatLoadStoreOpcode opcode =
- static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
- switch (opcode) {
- case HLMatLoadStoreOpcode::ColMatLoad:
- bColMajor = true;
- __fallthrough;
- case HLMatLoadStoreOpcode::RowMatLoad: {
- Value *val = UndefValue::get(
- VectorType::get(NewTy->getArrayElementType(),
- NewTy->getArrayNumElements()));
- for (unsigned i = 0; i < NewTy->getArrayNumElements(); ++i) {
- Value *GEP = Builder.CreateGEP(NewV,
- {Builder.getInt32(0), Builder.getInt32(i)});
- Value *elt = Builder.CreateLoad(GEP);
- val = Builder.CreateInsertElement(val, elt, i);
- }
- if (bColMajor) {
- // transpose matrix to match expected value orientation for
- // default cast to matrix type
- SmallVector<int, 16> ShuffleIndices;
- for (unsigned RowIdx = 0; RowIdx < Mat.getNumRows(); ++RowIdx)
- for (unsigned ColIdx = 0; ColIdx < Mat.getNumColumns(); ++ColIdx)
- ShuffleIndices.emplace_back(
- static_cast<int>(Mat.getColumnMajorIndex(RowIdx, ColIdx)));
- val = Builder.CreateShuffleVector(val, val, ShuffleIndices);
- }
- // lower mem to reg type
- val = Mat.emitLoweredMemToReg(val, Builder);
- // cast vector back to matrix value (DefaultCast expects row major)
- unsigned newOpcode = (unsigned)HLCastOpcode::DefaultCast;
- val = callHLFunction(*F->getParent(), HLOpcodeGroup::HLCast, newOpcode,
- Ty, { Builder.getInt32(newOpcode), val }, Builder);
- if (bColMajor) {
- // emit cast row to col to match original result
- newOpcode = (unsigned)HLCastOpcode::RowMatrixToColMatrix;
- val = callHLFunction(*F->getParent(), HLOpcodeGroup::HLCast, newOpcode,
- Ty, { Builder.getInt32(newOpcode), val }, Builder);
- }
- // replace use of HLMatLoadStore with loaded vector
- CI->replaceAllUsesWith(val);
- } break;
- case HLMatLoadStoreOpcode::ColMatStore:
- bColMajor = true;
- __fallthrough;
- case HLMatLoadStoreOpcode::RowMatStore: {
- // HLCast matrix value to vector
- unsigned newOpcode = (unsigned)(bColMajor ?
- HLCastOpcode::ColMatrixToVecCast :
- HLCastOpcode::RowMatrixToVecCast);
- Value *val = callHLFunction(*F->getParent(),
- HLOpcodeGroup::HLCast, newOpcode,
- Mat.getLoweredVectorType(false),
- { Builder.getInt32(newOpcode),
- CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx) },
- Builder);
- // lower reg to mem type
- val = Mat.emitLoweredRegToMem(val, Builder);
- for (unsigned i = 0; i < NewTy->getArrayNumElements(); ++i) {
- Value *elt = Builder.CreateExtractElement(val, i);
- Value *GEP = Builder.CreateGEP(NewV,
- {Builder.getInt32(0), Builder.getInt32(i)});
- Builder.CreateStore(elt, GEP);
- }
- } break;
- default:
- DXASSERT(0, "invalid opcode");
- }
- CI->eraseFromParent();
- } break;
- case HLOpcodeGroup::HLSubscript: {
- SmallVector<Value*, 4> ElemIndices;
- HLSubscriptOpcode opcode =
- static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(CI));
- switch (opcode) {
- case HLSubscriptOpcode::VectorSubscript:
- DXASSERT(0, "not handled yet");
- break;
- case HLSubscriptOpcode::ColMatElement:
- bColMajor = true;
- __fallthrough;
- case HLSubscriptOpcode::RowMatElement: {
- ConstantDataSequential *cIdx = cast<ConstantDataSequential>(
- CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
- for (unsigned i = 0; i < cIdx->getNumElements(); ++i) {
- ElemIndices.push_back(cIdx->getElementAsConstant(i));
- }
- } break;
- case HLSubscriptOpcode::ColMatSubscript:
- bColMajor = true;
- __fallthrough;
- case HLSubscriptOpcode::RowMatSubscript: {
- for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < CI->getNumArgOperands(); ++Idx) {
- ElemIndices.emplace_back(CI->getArgOperand(Idx));
- }
- } break;
- default:
- DXASSERT(0, "invalid opcode");
- }
- std::vector<Instruction*> DeadInsts;
- HLMatrixSubscriptUseReplacer UseReplacer(
- CI, NewV, /*TempLoweredMatrix*/nullptr, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
- DXASSERT(CI->use_empty(),
- "Expected all matrix subscript uses to have been replaced.");
- CI->eraseFromParent();
- while (!DeadInsts.empty()) {
- DeadInsts.back()->eraseFromParent();
- DeadInsts.pop_back();
- }
- } break;
- //case HLOpcodeGroup::NotHL: // TODO: Support lib functions
- case HLOpcodeGroup::HLIntrinsic: {
- // Just bitcast for now
- IRBuilder<> Builder(CI);
- use.set(Builder.CreateBitCast(NewV, V->getType()));
- continue;
- } break;
- default:
- DXASSERT(0, "invalid opcode");
- }
- } else {
- // What else?
- DXASSERT(false, "case not handled.");
- }
- }
- }
|