Browse Source

Support return of resource for library.
1. Don't check hasMulticomponentUAVLoads for lib.
2. Skip mem2reg for unpromotable handle alloca for lib.
3. Support find resource attribute from function call.
4. Avoid unpack for dxil types.
5. Add resource attribute for fieldAnnotation.

Xiang Li 8 năm trước cách đây
mục cha
commit
a318eb6106

+ 1 - 0
include/dxc/HLSL/DxilOperations.h

@@ -87,6 +87,7 @@ public:
   static bool IsDxilOpFuncCallInst(const llvm::Instruction *I, OpCode opcode);
   static bool IsDxilOpWave(OpCode C);
   static bool IsDxilOpGradient(OpCode C);
+  static bool IsDxilOpType(llvm::StructType *ST);
   static bool IsDupDxilOpType(llvm::StructType *ST);
   static llvm::StructType *GetOriginalDxilOpType(llvm::StructType *ST,
                                                  llvm::Module &M);

+ 6 - 1
lib/HLSL/DxilGenerationPass.cpp

@@ -1540,6 +1540,8 @@ void DxilLegalizeResourceUsePass::PromoteLocalResource(Function &F) {
   OP *hlslOP = HLM.GetOP();
   Type *HandleTy = hlslOP->GetHandleType();
 
+  bool IsLib = HLM.GetShaderModel()->IsLib();
+
   BasicBlock &BB = F.getEntryBlock();
   unsigned allocaSize = 0;
   while (1) {
@@ -1550,7 +1552,10 @@ void DxilLegalizeResourceUsePass::PromoteLocalResource(Function &F) {
     for (BasicBlock::iterator I = BB.begin(), E = --BB.end(); I != E; ++I)
       if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { // Is it an alloca?
         if (HandleTy == dxilutil::GetArrayEltTy(AI->getAllocatedType())) {
-          DXASSERT(isAllocaPromotable(AI), "otherwise, non-promotable resource array alloca found");
+          // Skip for unpromotable for lib.
+          if (!isAllocaPromotable(AI) && IsLib)
+            continue;
+
           Allocas.push_back(AI);
         }
       }

+ 3 - 0
lib/HLSL/DxilModule.cpp

@@ -354,6 +354,9 @@ void DxilModule::CollectShaderFlags(ShaderFlags &Flags) {
     GetValidatorVersion(valMajor, valMinor);
     hasMulticomponentUAVLoadsBackCompat = valMajor <= 1 && valMinor == 0;
   }
+  // Don't check hasMulticomponentUAVLoads for lib.
+  if (m_pSM->IsLib())
+    hasMulticomponentUAVLoads = true;
 
   Type *int16Ty = Type::getInt16Ty(GetCtx());
   Type *int64Ty = Type::getInt64Ty(GetCtx());

+ 7 - 0
lib/HLSL/DxilOperations.cpp

@@ -343,6 +343,13 @@ bool OP::IsDxilOpFunc(const llvm::Function *F) {
   return IsDxilOpFuncName(F->getName());
 }
 
+bool OP::IsDxilOpType(llvm::StructType *ST) {
+  if (!ST->hasName())
+    return false;
+  StringRef Name = ST->getName();
+  return Name.startswith(m_TypePrefix);
+}
+
 bool OP::IsDupDxilOpType(llvm::StructType *ST) {
   if (!ST->hasName())
     return false;

+ 40 - 0
lib/HLSL/HLOperationLower.cpp

@@ -167,6 +167,46 @@ private:
       HandleMetaMap[Handle] = Attrib;
       return HandleMetaMap[Handle];
     }
+    if (LoadInst *LI = dyn_cast<LoadInst>(Handle)) {
+      Value *Ptr = LI->getPointerOperand();
+
+      for (User *U : Ptr->users()) {
+        if (CallInst *CI = dyn_cast<CallInst>(U)) {
+          DxilFunctionAnnotation *FnAnnot = HLM.GetFunctionAnnotation(CI->getCalledFunction());
+          if (FnAnnot) {
+            for (auto &arg : CI->arg_operands()) {
+              if (arg == Ptr) {
+                unsigned argNo = arg.getOperandNo();
+                DxilParameterAnnotation &ParamAnnot = FnAnnot->GetParameterAnnotation(argNo);
+                MDNode *MD = ParamAnnot.GetResourceAttribute();
+                if (!MD) {
+                  Handle->getContext().emitError(
+                      "cannot map resource to handle");
+                  return HandleMetaMap[Handle];
+                }
+                DxilResourceBase Res(DxilResource::Class::Invalid);
+                HLM.LoadDxilResourceBaseFromMDNode(MD, Res);
+
+                ResAttribute Attrib = {Res.GetClass(), Res.GetKind(),
+                                       Res.GetGlobalSymbol()->getType()};
+
+                HandleMetaMap[Handle] = Attrib;
+                return HandleMetaMap[Handle];
+              }
+            }
+          }
+        }
+        if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
+          Value *V = SI->getValueOperand();
+          ResAttribute Attrib = FindCreateHandleResourceBase(V);
+          HandleMetaMap[Handle] = Attrib;
+          return HandleMetaMap[Handle];
+        }
+      }
+      // Cannot find.
+      Handle->getContext().emitError("cannot map resource to handle");
+      return HandleMetaMap[Handle];
+    }
     if (CallInst *CI = dyn_cast<CallInst>(Handle)) {
       MDNode *MD = HLM.GetDxilResourceAttrib(CI->getCalledFunction());
       if (!MD) {

+ 4 - 1
lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp

@@ -20,6 +20,7 @@
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include "dxc/HLSL/DxilOperations.h"   // HLSL Change - avoid unpack for dxil types.
 using namespace llvm;
 
 #define DEBUG_TYPE "instcombine"
@@ -519,7 +520,9 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
 
   if (auto *ST = dyn_cast<StructType>(T)) {
     // If the struct only have one element, we unpack.
-    if (ST->getNumElements() == 1) {
+    if (ST->getNumElements() == 1
+        && !hlsl::OP::IsDxilOpType(ST) // HLSL Change - avoid unpack dxil types.
+        ) {
       LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U),
                                                ".unpack");
       return IC.ReplaceInstUsesWith(LI, IC.Builder->CreateInsertValue(

+ 1 - 0
lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp

@@ -5252,6 +5252,7 @@ void SROA_Parameter_HLSL::flattenArgument(
       flatParamAnnotation.SetCompType(annotation.GetCompType().GetKind());
       flatParamAnnotation.SetMatrixAnnotation(annotation.GetMatrixAnnotation());
       flatParamAnnotation.SetPrecise(annotation.IsPrecise());
+      flatParamAnnotation.SetResourceAttribute(annotation.GetResourceAttribute());
 
       // Add debug info.
       if (DDI && V != Arg) {

+ 84 - 36
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -199,6 +199,10 @@ private:
                                      DxilTypeSystem &dxilTypeSys);
   unsigned AddTypeAnnotation(QualType Ty, DxilTypeSystem &dxilTypeSys,
                              unsigned &arrayEltSize);
+  MDNode *GetOrAddResTypeMD(QualType resTy);
+  void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotation,
+                                          QualType fieldTy,
+                                          bool bDefaultRowMajor);
 
   std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
 
@@ -650,7 +654,68 @@ static CompType::Kind BuiltinTyToCompTy(const BuiltinType *BTy, bool bSNorm,
   return kind;
 }
 
-static void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotation, QualType fieldTy, bool bDefaultRowMajor) {
+static DxilSampler::SamplerKind KeywordToSamplerKind(llvm::StringRef keyword) {
+  // TODO: refactor for faster search (switch by 1/2/3 first letters, then
+  // compare)
+  return llvm::StringSwitch<DxilSampler::SamplerKind>(keyword)
+    .Case("SamplerState", DxilSampler::SamplerKind::Default)
+    .Case("SamplerComparisonState", DxilSampler::SamplerKind::Comparison)
+    .Default(DxilSampler::SamplerKind::Invalid);
+}
+
+MDNode *CGMSHLSLRuntime::GetOrAddResTypeMD(QualType resTy) {
+  const RecordType *RT = resTy->getAs<RecordType>();
+  if (!RT)
+    return nullptr;
+  RecordDecl *RD = RT->getDecl();
+  SourceLocation loc = RD->getLocation();
+
+  hlsl::DxilResourceBase::Class resClass = TypeToClass(resTy);
+  llvm::Type *Ty = CGM.getTypes().ConvertType(resTy);
+  auto it = resMetadataMap.find(Ty);
+  if (it != resMetadataMap.end())
+    return it->second;
+
+  // Save resource type metadata.
+  switch (resClass) {
+  case DXIL::ResourceClass::UAV: {
+    DxilResource UAV;
+    // TODO: save globalcoherent to variable in EmitHLSLBuiltinCallExpr.
+    SetUAVSRV(loc, resClass, &UAV, RD);
+    // Set global symbol to save type.
+    UAV.SetGlobalSymbol(UndefValue::get(Ty));
+    MDNode *MD = m_pHLModule->DxilUAVToMDNode(UAV);
+    resMetadataMap[Ty] = MD;
+    return MD;
+  } break;
+  case DXIL::ResourceClass::SRV: {
+    DxilResource SRV;
+    SetUAVSRV(loc, resClass, &SRV, RD);
+    // Set global symbol to save type.
+    SRV.SetGlobalSymbol(UndefValue::get(Ty));
+    MDNode *MD = m_pHLModule->DxilSRVToMDNode(SRV);
+    resMetadataMap[Ty] = MD;
+    return MD;
+  } break;
+  case DXIL::ResourceClass::Sampler: {
+    DxilSampler S;
+    DxilSampler::SamplerKind kind = KeywordToSamplerKind(RD->getName());
+    S.SetSamplerKind(kind);
+    // Set global symbol to save type.
+    S.SetGlobalSymbol(UndefValue::get(Ty));
+    MDNode *MD = m_pHLModule->DxilSamplerToMDNode(S);
+    resMetadataMap[Ty] = MD;
+    return MD;
+  }
+  default:
+    // Skip OutputStream for GS.
+    return nullptr;
+  }
+}
+
+void CGMSHLSLRuntime::ConstructFieldAttributedAnnotation(
+    DxilFieldAnnotation &fieldAnnotation, QualType fieldTy,
+    bool bDefaultRowMajor) {
   QualType Ty = fieldTy;
   if (Ty->isReferenceType())
     Ty = Ty.getNonReferenceType();
@@ -690,6 +755,11 @@ static void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotat
   if (hlsl::IsHLSLVecType(Ty))
     EltTy = hlsl::GetHLSLVecElementType(Ty);
 
+  if (IsHLSLResourceType(Ty)) {
+    MDNode *MD = GetOrAddResTypeMD(Ty);
+    fieldAnnotation.SetResourceAttribute(MD);
+  }
+
   bool bSNorm = false;
   bool bUNorm = false;
 
@@ -711,15 +781,15 @@ static void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotat
     const BuiltinType *BTy = EltTy->getAs<BuiltinType>();
     CompType::Kind kind = BuiltinTyToCompTy(BTy, bSNorm, bUNorm);
     fieldAnnotation.SetCompType(kind);
-  }
-  else if (EltTy->isEnumeralType()) {
+  } else if (EltTy->isEnumeralType()) {
     const EnumType *ETy = EltTy->getAs<EnumType>();
     QualType type = ETy->getDecl()->getIntegerType();
-    if (const BuiltinType *BTy = dyn_cast<BuiltinType>(type->getCanonicalTypeInternal()))
-        fieldAnnotation.SetCompType(BuiltinTyToCompTy(BTy, bSNorm, bUNorm));
-  }
-  else
-    DXASSERT(!bSNorm && !bUNorm, "snorm/unorm on invalid type, validate at handleHLSLTypeAttr");
+    if (const BuiltinType *BTy =
+            dyn_cast<BuiltinType>(type->getCanonicalTypeInternal()))
+      fieldAnnotation.SetCompType(BuiltinTyToCompTy(BTy, bSNorm, bUNorm));
+  } else
+    DXASSERT(!bSNorm && !bUNorm,
+             "snorm/unorm on invalid type, validate at handleHLSLTypeAttr");
 }
 
 static void ConstructFieldInterpolation(DxilFieldAnnotation &fieldAnnotation,
@@ -977,15 +1047,6 @@ static DxilResource::Kind KeywordToKind(StringRef keyword) {
   return DxilResource::Kind::Invalid;
 }
 
-static DxilSampler::SamplerKind KeywordToSamplerKind(llvm::StringRef keyword) {
-  // TODO: refactor for faster search (switch by 1/2/3 first letters, then
-  // compare)
-  return llvm::StringSwitch<DxilSampler::SamplerKind>(keyword)
-    .Case("SamplerState", DxilSampler::SamplerKind::Default)
-    .Case("SamplerComparisonState", DxilSampler::SamplerKind::Comparison)
-    .Default(DxilSampler::SamplerKind::Invalid);
-}
-
 void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   // Add hlsl intrinsic attr
   unsigned intrinsicOpcode;
@@ -1009,34 +1070,21 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       // Save resource type metadata.
       switch (resClass) {
       case DXIL::ResourceClass::UAV: {
-        DxilResource UAV;
-        // TODO: save globalcoherent to variable in EmitHLSLBuiltinCallExpr.
-        SetUAVSRV(FD->getLocation(), resClass, &UAV, RD);
-        // Set global symbol to save type.
-        UAV.SetGlobalSymbol(UndefValue::get(Ty));
-        MDNode *MD = m_pHLModule->DxilUAVToMDNode(UAV);
+        MDNode *MD = GetOrAddResTypeMD(recordTy);
+        DXASSERT(MD, "else invalid resource type");
         resMetadataMap[Ty] = MD;
       } break;
       case DXIL::ResourceClass::SRV: {
-        DxilResource SRV;
-        SetUAVSRV(FD->getLocation(), resClass, &SRV, RD);
-        // Set global symbol to save type.
-        SRV.SetGlobalSymbol(UndefValue::get(Ty));
-        MDNode *Meta = m_pHLModule->DxilSRVToMDNode(SRV);
+        MDNode *Meta = GetOrAddResTypeMD(recordTy);
+        DXASSERT(Meta, "else invalid resource type");
         resMetadataMap[Ty] = Meta;
         if (FT->getNumParams() > 1) {
           QualType paramTy = MD->getParamDecl(0)->getType();
           // Add sampler type.
           if (TypeToClass(paramTy) == DXIL::ResourceClass::Sampler) {
             llvm::Type *Ty = FT->getParamType(1)->getPointerElementType();
-            DxilSampler S;
-            const RecordType *RT = paramTy->getAs<RecordType>();
-            DxilSampler::SamplerKind kind =
-                KeywordToSamplerKind(RT->getDecl()->getName());
-            S.SetSamplerKind(kind);
-            // Set global symbol to save type.
-            S.SetGlobalSymbol(UndefValue::get(Ty));
-            MDNode *MD = m_pHLModule->DxilSamplerToMDNode(S);
+            MDNode *MD = GetOrAddResTypeMD(paramTy);
+            DXASSERT(MD, "else invalid resource type");
             resMetadataMap[Ty] = MD;
           }
         }

+ 11 - 0
tools/clang/test/CodeGenHLSL/shader-compat-suite/lib_out_param_res.hlsl

@@ -0,0 +1,11 @@
+// RUN: %dxc -T lib_6_1 %s | FileCheck %s
+
+// CHECK: call void @"\01?GetBuf@@YA?AV?$Buffer@V?$vector@M$03@@@@XZ"(%dx.types.Handle* nonnull %{{.*}})
+// Make sure resource return type works.
+
+Buffer<float4> GetBuf();
+
+float4 test(uint i) {
+  Buffer<float4> buf = GetBuf();
+  return buf[i];
+}