Browse Source

Use SRet for struct return type. (#243)

1. Use SRet for struct return type.
2. Add SM6.1 and DXIL1.1.
Xiang Li 8 years ago
parent
commit
bc4a68b6b4

+ 4 - 2
include/dxc/HLSL/DxilShaderModel.h

@@ -29,7 +29,7 @@ public:
 
   // Major/Minor version of highest shader model
   static const unsigned kHighestMajor = 6;
-  static const unsigned kHighestMinor = 0;
+  static const unsigned kHighestMinor = 1;
 
   bool IsPS() const     { return m_Kind == Kind::Pixel; }
   bool IsVS() const     { return m_Kind == Kind::Vertex; }
@@ -42,8 +42,10 @@ public:
   Kind GetKind() const      { return m_Kind; }
   unsigned GetMajor() const { return m_Major; }
   unsigned GetMinor() const { return m_Minor; }
+  void GetDxilVersion(unsigned &DxilMajor, unsigned &DxilMinor) const;
   bool IsSM50Plus() const   { return m_Major >= 5; }
   bool IsSM51Plus() const   { return m_Major > 5 || (m_Major == 5 && m_Minor >= 1); }
+  bool IsSM61Plus() const   { return m_Major > 6 || (m_Major == 6 && m_Minor >= 1); }
   const char *GetName() const { return m_pszName; }
   std::string GetKindName() const;
   unsigned GetNumTempRegs() const { return DXIL::kMaxTempRegCount; }
@@ -79,7 +81,7 @@ private:
               unsigned m_NumInputRegs, unsigned m_NumOutputRegs,
               bool m_bUAVs, bool m_bTypedUavs, unsigned m_UAVRegsLim);
 
-  static const unsigned kNumShaderModels = 27;
+  static const unsigned kNumShaderModels = 33;
   static const ShaderModel ms_ShaderModels[kNumShaderModels];
 
   static const ShaderModel *GetInvalid();

+ 20 - 0
lib/HLSL/DxilGenerationPass.cpp

@@ -2838,10 +2838,30 @@ public:
 
   const char *getPassName() const override { return "HLSL DXIL Metadata Emit"; }
 
+  void patchSM60(Module &M) {
+    for (iplist<Function>::iterator F : M.getFunctionList()) {
+      for (Function::iterator BBI = F->begin(), BBE = F->end(); BBI != BBE;
+           ++BBI) {
+        BasicBlock *BB = BBI;
+        for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;
+             ++II) {
+          Instruction *I = II;
+          if (I->getMetadata(LLVMContext::MD_noalias)) {
+            I->setMetadata(LLVMContext::MD_noalias, nullptr);
+          }
+        }
+      }
+    }
+  }
+
   bool runOnModule(Module &M) override {
     if (M.HasDxilModule()) {
       // Remove store undef output.
       hlsl::OP *hlslOP = M.GetDxilModule().GetOP();
+      bool bIsSM61Plus = M.GetDxilModule().GetShaderModel()->IsSM61Plus();
+      if (!bIsSM61Plus) {
+        patchSM60(M);
+      }
       for (iplist<Function>::iterator F : M.getFunctionList()) {
         if (!hlslOP->IsDxilOpFunc(F))
           continue;

+ 1 - 0
lib/HLSL/DxilModule.cpp

@@ -927,6 +927,7 @@ vector<GlobalVariable* > &DxilModule::GetLLVMUsed() {
 
 // DXIL metadata serialization/deserialization.
 void DxilModule::EmitDxilMetadata() {
+  m_pSM->GetDxilVersion(m_DxilMajor, m_DxilMinor);
   m_pMDHelper->EmitDxilVersion(m_DxilMajor, m_DxilMinor);
   m_pMDHelper->EmitDxilShaderModel(m_pSM);
 

+ 21 - 0
lib/HLSL/DxilShaderModel.cpp

@@ -101,6 +101,21 @@ const ShaderModel *ShaderModel::GetByName(const char *pszName) {
   return Get(Kind, Major, Minor);
 }
 
+void ShaderModel::GetDxilVersion(unsigned &DxilMajor, unsigned &DxilMinor) const {
+  DXASSERT(m_Major == 6, "invalid major");
+  switch (m_Minor) {
+  case 0:
+    DxilMinor = 0;
+    break;
+  case 1:
+    DxilMinor = 1;
+    break;
+  default:
+    DXASSERT(0, "invalid minor");
+    break;
+  }
+}
+
 std::string ShaderModel::GetKindName() const {
   return std::string(m_pszName).substr(0, 2);
 }
@@ -118,32 +133,38 @@ const ShaderModel ShaderModel::ms_ShaderModels[kNumShaderModels] = {
   SM(Kind::Compute,  5, 0, "cs_5_0",  0,  0,   true,  true,  64),
   SM(Kind::Compute,  5, 1, "cs_5_1",  0,  0,   true,  true,  UINT_MAX),
   SM(Kind::Compute,  6, 0, "cs_6_0",  0,  0,   true,  true,  UINT_MAX),
+  SM(Kind::Compute,  6, 1, "cs_6_1",  0,  0,   true,  true,  UINT_MAX),
 
   SM(Kind::Domain,   5, 0, "ds_5_0",  32, 32,  true,  true,  64),
   SM(Kind::Domain,   5, 1, "ds_5_1",  32, 32,  true,  true,  UINT_MAX),
   SM(Kind::Domain,   6, 0, "ds_6_0",  32, 32,  true,  true,  UINT_MAX),
+  SM(Kind::Domain,   6, 1, "ds_6_1",  32, 32,  true,  true,  UINT_MAX),
 
   SM(Kind::Geometry, 4, 0, "gs_4_0",  16, 32,  false, false, 0),
   SM(Kind::Geometry, 4, 1, "gs_4_1",  32, 32,  false, false, 0),
   SM(Kind::Geometry, 5, 0, "gs_5_0",  32, 32,  true,  true,  64),
   SM(Kind::Geometry, 5, 1, "gs_5_1",  32, 32,  true,  true,  UINT_MAX),
   SM(Kind::Geometry, 6, 0, "gs_6_0",  32, 32,  true,  true,  UINT_MAX),
+  SM(Kind::Geometry, 6, 1, "gs_6_1",  32, 32,  true,  true,  UINT_MAX),
 
   SM(Kind::Hull,     5, 0, "hs_5_0",  32, 32,  true,  true,  64),
   SM(Kind::Hull,     5, 1, "hs_5_1",  32, 32,  true,  true,  UINT_MAX),
   SM(Kind::Hull,     6, 0, "hs_6_0",  32, 32,  true,  true,  UINT_MAX),
+  SM(Kind::Hull,     6, 1, "hs_6_1",  32, 32,  true,  true,  UINT_MAX),
 
   SM(Kind::Pixel,    4, 0, "ps_4_0",  32, 8,   false, false, 0),
   SM(Kind::Pixel,    4, 1, "ps_4_1",  32, 8,   false, false, 0),
   SM(Kind::Pixel,    5, 0, "ps_5_0",  32, 8,   true,  true,  64),
   SM(Kind::Pixel,    5, 1, "ps_5_1",  32, 8,   true,  true,  UINT_MAX),
   SM(Kind::Pixel,    6, 0, "ps_6_0",  32, 8,   true,  true,  UINT_MAX),
+  SM(Kind::Pixel,    6, 1, "ps_6_1",  32, 8,   true,  true,  UINT_MAX),
 
   SM(Kind::Vertex,   4, 0, "vs_4_0",  16, 16,  false, false, 0),
   SM(Kind::Vertex,   4, 1, "vs_4_1",  32, 32,  false, false, 0),
   SM(Kind::Vertex,   5, 0, "vs_5_0",  32, 32,  true,  true,  64),
   SM(Kind::Vertex,   5, 1, "vs_5_1",  32, 32,  true,  true,  UINT_MAX),
   SM(Kind::Vertex,   6, 0, "vs_6_0",  32, 32,  true,  true,  UINT_MAX),
+  SM(Kind::Vertex,   6, 1, "vs_6_1",  32, 32,  true,  true,  UINT_MAX),
 
   SM(Kind::Invalid,  0, 0, "invalid", 0,  0,   false, false, 0),
 };

+ 3 - 2
lib/HLSL/HLModule.cpp

@@ -61,8 +61,8 @@ HLModule::HLModule(Module *pModule)
           pModule, llvm::make_unique<HLExtraPropertyHelper>(pModule)))
     , m_pDebugInfoFinder(nullptr)
     , m_pSM(nullptr)
-    , m_DxilMajor(1)
-    , m_DxilMinor(0)
+    , m_DxilMajor(DXIL::kDxilMajor)
+    , m_DxilMinor(DXIL::kDxilMinor)
     , m_pOP(llvm::make_unique<OP>(pModule->getContext(), pModule))
     , m_pTypeSystem(llvm::make_unique<DxilTypeSystem>(pModule)) {
   DXASSERT_NOMSG(m_pModule != nullptr);
@@ -83,6 +83,7 @@ OP *HLModule::GetOP() const { return m_pOP.get(); }
 void HLModule::SetShaderModel(const ShaderModel *pSM) {
   DXASSERT(m_pSM == nullptr, "shader model must not change for the module");
   m_pSM = pSM;
+  m_pSM->GetDxilVersion(m_DxilMajor, m_DxilMinor);
   m_pMDHelper->SetShaderModel(m_pSM);
   CreateSignatures(m_pSM, m_InputSignature, m_OutputSignature, m_PatchConstantSignature, m_RootSignature);
 }

+ 73 - 55
tools/clang/lib/CodeGen/CGHLSLMS.cpp

@@ -301,11 +301,12 @@ CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM)
   const hlsl::ShaderModel *SM =
       hlsl::ShaderModel::GetByName(CGM.getCodeGenOpts().HLSLProfile.c_str());
   // Only accept valid, 6.0 shader model.
-  if (!SM->IsValid() || SM->GetMajor() != 6 || SM->GetMinor() != 0) {
+  if (!SM->IsValid() || SM->GetMajor() != 6) {
     DiagnosticsEngine &Diags = CGM.getDiags();
     unsigned DiagID =
         Diags.getCustomDiagID(DiagnosticsEngine::Error, "invalid profile %0");
     Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLProfile;
+    return;
   }
   // TODO: add AllResourceBound.
   if (CGM.getCodeGenOpts().HLSLAvoidControlFlow && !CGM.getCodeGenOpts().HLSLAllResourcesBound) {
@@ -978,7 +979,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD)) {
       const CXXRecordDecl *RD = MD->getParent();
       // For nested case like sample_slice_type.
-      if (const CXXRecordDecl *PRD = dyn_cast<CXXRecordDecl>(RD->getDeclContext())) {
+      if (const CXXRecordDecl *PRD =
+              dyn_cast<CXXRecordDecl>(RD->getDeclContext())) {
         RD = PRD;
       }
 
@@ -994,7 +996,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
         SetUAVSRV(FD->getLocation(), resClass, &UAV, RD);
         // Set global symbol to save type.
         UAV.SetGlobalSymbol(UndefValue::get(Ty));
-        MDNode * MD = m_pHLModule->DxilUAVToMDNode(UAV);
+        MDNode *MD = m_pHLModule->DxilUAVToMDNode(UAV);
         resMetadataMap[Ty] = MD;
       } break;
       case DXIL::ResourceClass::SRV: {
@@ -1002,13 +1004,13 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
         SetUAVSRV(FD->getLocation(), resClass, &SRV, RD);
         // Set global symbol to save type.
         SRV.SetGlobalSymbol(UndefValue::get(Ty));
-        MDNode * Meta = m_pHLModule->DxilSRVToMDNode(SRV);
+        MDNode *Meta = m_pHLModule->DxilSRVToMDNode(SRV);
         resMetadataMap[Ty] = Meta;
         if (FT->getNumParams() > 1) {
           QualType paramTy = MD->getParamDecl(0)->getType();
           // Add sampler type.
           if (TypeToClass(paramTy) == DXIL::ResourceClass::Sampler) {
-            llvm::Type * Ty = FT->getParamType(1)->getPointerElementType();
+            llvm::Type *Ty = FT->getParamType(1)->getPointerElementType();
             DxilSampler S;
             const RecordType *RT = paramTy->getAs<RecordType>();
             DxilSampler::SamplerKind kind =
@@ -1034,14 +1036,15 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     // Don't need to add FunctionQual for intrinsic function.
     return;
   }
-  
+
   // Set entry function
   const std::string &entryName = m_pHLModule->GetEntryFunctionName();
   bool isEntry = FD->getNameAsString() == entryName;
   if (isEntry)
     EntryFunc = F;
 
-  std::unique_ptr<HLFunctionProps> funcProps = llvm::make_unique<HLFunctionProps>();
+  std::unique_ptr<HLFunctionProps> funcProps =
+      llvm::make_unique<HLFunctionProps>();
 
   // Save patch constant function to patchConstantFunctionMap.
   bool isPatchConstantFunction = false;
@@ -1051,9 +1054,10 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       patchConstantFunctionMap[FD->getName()] = F;
     else {
       // TODO: This is not the same as how fxc handles patch constant functions.
-      //  This will fail if more than one function with the same name has a SV_TessFactor semantic.
-      //  Fxc just selects the last function defined that has the matching name when referenced
-      //  by the patchconstantfunc attribute from the hull shader currently being compiled.
+      //  This will fail if more than one function with the same name has a
+      //  SV_TessFactor semantic. Fxc just selects the last function defined
+      //  that has the matching name when referenced by the patchconstantfunc
+      //  attribute from the hull shader currently being compiled.
       // Report error
       DiagnosticsEngine &Diags = CGM.getDiags();
       unsigned DiagID =
@@ -1063,8 +1067,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
       return;
     }
 
-    for (Argument &arg : F->getArgumentList()) {
-      const ParmVarDecl *parmDecl = FD->getParamDecl(arg.getArgNo());
+    for (ParmVarDecl *parmDecl : FD->parameters()) {
       QualType Ty = parmDecl->getType();
       if (IsHLSLOutputPatchType(Ty)) {
         funcProps->ShaderProps.HS.outputControlPoints =
@@ -1080,7 +1083,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
   // TODO: how to know VS/PS?
   funcProps->shaderKind = DXIL::ShaderKind::Invalid;
-  
+
   DiagnosticsEngine &Diags = CGM.getDiags();
   // Geometry shader.
   bool isGS = false;
@@ -1092,8 +1095,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     funcProps->ShaderProps.GS.inputPrimitive = DXIL::InputPrimitive::Undefined;
 
     if (isEntry && !SM->IsGS()) {
-      unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
-                                              "attribute maxvertexcount only valid for GS.");
+      unsigned DiagID =
+          Diags.getCustomDiagID(DiagnosticsEngine::Error,
+                                "attribute maxvertexcount only valid for GS.");
       Diags.Report(Attr->getLocation(), DiagID);
       return;
     }
@@ -1102,13 +1106,13 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     unsigned instanceCount = Attr->getCount();
     funcProps->ShaderProps.GS.instanceCount = instanceCount;
     if (isEntry && !SM->IsGS()) {
-      unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
-                                              "attribute maxvertexcount only valid for GS.");
+      unsigned DiagID =
+          Diags.getCustomDiagID(DiagnosticsEngine::Error,
+                                "attribute maxvertexcount only valid for GS.");
       Diags.Report(Attr->getLocation(), DiagID);
       return;
     }
-  }
-  else {
+  } else {
     // Set default instance count.
     if (isGS)
       funcProps->ShaderProps.GS.instanceCount = 1;
@@ -1209,7 +1213,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
           FD->getAttr<HLSLOutputTopologyAttr>()) {
     if (isHS) {
       DXIL::TessellatorOutputPrimitive primitive =
-            StringToTessOutputPrimitive(Attr->getTopology());
+          StringToTessOutputPrimitive(Attr->getTopology());
       funcProps->ShaderProps.HS.outputPrimitive = primitive;
     } else if (isEntry && !SM->IsHS()) {
       unsigned DiagID =
@@ -1264,26 +1268,26 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   bool isVS = false;
   if (const HLSLClipPlanesAttr *Attr = FD->getAttr<HLSLClipPlanesAttr>()) {
     if (isEntry && !SM->IsVS()) {
-      unsigned DiagID =
-          Diags.getCustomDiagID(DiagnosticsEngine::Error,
-                                "attribute clipplane only valid for VS.");
+      unsigned DiagID = Diags.getCustomDiagID(
+          DiagnosticsEngine::Error, "attribute clipplane only valid for VS.");
       Diags.Report(Attr->getLocation(), DiagID);
       return;
     }
 
     isVS = true;
-    // The real job is done at EmitHLSLFunctionProlog where debug info is available.
-    // Only set shader kind here.
+    // The real job is done at EmitHLSLFunctionProlog where debug info is
+    // available. Only set shader kind here.
     funcProps->shaderKind = DXIL::ShaderKind::Vertex;
   }
 
   // Pixel shader.
   bool isPS = false;
-  if (const HLSLEarlyDepthStencilAttr *Attr = FD->getAttr<HLSLEarlyDepthStencilAttr>()) {
+  if (const HLSLEarlyDepthStencilAttr *Attr =
+          FD->getAttr<HLSLEarlyDepthStencilAttr>()) {
     if (isEntry && !SM->IsPS()) {
-      unsigned DiagID =
-          Diags.getCustomDiagID(DiagnosticsEngine::Error,
-                                "attribute earlydepthstencil only valid for PS.");
+      unsigned DiagID = Diags.getCustomDiagID(
+          DiagnosticsEngine::Error,
+          "attribute earlydepthstencil only valid for PS.");
       Diags.Report(Attr->getLocation(), DiagID);
       return;
     }
@@ -1308,7 +1312,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     profileAttributes++;
 
   // TODO: check this in front-end and report error.
-  DXASSERT(profileAttributes<2, "profile attributes are mutual exclusive");
+  DXASSERT(profileAttributes < 2, "profile attributes are mutual exclusive");
 
   if (isEntry) {
     switch (funcProps->shaderKind) {
@@ -1324,11 +1328,37 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     }
   }
 
-  DxilFunctionAnnotation *FuncAnnotation = m_pHLModule->AddFunctionAnnotation(F);
+  DxilFunctionAnnotation *FuncAnnotation =
+      m_pHLModule->AddFunctionAnnotation(F);
+  bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
+
+  // Param Info
+  unsigned streamIndex = 0;
+  unsigned inputPatchCount = 0;
+  unsigned outputPatchCount = 0;
+
+  unsigned ArgNo = 0;
+  unsigned ParmIdx = 0;
+
+  if (const CXXMethodDecl *MethodDecl = dyn_cast<CXXMethodDecl>(FD)) {
+    QualType ThisTy = MethodDecl->getThisType(FD->getASTContext());
+    DxilParameterAnnotation &paramAnnotation =
+        FuncAnnotation->GetParameterAnnotation(ArgNo++);
+    // Construct annoation for this pointer.
+    ConstructFieldAttributedAnnotation(paramAnnotation, ThisTy,
+                                       bDefaultRowMajor);
+  }
 
   // Ret Info
-  DxilParameterAnnotation &retTyAnnotation = FuncAnnotation->GetRetTypeAnnotation();
   QualType retTy = FD->getReturnType();
+  DxilParameterAnnotation *pRetTyAnnotation = nullptr;
+  if (F->getReturnType()->isVoidTy() && !retTy->isVoidType()) {
+    // SRet.
+    pRetTyAnnotation = &FuncAnnotation->GetParameterAnnotation(ArgNo++);
+  } else {
+    pRetTyAnnotation = &FuncAnnotation->GetRetTypeAnnotation();
+  }
+  DxilParameterAnnotation &retTyAnnotation = *pRetTyAnnotation;
   // keep Undefined here, we cannot decide for struct
   retTyAnnotation.SetInterpolationMode(
       GetInterpMode(FD, CompType::Kind::Invalid, /*bKeepUndefined*/ true)
@@ -1336,35 +1366,22 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
   SourceLocation retTySemanticLoc = SetSemantic(FD, retTyAnnotation);
   retTyAnnotation.SetParamInputQual(DxilParamInputQual::Out);
   if (isEntry) {
-    CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation, /*isPatchConstantFunction*/false);
+    CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation,
+                             /*isPatchConstantFunction*/ false);
   }
 
-  bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
-
   ConstructFieldAttributedAnnotation(retTyAnnotation, retTy, bDefaultRowMajor);
   if (FD->hasAttr<HLSLPreciseAttr>())
     retTyAnnotation.SetPrecise();
 
-  // Param Info
-  unsigned streamIndex = 0;
-  unsigned inputPatchCount = 0;
-  unsigned outputPatchCount = 0;
-
-  for (unsigned ArgNo = 0; ArgNo < F->arg_size(); ++ArgNo) {
-    unsigned ParmIdx = ArgNo;
+  for (; ArgNo < F->arg_size(); ++ArgNo, ++ParmIdx) {
+    DxilParameterAnnotation &paramAnnotation =
+        FuncAnnotation->GetParameterAnnotation(ArgNo);
 
-    DxilParameterAnnotation &paramAnnotation = FuncAnnotation->GetParameterAnnotation(ArgNo);
-    
-    if (isa<CXXMethodDecl>(FD)) {
-      // skip arg0 for this pointer
-      if (ArgNo == 0)
-        continue;
-      // update idx for rest params
-      ParmIdx--;
-    }
     const ParmVarDecl *parmDecl = FD->getParamDecl(ParmIdx);
-    
-    ConstructFieldAttributedAnnotation(paramAnnotation, parmDecl->getType(), bDefaultRowMajor);
+
+    ConstructFieldAttributedAnnotation(paramAnnotation, parmDecl->getType(),
+                                       bDefaultRowMajor);
     if (parmDecl->hasAttr<HLSLPreciseAttr>())
       paramAnnotation.SetPrecise();
 
@@ -1532,7 +1549,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
 
     paramAnnotation.SetParamInputQual(dxilInputQ);
     if (isEntry) {
-      CheckParameterAnnotation(paramSemanticLoc, paramAnnotation, /*isPatchConstantFunction*/false);
+      CheckParameterAnnotation(paramSemanticLoc, paramAnnotation,
+                               /*isPatchConstantFunction*/ false);
     }
   }
 
@@ -1561,7 +1579,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
     AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
   }
 
-  for (const ValueDecl*param : FD->params()) {
+  for (const ValueDecl *param : FD->params()) {
     QualType Ty = param->getType();
     AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
   }

+ 18 - 2
tools/clang/lib/CodeGen/TargetInfo.cpp

@@ -6181,7 +6181,14 @@ public:
   ABIArgInfo classifyReturnType(QualType RetTy) const {
     if (RetTy->isVoidType())
       return ABIArgInfo::getIgnore();
-    // do not create SRet for HLSL
+    if (isAggregateTypeForABI(RetTy))
+      return ABIArgInfo::getIndirect(0);
+
+    // Treat an enum type as its underlying type.
+    if (const EnumType *EnumTy = RetTy->getAs<EnumType>())
+      RetTy = EnumTy->getDecl()->getIntegerType();
+
+    // do not use extend for hlsl.
     return ABIArgInfo::getDirect(CGT.ConvertType(RetTy));
   }
 
@@ -6218,7 +6225,16 @@ ABIArgInfo MSDXILABIInfo::classifyArgumentType(QualType Ty) const {
 }
 
 void MSDXILABIInfo::computeInfo(CGFunctionInfo &FI) const {
-  FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
+  QualType RetTy = FI.getReturnType();
+  if (RetTy->isVoidType()) {
+    FI.getReturnInfo() = ABIArgInfo::getIgnore();
+  } else if (isAggregateTypeForABI(RetTy)) {
+    if (!getCXXABI().classifyReturnType(FI))
+      FI.getReturnInfo() = classifyReturnType(RetTy);
+  } else {
+    // Make vector and matrix direct ret.
+    FI.getReturnInfo() = classifyReturnType(RetTy);
+  }
   for (auto &I : FI.arguments()) {
     I.info = classifyArgumentType(I.type);
     // Do not flat matrix

+ 0 - 1
tools/clang/test/CodeGenHLSL/BasicHLSL11_PS2.hlsl

@@ -20,7 +20,6 @@
 
 // CHECK: DILocalVariable(tag: DW_TAG_auto_variable, name: "vDiffuse"
 // CHECK: DILocalVariable(tag: DW_TAG_auto_variable, name: "fLighting"
-// CHECK: DILocalVariable(tag: DW_TAG_auto_variable, name: "Output"
 
 
 //--------------------------------------------------------------------------------------

+ 25 - 0
tools/clang/test/CodeGenHLSL/class.hlsl

@@ -0,0 +1,25 @@
+// RUN: %dxc -E main -T ps_6_0 %s | FileCheck %s
+
+// Make sure N n2[2] lowered into float [2]
+// CHECK:[2 x float]
+
+  struct N {
+     float n;
+  };
+
+class X {
+  float2x2 ma[2];
+  N n2[2];
+  row_major float3x3 m;
+  N test_inout(float idx) {
+   return n2[idx];
+  }
+};
+
+X x0;
+
+float4 main(float4 a : A, float4 b:B) : SV_TARGET
+{
+  X x = x0;
+  return x.test_inout(a.x).n;
+}

+ 5 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -335,6 +335,7 @@ public:
   TEST_METHOD(CodeGenCbufferAlloc)
   TEST_METHOD(CodeGenCbufferAllocLegacy)
   TEST_METHOD(CodeGenCbufferInLoop)
+  TEST_METHOD(CodeGenClass)
   TEST_METHOD(CodeGenClip)
   TEST_METHOD(CodeGenClipPlanes)
   TEST_METHOD(CodeGenConstoperand1)
@@ -2181,6 +2182,10 @@ TEST_F(CompilerTest, CodeGenCbufferInLoop) {
   CodeGenTest(L"..\\CodeGenHLSL\\cbufferInLoop.hlsl");
 }
 
+TEST_F(CompilerTest, CodeGenClass) {
+  CodeGenTestCheck(L"..\\CodeGenHLSL\\class.hlsl");
+}
+
 TEST_F(CompilerTest, CodeGenClip) {
   CodeGenTestCheck(L"..\\CodeGenHLSL\\clip.hlsl");
 }