Przeglądaj źródła

Add HLLowerUDT: early UDT ptr lowering for direct use in final Dxil

- DispatchMesh and other intrinsics that use UDT ptr directly need
  this to prevent copying and reconstruction, as well as prevent vectors
  and HL matrices from being used in final Dxil.
Tex Riddell 5 lat temu
rodzic
commit
f383b21be9
3 zmienionych plików z 465 dodań i 0 usunięć
  1. 34 0
      include/dxc/HLSL/HLLowerUDT.h
  2. 1 0
      lib/HLSL/CMakeLists.txt
  3. 430 0
      lib/HLSL/HLLowerUDT.cpp

+ 34 - 0
include/dxc/HLSL/HLLowerUDT.h

@@ -0,0 +1,34 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLLowerUDT.h                                                              //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Lower user defined type used directly by certain intrinsic operations.    //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#pragma once
+#include "dxc/Support/Global.h"
+#include "dxc/DXIL/DxilConstants.h"
+#include "dxc/DXIL/DxilTypeSystem.h"
+
+namespace llvm {
+class Constant;
+class Function;
+class StructType;
+class Type;
+class Value;
+} // namespace llvm
+
+namespace hlsl {
+
+llvm::StructType *GetLoweredUDT(llvm::StructType *structTy);
+llvm::Constant *TranslateInitForLoweredUDT(llvm::Constant *Init, llvm::Type *NewTy,
+    // We need orientation for matrix fields
+    hlsl::DxilTypeSystem *pTypeSys,
+    hlsl::MatrixOrientation matOrientation = hlsl::MatrixOrientation::Undefined);
+void ReplaceUsesForLoweredUDT(llvm::Value *V, llvm::Value *NewV);
+
+} // namespace hlsl

+ 1 - 0
lib/HLSL/CMakeLists.txt

@@ -29,6 +29,7 @@ add_llvm_library(LLVMHLSL
   DxcOptimizer.cpp
   HLDeadFunctionElimination.cpp
   HLExpandStoreIntrinsics.cpp
+  HLLowerUDT.cpp
   HLMatrixBitcastLowerPass.cpp
   HLMatrixLowerPass.cpp
   HLMatrixSubscriptUseReplacer.cpp

+ 430 - 0
lib/HLSL/HLLowerUDT.cpp

@@ -0,0 +1,430 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// HLLowerUDT.cpp                                                            //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+// Lower user defined type used directly by certain intrinsic operations.    //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "dxc/HLSL/HLLowerUDT.h"
+#include "dxc/Support/Global.h"
+#include "dxc/DXIL/DxilConstants.h"
+#include "dxc/HLSL/HLModule.h"
+#include "dxc/HLSL/HLOperations.h"
+#include "dxc/DXIL/DxilTypeSystem.h"
+#include "dxc/HLSL/HLMatrixLowerHelper.h"
+#include "dxc/HLSL/HLMatrixType.h"
+#include "dxc/HlslIntrinsicOp.h"
+#include "dxc/DXIL/DxilUtil.h"
+
+#include "HLMatrixSubscriptUseReplacer.h"
+
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalVariable.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+static Value *callHLFunction(llvm::Module &Module, HLOpcodeGroup OpcodeGroup, unsigned Opcode,
+  Type *RetTy, ArrayRef<Value*> Args, IRBuilder<> &Builder) {
+  SmallVector<Type*, 4> ArgTys;
+  ArgTys.reserve(Args.size());
+  for (Value *Arg : Args)
+    ArgTys.emplace_back(Arg->getType());
+
+  FunctionType *FuncTy = FunctionType::get(RetTy, ArgTys, /* isVarArg */ false);
+  Function *Func = GetOrCreateHLFunction(Module, FuncTy, OpcodeGroup, Opcode);
+
+  return Builder.CreateCall(Func, Args);
+}
+
+
+// Lowered UDT is the same layout, but with vectors and matrices translated to
+// arrays.
+// Returns nullptr for failure due to embedded HLSL object type.
+StructType *hlsl::GetLoweredUDT(StructType *structTy) {
+  bool changed = false;
+  SmallVector<Type*, 8> NewElTys(structTy->getNumContainedTypes());
+
+  for (unsigned iField = 0; iField < NewElTys.size(); ++iField) {
+    Type *FieldTy = structTy->getContainedType(iField);
+
+    // Default to original type
+    NewElTys[iField] = FieldTy;
+
+    // Unwrap arrays:
+    SmallVector<unsigned, 4> OuterToInnerLengths;
+    Type *EltTy = dxilutil::StripArrayTypes(FieldTy, &OuterToInnerLengths);
+    Type *NewTy = EltTy;
+
+    // Lower element if necessary
+    if (EltTy->isVectorTy()) {
+      NewTy = ArrayType::get(EltTy->getVectorElementType(),
+                             EltTy->getVectorNumElements());
+    } else if (HLMatrixType Mat = HLMatrixType::dyn_cast(EltTy)) {
+      NewTy = ArrayType::get(Mat.getElementType(/*MemRepr*/true),
+                             Mat.getNumElements());
+    } else if (dxilutil::IsHLSLObjectType(EltTy) ||
+               dxilutil::IsHLSLRayQueryType(EltTy)) {
+      // We cannot lower a structure with an embedded object type
+      return nullptr;
+    } else if (StructType *ST = dyn_cast<StructType>(EltTy)) {
+      NewTy = GetLoweredUDT(ST);
+      if (nullptr == NewTy)
+        return nullptr; // Propagate failure back to root
+    } else if (EltTy->isIntegerTy(1)) {
+      // Must translate bool to mem type
+      EltTy = IntegerType::get(EltTy->getContext(), 32);
+    }
+
+    // if unchanged, skip field
+    if (NewTy == EltTy)
+      continue;
+
+    // Rewrap Arrays:
+    for (auto itLen = OuterToInnerLengths.rbegin(),
+                  E = OuterToInnerLengths.rend();
+         itLen != E; ++itLen) {
+      NewTy = ArrayType::get(NewTy, *itLen);
+    }
+
+    // Update field, and set changed
+    NewElTys[iField] = NewTy;
+    changed = true;
+  }
+
+  if (changed) {
+    return StructType::create(
+      structTy->getContext(), NewElTys, structTy->getStructName());
+  }
+
+  return structTy;
+}
+
+Constant *hlsl::TranslateInitForLoweredUDT(Constant *Init, Type *NewTy,
+    // We need orientation for matrix fields
+    DxilTypeSystem *pTypeSys,
+    MatrixOrientation matOrientation) {
+  // handle undef and zero init
+  if (isa<UndefValue>(Init))
+    return UndefValue::get(NewTy);
+  else if (Init->getType()->isAggregateType() && Init->isZeroValue())
+    return ConstantAggregateZero::get(NewTy);
+
+  // unchanged
+  Type *Ty = Init->getType();
+  if (Ty == NewTy)
+    return Init;
+
+  SmallVector<Constant*, 16> values;
+  if (Ty->isArrayTy()) {
+    values.reserve(Ty->getArrayNumElements());
+    ConstantArray *CA = cast<ConstantArray>(Init);
+    for (unsigned i = 0; i < Ty->getArrayNumElements(); ++i)
+      values.emplace_back(
+        TranslateInitForLoweredUDT(
+          CA->getAggregateElement(i),
+          NewTy->getArrayElementType(),
+          pTypeSys, matOrientation));
+    return ConstantArray::get(cast<ArrayType>(NewTy), values);
+  } else if (Ty->isVectorTy()) {
+    values.reserve(Ty->getVectorNumElements());
+    ConstantVector *CV = cast<ConstantVector>(Init);
+    for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i)
+      values.emplace_back(CV->getAggregateElement(i));
+    return ConstantArray::get(cast<ArrayType>(NewTy), values);
+  } else if (HLMatrixType Mat = HLMatrixType::dyn_cast(Ty)) {
+    values.reserve(Mat.getNumElements());
+    ConstantArray *MatArray = cast<ConstantArray>(
+      cast<ConstantStruct>(Init)->getOperand(0));
+    for (unsigned row = 0; row < Mat.getNumRows(); ++row) {
+      ConstantVector *RowVector = cast<ConstantVector>(
+        MatArray->getOperand(row));
+      for (unsigned col = 0; col < Mat.getNumColumns(); ++col) {
+        unsigned index = matOrientation == MatrixOrientation::ColumnMajor ?
+          Mat.getColumnMajorIndex(row, col) : Mat.getRowMajorIndex(row, col);
+        values[index] = RowVector->getOperand(col);
+      }
+    }
+  } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
+    values.reserve(ST->getNumContainedTypes());
+    ConstantStruct *CS = cast<ConstantStruct>(Init);
+    for (unsigned i = 0; i < ST->getStructNumElements(); ++i) {
+      values.emplace_back(
+        TranslateInitForLoweredUDT(
+          cast<Constant>(CS->getAggregateElement(i)),
+          NewTy->getStructElementType(i),
+          pTypeSys, matOrientation));
+    }
+    return ConstantStruct::get(cast<StructType>(NewTy), values);
+  }
+  return Init;
+}
+
+void hlsl::ReplaceUsesForLoweredUDT(Value *V, Value *NewV) {
+  Type *Ty = V->getType();
+  Type *NewTy = NewV->getType();
+
+  if (Ty == NewTy) {
+    V->replaceAllUsesWith(NewV);
+    if (Instruction *I = dyn_cast<Instruction>(V))
+      I->dropAllReferences();
+    if (Constant *CV = dyn_cast<Constant>(V))
+      CV->removeDeadConstantUsers();
+    return;
+  }
+
+  if (Ty->isPointerTy())
+    Ty = Ty->getPointerElementType();
+  if (NewTy->isPointerTy())
+    NewTy = NewTy->getPointerElementType();
+
+  while (!V->use_empty()) {
+    Use &use = *V->use_begin();
+    User *user = use.getUser();
+    // Clear use to prevent infinite loop on unhandled case.
+    use.set(UndefValue::get(V->getType()));
+
+    if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
+      // Load for non-matching type should only be vector
+      DXASSERT(Ty->isVectorTy() && NewTy->isArrayTy() &&
+        Ty->getVectorNumElements() == NewTy->getArrayNumElements(),
+        "unexpected load of non-matching type");
+      IRBuilder<> Builder(LI);
+      Value *result = UndefValue::get(Ty);
+      for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i) {
+        Value *GEP = Builder.CreateInBoundsGEP(NewV,
+          {Builder.getInt32(0), Builder.getInt32(i)});
+        Value *El = Builder.CreateLoad(GEP);
+        result = Builder.CreateInsertElement(result, El, i);
+      }
+      LI->replaceAllUsesWith(result);
+      LI->eraseFromParent();
+
+    } else if (StoreInst *SI = dyn_cast<StoreInst>(user)) {
+      // Store for non-matching type should only be vector
+      DXASSERT(Ty->isVectorTy() && NewTy->isArrayTy() &&
+        Ty->getVectorNumElements() == NewTy->getArrayNumElements(),
+        "unexpected load of non-matching type");
+      IRBuilder<> Builder(SI);
+      for (unsigned i = 0; i < Ty->getVectorNumElements(); ++i) {
+        Value *EE = Builder.CreateExtractElement(SI->getValueOperand(), i);
+        Value *GEP = Builder.CreateInBoundsGEP(
+          NewV, {Builder.getInt32(0), Builder.getInt32(i)});
+        Builder.CreateStore(EE, GEP);
+      }
+      SI->eraseFromParent();
+
+    } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(user)) {
+      // Non-constant GEP
+      IRBuilder<> Builder(GEP);
+      SmallVector<Value*, 4> idxList(GEP->idx_begin(), GEP->idx_end());
+      Value *NewGEP = Builder.CreateGEP(NewV, idxList);
+      ReplaceUsesForLoweredUDT(GEP, NewGEP);
+      GEP->eraseFromParent();
+
+    } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(user)) {
+      // Has to be constant GEP, NewV better be constant
+      SmallVector<Value*, 4> idxList(GEP->idx_begin(), GEP->idx_end());
+      Constant *NewGEP = ConstantExpr::getGetElementPtr(
+        nullptr, cast<Constant>(NewV), idxList, true);
+      ReplaceUsesForLoweredUDT(GEP, NewGEP);
+      GEP->dropAllReferences();
+
+    } else if (AddrSpaceCastInst *AC = dyn_cast<AddrSpaceCastInst>(user)) {
+      // Address space cast
+      IRBuilder<> Builder(AC);
+      unsigned AddrSpace = AC->getType()->getPointerAddressSpace();
+      Value *NewAC = Builder.CreateAddrSpaceCast(
+          NewV, PointerType::get(NewTy, AddrSpace));
+      ReplaceUsesForLoweredUDT(user, NewAC);
+      AC->eraseFromParent();
+
+    } else if (BitCastInst *BC = dyn_cast<BitCastInst>(user)) {
+      IRBuilder<> Builder(BC);
+      if (BC->getType()->getPointerElementType() == NewTy) {
+        // if alreday bitcast to new type, just replace the bitcast
+        // with the new value (already translated user function)
+        BC->replaceAllUsesWith(NewV);
+      } else {
+        // Could be i8 for memcpy?
+        // Replace bitcast argument with new value
+        use.set(NewV);
+      }
+
+    } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(user)) {
+      // Constant AddrSpaceCast, or BitCast
+      if (CE->getOpcode() == Instruction::AddrSpaceCast) {
+        unsigned AddrSpace = CE->getType()->getPointerAddressSpace();
+        ReplaceUsesForLoweredUDT(user,
+          ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV),
+            PointerType::get(NewTy, AddrSpace)));
+      } else if (CE->getOpcode() == Instruction::BitCast) {
+        if (CE->getType()->getPointerElementType() == NewTy) {
+          // if alreday bitcast to new type, just replace the bitcast
+          // with the new value
+          CE->replaceAllUsesWith(NewV);
+        } else {
+          // Could be i8 for memcpy?
+          // Replace bitcast argument with new value
+          use.set(NewV);
+        }
+      } else {
+        DXASSERT(0, "unhandled constant expr for lowered UTD");
+        CE->dropAllReferences();  // better than infinite loop on release
+      }
+
+    } else if (CallInst *CI = dyn_cast<CallInst>(user)) {
+      // Lower some matrix intrinsics that access pointers early, and
+      // cast arguments for user functions or special UDT intrinsics
+      // for later translation.
+      Function *F = CI->getCalledFunction();
+      HLOpcodeGroup group = GetHLOpcodeGroupByName(F);
+      HLMatrixType Mat = HLMatrixType::dyn_cast(Ty);
+      bool bColMajor = false;
+
+      switch (group) {
+      case HLOpcodeGroup::HLMatLoadStore: {
+        DXASSERT(Mat, "otherwise, matrix operation on non-matrix value");
+        IRBuilder<> Builder(CI);
+        HLMatLoadStoreOpcode opcode =
+            static_cast<HLMatLoadStoreOpcode>(hlsl::GetHLOpcode(CI));
+        switch (opcode) {
+        case HLMatLoadStoreOpcode::ColMatLoad:
+          bColMajor = true;
+          __fallthrough;
+        case HLMatLoadStoreOpcode::RowMatLoad: {
+          Value *val = UndefValue::get(
+            VectorType::get(NewTy->getArrayElementType(),
+                            NewTy->getArrayNumElements()));
+          for (unsigned i = 0; i < NewTy->getArrayNumElements(); ++i) {
+            Value *GEP = Builder.CreateGEP(NewV,
+              {Builder.getInt32(0), Builder.getInt32(i)});
+            Value *elt = Builder.CreateLoad(GEP);
+            val = Builder.CreateInsertElement(val, elt, i);
+          }
+          if (bColMajor) {
+            // transpose matrix to match expected value orientation for
+            // default cast to matrix type
+            SmallVector<int, 16> ShuffleIndices;
+            for (unsigned RowIdx = 0; RowIdx < Mat.getNumRows(); ++RowIdx)
+              for (unsigned ColIdx = 0; ColIdx < Mat.getNumColumns(); ++ColIdx)
+                ShuffleIndices.emplace_back(
+                  static_cast<int>(Mat.getColumnMajorIndex(RowIdx, ColIdx)));
+            val = Builder.CreateShuffleVector(val, val, ShuffleIndices);
+          }
+          // lower mem to reg type
+          val = Mat.emitLoweredMemToReg(val, Builder);
+          // cast vector back to matrix value (DefaultCast expects row major)
+          unsigned newOpcode = (unsigned)HLCastOpcode::DefaultCast;
+          val = callHLFunction(*F->getParent(), HLOpcodeGroup::HLCast, newOpcode,
+                               Ty, { Builder.getInt32(newOpcode), val }, Builder);
+          if (bColMajor) {
+            // emit cast row to col to match original result
+            newOpcode = (unsigned)HLCastOpcode::RowMatrixToColMatrix;
+            val = callHLFunction(*F->getParent(), HLOpcodeGroup::HLCast, newOpcode,
+              Ty, { Builder.getInt32(newOpcode), val }, Builder);
+          }
+          // replace use of HLMatLoadStore with loaded vector
+          CI->replaceAllUsesWith(val);
+        } break;
+        case HLMatLoadStoreOpcode::ColMatStore:
+          bColMajor = true;
+          __fallthrough;
+        case HLMatLoadStoreOpcode::RowMatStore: {
+          // HLCast matrix value to vector
+          unsigned newOpcode = (unsigned)(bColMajor ?
+              HLCastOpcode::ColMatrixToVecCast :
+              HLCastOpcode::RowMatrixToVecCast);
+          Value *val = callHLFunction(*F->getParent(),
+            HLOpcodeGroup::HLCast, newOpcode,
+            Mat.getLoweredVectorType(false),
+            { Builder.getInt32(newOpcode),
+              CI->getArgOperand(HLOperandIndex::kMatStoreValOpIdx) },
+            Builder);
+          // lower reg to mem type
+          val = Mat.emitLoweredRegToMem(val, Builder);
+          for (unsigned i = 0; i < NewTy->getArrayNumElements(); ++i) {
+            Value *elt = Builder.CreateExtractElement(val, i);
+            Value *GEP = Builder.CreateGEP(NewV,
+              {Builder.getInt32(0), Builder.getInt32(i)});
+            Builder.CreateStore(elt, GEP);
+          }
+        } break;
+        default:
+          DXASSERT(0, "invalid opcode");
+        }
+        CI->eraseFromParent();
+      } break;
+
+      case HLOpcodeGroup::HLSubscript: {
+        SmallVector<Value*, 4> ElemIndices;
+        HLSubscriptOpcode opcode =
+            static_cast<HLSubscriptOpcode>(hlsl::GetHLOpcode(CI));
+        switch (opcode) {
+        case HLSubscriptOpcode::VectorSubscript:
+          DXASSERT(0, "not handled yet");
+          break;
+        case HLSubscriptOpcode::ColMatElement:
+          bColMajor = true;
+          __fallthrough;
+        case HLSubscriptOpcode::RowMatElement: {
+          ConstantDataSequential *cIdx = cast<ConstantDataSequential>(
+            CI->getArgOperand(HLOperandIndex::kMatSubscriptSubOpIdx));
+          for (unsigned i = 0; i < cIdx->getNumElements(); ++i) {
+            ElemIndices.push_back(cIdx->getElementAsConstant(i));
+          }
+        } break;
+        case HLSubscriptOpcode::ColMatSubscript:
+          bColMajor = true;
+          __fallthrough;
+        case HLSubscriptOpcode::RowMatSubscript: {
+          for (unsigned Idx = HLOperandIndex::kMatSubscriptSubOpIdx; Idx < CI->getNumArgOperands(); ++Idx) {
+            ElemIndices.emplace_back(CI->getArgOperand(Idx));
+          }
+        } break;
+        default:
+          DXASSERT(0, "invalid opcode");
+        }
+
+        std::vector<Instruction*> DeadInsts;
+        HLMatrixSubscriptUseReplacer UseReplacer(
+          CI, NewV, ElemIndices, /*AllowLoweredPtrGEPs*/true, DeadInsts);
+        DXASSERT(CI->use_empty(),
+                 "Expected all matrix subscript uses to have been replaced.");
+        CI->eraseFromParent();
+        while (!DeadInsts.empty()) {
+          DeadInsts.back()->eraseFromParent();
+          DeadInsts.pop_back();
+        }
+      } break;
+
+      case HLOpcodeGroup::NotHL:
+      case HLOpcodeGroup::HLIntrinsic: {
+        // Just bitcast for now
+        IRBuilder<> Builder(CI);
+        use.set(Builder.CreateBitCast(NewV, V->getType()));
+        continue;
+      } break;
+
+      default:
+        DXASSERT(0, "invalid opcode");
+      }
+    } else {
+      // What else?
+      DXASSERT(false, "case not handled.");
+    }
+  }
+}