CGHLSLMS.cpp 191 KB


  1. //===----- CGHLSLMS.cpp - Interface to HLSL Runtime ----------------===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // CGHLSLMS.cpp //
  5. // Copyright (C) Microsoft Corporation. All rights reserved. //
  6. // Licensed under the MIT license. See COPYRIGHT in the project root for //
  7. // full license information. //
  8. // //
  9. // This provides a class for HLSL code generation. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "CGHLSLRuntime.h"
  13. #include "CodeGenFunction.h"
  14. #include "CodeGenModule.h"
  15. #include "CGRecordLayout.h"
  16. #include "dxc/HlslIntrinsicOp.h"
  17. #include "dxc/HLSL/HLMatrixLowerHelper.h"
  18. #include "dxc/HLSL/HLModule.h"
  19. #include "dxc/HLSL/HLOperations.h"
  20. #include "dxc/HLSL/DXILOperations.h"
  21. #include "dxc/HLSL/DxilTypeSystem.h"
  22. #include "clang/AST/DeclTemplate.h"
  23. #include "clang/AST/HlslTypes.h"
  24. #include "clang/Frontend/CodeGenOptions.h"
  25. #include "llvm/IR/Constants.h"
  26. #include "llvm/IR/IRBuilder.h"
  27. #include "llvm/IR/GetElementPtrTypeIterator.h"
  28. #include <memory>
  29. #include <unordered_map>
  30. #include <unordered_set>
  31. #include "dxc/HLSL/DxilRootSignature.h"
  32. #include "dxc/HLSL/DxilCBuffer.h"
  33. #include "clang/Parse/ParseHLSL.h" // root sig would be in Parser if part of lang
  34. #include "dxc/Support/WinIncludes.h" // stream support
  35. #include "dxc/dxcapi.h" // stream support
  36. using namespace clang;
  37. using namespace CodeGen;
  38. using namespace hlsl;
  39. using namespace llvm;
  40. using std::unique_ptr;
  41. static const bool KeepUndefinedTrue = true; // Keep interpolation mode undefined if not set explicitly.
  42. namespace {
  43. /// Use this class to represent HLSL cbuffer in high-level DXIL.
  44. class HLCBuffer : public DxilCBuffer {
  45. public:
  46. HLCBuffer() = default;
  47. virtual ~HLCBuffer() = default;
  48. void AddConst(std::unique_ptr<DxilResourceBase> &pItem);
  49. std::vector<std::unique_ptr<DxilResourceBase>> &GetConstants();
  50. private:
  51. std::vector<std::unique_ptr<DxilResourceBase>> constants; // constants inside const buffer
  52. };
  53. //------------------------------------------------------------------------------
  54. //
  55. // HLCBuffer methods.
  56. //
  57. void HLCBuffer::AddConst(std::unique_ptr<DxilResourceBase> &pItem) {
  58. pItem->SetID(constants.size());
  59. constants.push_back(std::move(pItem));
  60. }
  61. std::vector<std::unique_ptr<DxilResourceBase>> &HLCBuffer::GetConstants() {
  62. return constants;
  63. }
  64. class CGMSHLSLRuntime : public CGHLSLRuntime {
  65. private:
  66. /// Convenience reference to LLVM Context
  67. llvm::LLVMContext &Context;
  68. /// Convenience reference to the current module
  69. llvm::Module &TheModule;
  70. HLModule *m_pHLModule;
  71. llvm::Type *CBufferType;
  72. uint32_t globalCBIndex;
  73. // TODO: make sure how minprec works
  74. llvm::DataLayout legacyLayout;
  75. // decl map to constant id for program
  76. llvm::DenseMap<HLSLBufferDecl *, uint32_t> constantBufMap;
  77. bool m_bDebugInfo;
  78. HLCBuffer &GetGlobalCBuffer() {
  79. return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(globalCBIndex)));
  80. }
  81. void AddConstant(VarDecl *constDecl, HLCBuffer &CB);
  82. uint32_t AddSampler(VarDecl *samplerDecl);
  83. uint32_t AddUAVSRV(VarDecl *decl, hlsl::DxilResourceBase::Class resClass);
  84. uint32_t AddCBuffer(HLSLBufferDecl *D);
  85. hlsl::DxilResourceBase::Class TypeToClass(clang::QualType Ty);
  86. // Save the entryFunc so don't need to find it with original name.
  87. llvm::Function *EntryFunc;
  88. // Map to save patch constant functions
  89. StringMap<Function *> patchConstantFunctionMap;
  90. bool IsPatchConstantFunction(const Function *F);
  91. // List for functions with clip plane.
  92. std::vector<Function *> clipPlaneFuncList;
  93. std::unordered_map<Value *, DebugLoc> debugInfoMap;
  94. Value *EmitHLSLMatrixLoad(CGBuilderTy &Builder, Value *Ptr, QualType Ty);
  95. void EmitHLSLMatrixStore(CGBuilderTy &Builder, Value *Val, Value *DestPtr,
  96. QualType Ty);
  97. // Flatten the val into scalar val and push into elts and eltTys.
  98. void FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Value *, 4> &elts,
  99. SmallVector<QualType, 4> &eltTys, QualType Ty,
  100. Value *val);
  101. // Push every value on InitListExpr into EltValList and EltTyList.
  102. void ScanInitList(CodeGenFunction &CGF, InitListExpr *E,
  103. SmallVector<Value *, 4> &EltValList,
  104. SmallVector<QualType, 4> &EltTyList);
  105. // Only scan init list to get the element size;
  106. unsigned ScanInitList(InitListExpr *E);
  107. void FlattenAggregatePtrToGepList(CodeGenFunction &CGF, Value *Ptr,
  108. SmallVector<Value *, 4> &idxList,
  109. clang::QualType Type, llvm::Type *Ty,
  110. SmallVector<Value *, 4> &GepList,
  111. SmallVector<QualType, 4> &EltTyList);
  112. void LoadFlattenedGepList(CodeGenFunction &CGF, ArrayRef<Value *> GepList,
  113. ArrayRef<QualType> EltTyList,
  114. SmallVector<Value *, 4> &EltList);
  115. void StoreFlattenedGepList(CodeGenFunction &CGF, ArrayRef<Value *> GepList,
  116. ArrayRef<QualType> GepTyList,
  117. ArrayRef<Value *> EltValList,
  118. ArrayRef<QualType> SrcTyList);
  119. void EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  120. llvm::Value *DestPtr,
  121. SmallVector<Value *, 4> &idxList,
  122. clang::QualType Type,
  123. llvm::Type *Ty);
  124. void EmitHLSLAggregateStore(CodeGenFunction &CGF, llvm::Value *Val,
  125. llvm::Value *DestPtr,
  126. SmallVector<Value *, 4> &idxList,
  127. clang::QualType Type, llvm::Type *Ty);
  128. void EmitHLSLFlatConversionToAggregate(CodeGenFunction &CGF, Value *SrcVal,
  129. llvm::Value *DestPtr,
  130. SmallVector<Value *, 4> &idxList,
  131. QualType Type, QualType SrcType,
  132. llvm::Type *Ty);
  133. void EmitHLSLRootSignature(CodeGenFunction &CGF, HLSLRootSignatureAttr *RSA,
  134. llvm::Function *Fn);
  135. void CheckParameterAnnotation(SourceLocation SLoc,
  136. const DxilParameterAnnotation &paramInfo,
  137. bool isPatchConstantFunction);
  138. void CheckParameterAnnotation(SourceLocation SLoc,
  139. DxilParamInputQual paramQual,
  140. llvm::StringRef semFullName,
  141. bool isPatchConstantFunction);
  142. void SetEntryFunction();
  143. SourceLocation SetSemantic(const NamedDecl *decl,
  144. DxilParameterAnnotation &paramInfo);
  145. hlsl::InterpolationMode GetInterpMode(const Decl *decl, CompType compType,
  146. bool bKeepUndefined);
  147. hlsl::CompType GetCompType(const BuiltinType *BT);
  148. // save intrinsic opcode
  149. std::unordered_map<Function *, unsigned> m_IntrinsicMap;
  150. void AddHLSLIntrinsicOpcodeToFunction(Function *, unsigned opcode);
  151. // Type annotation related.
  152. unsigned ConstructStructAnnotation(DxilStructAnnotation *annotation,
  153. const RecordDecl *RD,
  154. DxilTypeSystem &dxilTypeSys);
  155. unsigned AddTypeAnnotation(QualType Ty, DxilTypeSystem &dxilTypeSys,
  156. unsigned &arrayEltSize);
  157. std::unordered_map<Constant*, DxilFieldAnnotation> m_ConstVarAnnotationMap;
  158. public:
  159. CGMSHLSLRuntime(CodeGenModule &CGM);
  160. bool IsHlslObjectType(llvm::Type * Ty) override;
  161. /// Add resouce to the program
  162. void addResource(Decl *D) override;
  163. void FinishCodeGen() override;
  164. Value *EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E, Value *DestPtr) override;
  165. QualType UpdateHLSLIncompleteArrayType(VarDecl &D) override;
  166. RValue EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF, const FunctionDecl *FD,
  167. const CallExpr *E,
  168. ReturnValueSlot ReturnValue) override;
  169. void EmitHLSLOutParamConversionInit(
  170. CodeGenFunction &CGF, const FunctionDecl *FD, const CallExpr *E,
  171. llvm::SmallVector<LValue, 8> &castArgList,
  172. llvm::SmallVector<const Stmt *, 8> &argList,
  173. const std::function<void(const VarDecl *, llvm::Value *)> &TmpArgMap)
  174. override;
  175. void EmitHLSLOutParamConversionCopyBack(
  176. CodeGenFunction &CGF, llvm::SmallVector<LValue, 8> &castArgList) override;
  177. Value *EmitHLSLMatrixOperationCall(CodeGenFunction &CGF, const clang::Expr *E,
  178. llvm::Type *RetType,
  179. ArrayRef<Value *> paramList) override;
  180. void EmitHLSLDiscard(CodeGenFunction &CGF) override;
  181. Value *EmitHLSLMatrixSubscript(CodeGenFunction &CGF, llvm::Type *RetType,
  182. Value *Ptr, Value *Idx, QualType Ty) override;
  183. Value *EmitHLSLMatrixElement(CodeGenFunction &CGF, llvm::Type *RetType,
  184. ArrayRef<Value *> paramList,
  185. QualType Ty) override;
  186. Value *EmitHLSLMatrixLoad(CodeGenFunction &CGF, Value *Ptr,
  187. QualType Ty) override;
  188. void EmitHLSLMatrixStore(CodeGenFunction &CGF, Value *Val, Value *DestPtr,
  189. QualType Ty) override;
  190. void EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  191. llvm::Value *DestPtr,
  192. clang::QualType Ty) override;
  193. void EmitHLSLAggregateStore(CodeGenFunction &CGF, llvm::Value *Val,
  194. llvm::Value *DestPtr,
  195. clang::QualType Ty) override;
  196. void EmitHLSLFlatConversionToAggregate(CodeGenFunction &CGF, Value *Val,
  197. Value *DestPtr,
  198. QualType Ty,
  199. QualType SrcTy) override;
  200. Value *EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src, QualType SrcType,
  201. QualType DstType) override;
  202. void EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  203. clang::QualType SrcTy,
  204. llvm::Value *DestPtr,
  205. clang::QualType DestTy) override;
  206. void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override;
  207. void EmitHLSLFunctionProlog(llvm::Function *, const FunctionDecl *FD) override;
  208. void AddControlFlowHint(CodeGenFunction &CGF, const Stmt &S,
  209. llvm::TerminatorInst *TI,
  210. ArrayRef<const Attr *> Attrs) override;
  211. void FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) override;
  212. /// Get or add constant to the program
  213. HLCBuffer &GetOrCreateCBuffer(HLSLBufferDecl *D);
  214. };
  215. }
  216. //------------------------------------------------------------------------------
  217. //
  218. // CGMSHLSLRuntime methods.
  219. //
  220. CGMSHLSLRuntime::CGMSHLSLRuntime(CodeGenModule &CGM)
  221. : CGHLSLRuntime(CGM), Context(CGM.getLLVMContext()), EntryFunc(nullptr),
  222. TheModule(CGM.getModule()), legacyLayout(HLModule::GetLegacyDataLayoutDesc()),
  223. CBufferType(
  224. llvm::StructType::create(TheModule.getContext(), "ConstantBuffer")) {
  225. const hlsl::ShaderModel *SM =
  226. hlsl::ShaderModel::GetByName(CGM.getCodeGenOpts().HLSLProfile.c_str());
  227. if (!SM->IsValid()) {
  228. DiagnosticsEngine &Diags = CGM.getDiags();
  229. unsigned DiagID =
  230. Diags.getCustomDiagID(DiagnosticsEngine::Error, "invalid profile %0");
  231. Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLProfile;
  232. }
  233. // TODO: add AllResourceBound.
  234. if (CGM.getCodeGenOpts().HLSLAvoidControlFlow && !CGM.getCodeGenOpts().HLSLAllResourcesBound) {
  235. if (SM->GetMajor() >= 5 && SM->GetMinor() >= 1) {
  236. DiagnosticsEngine &Diags = CGM.getDiags();
  237. unsigned DiagID =
  238. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  239. "Gfa option cannot be used in SM_5_1+ unless "
  240. "all_resources_bound flag is specified");
  241. Diags.Report(DiagID);
  242. }
  243. }
  244. // Create HLModule.
  245. const bool skipInit = true;
  246. m_pHLModule = &TheModule.GetOrCreateHLModule(skipInit);
  247. // Set Option.
  248. HLOptions opts;
  249. opts.bIEEEStrict = CGM.getCodeGenOpts().UnsafeFPMath;
  250. opts.bDefaultRowMajor = CGM.getCodeGenOpts().HLSLDefaultRowMajor;
  251. opts.bDisableOptimizations = CGM.getCodeGenOpts().DisableLLVMOpts;
  252. opts.bLegacyCBufferLoad = !CGM.getCodeGenOpts().HLSLNotUseLegacyCBufLoad;
  253. opts.bAllResourcesBound = CGM.getCodeGenOpts().HLSLAllResourcesBound;
  254. m_pHLModule->SetHLOptions(opts);
  255. m_bDebugInfo = CGM.getCodeGenOpts().getDebugInfo() == CodeGenOptions::FullDebugInfo;
  256. // set profile
  257. m_pHLModule->SetShaderModel(SM);
  258. // set entry name
  259. m_pHLModule->SetEntryFunctionName(CGM.getCodeGenOpts().HLSLEntryFunction);
  260. // add globalCB
  261. unique_ptr<HLCBuffer> CB = std::make_unique<HLCBuffer>();
  262. std::string globalCBName = "$Globals";
  263. CB->SetGlobalSymbol(nullptr);
  264. CB->SetGlobalName(globalCBName);
  265. globalCBIndex = m_pHLModule->GetCBuffers().size();
  266. CB->SetID(globalCBIndex);
  267. CB->SetRangeSize(1);
  268. CB->SetLowerBound(UINT_MAX);
  269. DXVERIFY_NOMSG(globalCBIndex == m_pHLModule->AddCBuffer(std::move(CB)));
  270. }
  271. bool CGMSHLSLRuntime::IsHlslObjectType(llvm::Type *Ty) {
  272. return HLModule::IsHLSLObjectType(Ty);
  273. }
  274. void CGMSHLSLRuntime::AddHLSLIntrinsicOpcodeToFunction(Function *F,
  275. unsigned opcode) {
  276. m_IntrinsicMap[F] = opcode;
  277. }
  278. void CGMSHLSLRuntime::CheckParameterAnnotation(
  279. SourceLocation SLoc, const DxilParameterAnnotation &paramInfo,
  280. bool isPatchConstantFunction) {
  281. if (!paramInfo.HasSemanticString()) {
  282. return;
  283. }
  284. llvm::StringRef semFullName = paramInfo.GetSemanticStringRef();
  285. DxilParamInputQual paramQual = paramInfo.GetParamInputQual();
  286. if (paramQual == DxilParamInputQual::Inout) {
  287. CheckParameterAnnotation(SLoc, DxilParamInputQual::In, semFullName, isPatchConstantFunction);
  288. CheckParameterAnnotation(SLoc, DxilParamInputQual::Out, semFullName, isPatchConstantFunction);
  289. return;
  290. }
  291. CheckParameterAnnotation(SLoc, paramQual, semFullName, isPatchConstantFunction);
  292. }
  293. void CGMSHLSLRuntime::CheckParameterAnnotation(
  294. SourceLocation SLoc, DxilParamInputQual paramQual, llvm::StringRef semFullName,
  295. bool isPatchConstantFunction) {
  296. const ShaderModel *SM = m_pHLModule->GetShaderModel();
  297. DXIL::SigPointKind sigPoint = SigPointFromInputQual(
  298. paramQual, SM->GetKind(), isPatchConstantFunction);
  299. llvm::StringRef semName;
  300. unsigned semIndex;
  301. Semantic::DecomposeNameAndIndex(semFullName, &semName, &semIndex);
  302. const Semantic *pSemantic =
  303. Semantic::GetByName(semName, sigPoint, SM->GetMajor(), SM->GetMinor());
  304. if (pSemantic->IsInvalid()) {
  305. DiagnosticsEngine &Diags = CGM.getDiags();
  306. unsigned DiagID =
  307. Diags.getCustomDiagID(DiagnosticsEngine::Error, "invalid semantic '%0' for %1");
  308. Diags.Report(SLoc, DiagID) << semName << m_pHLModule->GetShaderModel()->GetKindName();
  309. }
  310. }
  311. SourceLocation
  312. CGMSHLSLRuntime::SetSemantic(const NamedDecl *decl,
  313. DxilParameterAnnotation &paramInfo) {
  314. for (const hlsl::UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  315. switch (it->getKind()) {
  316. case hlsl::UnusualAnnotation::UA_SemanticDecl: {
  317. const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
  318. paramInfo.SetSemanticString(sd->SemanticName);
  319. return it->Loc;
  320. }
  321. }
  322. }
  323. return SourceLocation();
  324. }
  325. static bool HasTessFactorSemantic(const ValueDecl *decl) {
  326. for (const hlsl::UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  327. switch (it->getKind()) {
  328. case hlsl::UnusualAnnotation::UA_SemanticDecl: {
  329. const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
  330. const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
  331. if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
  332. return true;
  333. }
  334. }
  335. }
  336. return false;
  337. }
  338. static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
  339. if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
  340. return false;
  341. if (const RecordType *RT = Ty->getAsStructureType()) {
  342. RecordDecl *RD = RT->getDecl();
  343. for (FieldDecl *fieldDecl : RD->fields()) {
  344. if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
  345. return true;
  346. }
  347. return false;
  348. }
  349. if (const clang::ArrayType *arrayTy = Ty->getAsArrayTypeUnsafe())
  350. return HasTessFactorSemantic(decl);
  351. return false;
  352. }
  353. // TODO: get from type annotation.
  354. static bool IsPatchConstantFunctionDecl(const FunctionDecl *FD) {
  355. if (!FD->getReturnType()->isVoidType()) {
  356. // Try to find TessFactor in return type.
  357. if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
  358. return true;
  359. }
  360. // Try to find TessFactor in out param.
  361. for (ParmVarDecl *param : FD->params()) {
  362. if (param->hasAttr<HLSLOutAttr>()) {
  363. if (HasTessFactorSemanticRecurse(param, param->getType()))
  364. return true;
  365. }
  366. }
  367. return false;
  368. }
  369. static DXIL::TessellatorDomain StringToDomain(StringRef domain) {
  370. if (domain == "isoline")
  371. return DXIL::TessellatorDomain::IsoLine;
  372. if (domain == "tri")
  373. return DXIL::TessellatorDomain::Tri;
  374. if (domain == "quad")
  375. return DXIL::TessellatorDomain::Quad;
  376. return DXIL::TessellatorDomain::Undefined;
  377. }
  378. static DXIL::TessellatorPartitioning StringToPartitioning(StringRef partition) {
  379. if (partition == "integer")
  380. return DXIL::TessellatorPartitioning::Integer;
  381. if (partition == "pow2")
  382. return DXIL::TessellatorPartitioning::Pow2;
  383. if (partition == "fractional_even")
  384. return DXIL::TessellatorPartitioning::FractionalEven;
  385. if (partition == "fractional_odd")
  386. return DXIL::TessellatorPartitioning::FractionalOdd;
  387. return DXIL::TessellatorPartitioning::Undefined;
  388. }
  389. static DXIL::TessellatorOutputPrimitive
  390. StringToTessOutputPrimitive(StringRef primitive) {
  391. if (primitive == "point")
  392. return DXIL::TessellatorOutputPrimitive::Point;
  393. if (primitive == "line")
  394. return DXIL::TessellatorOutputPrimitive::Line;
  395. if (primitive == "triangle_cw")
  396. return DXIL::TessellatorOutputPrimitive::TriangleCW;
  397. if (primitive == "triangle_ccw")
  398. return DXIL::TessellatorOutputPrimitive::TriangleCCW;
  399. return DXIL::TessellatorOutputPrimitive::Undefined;
  400. }
  401. static unsigned AlignTo8Bytes(unsigned offset, bool b8BytesAlign) {
  402. DXASSERT((offset & 0x3) == 0, "offset should be divisible by 4");
  403. if (!b8BytesAlign)
  404. return offset;
  405. else if ((offset & 0x7) == 0)
  406. return offset;
  407. else
  408. return offset + 4;
  409. }
  410. static unsigned AlignBaseOffset(unsigned baseOffset, unsigned size,
  411. QualType Ty, bool bDefaultRowMajor) {
  412. bool b8BytesAlign = false;
  413. if (Ty->isBuiltinType()) {
  414. const clang::BuiltinType *BT = Ty->getAs<clang::BuiltinType>();
  415. if (BT->getKind() == clang::BuiltinType::Kind::Double ||
  416. BT->getKind() == clang::BuiltinType::Kind::LongLong)
  417. b8BytesAlign = true;
  418. }
  419. if (unsigned remainder = (baseOffset & 0xf)) {
  420. // Align to 4 x 4 bytes.
  421. unsigned aligned = baseOffset - remainder + 16;
  422. // If cannot fit in the remainder, need align.
  423. bool bNeedAlign = (remainder + size) > 16;
  424. // Array always start aligned.
  425. bNeedAlign |= Ty->isArrayType();
  426. if (IsHLSLMatType(Ty)) {
  427. bool bColMajor = !bDefaultRowMajor;
  428. if (const AttributedType *AT = dyn_cast<AttributedType>(Ty)) {
  429. switch (AT->getAttrKind()) {
  430. case AttributedType::Kind::attr_hlsl_column_major:
  431. bColMajor = true;
  432. break;
  433. case AttributedType::Kind::attr_hlsl_row_major:
  434. bColMajor = false;
  435. break;
  436. default:
  437. // Do nothing
  438. break;
  439. }
  440. }
  441. unsigned row, col;
  442. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  443. bNeedAlign |= bColMajor && col > 1;
  444. bNeedAlign |= !bColMajor && row > 1;
  445. }
  446. if (bNeedAlign)
  447. return AlignTo8Bytes(aligned, b8BytesAlign);
  448. else
  449. return AlignTo8Bytes(baseOffset, b8BytesAlign);
  450. } else
  451. return baseOffset;
  452. }
  453. static unsigned AlignBaseOffset(QualType Ty, unsigned baseOffset,
  454. bool bDefaultRowMajor,
  455. CodeGen::CodeGenModule &CGM,
  456. llvm::DataLayout &layout) {
  457. QualType paramTy = Ty.getCanonicalType();
  458. if (const ReferenceType *RefType = dyn_cast<ReferenceType>(paramTy))
  459. paramTy = RefType->getPointeeType();
  460. // Get size.
  461. llvm::Type *Type = CGM.getTypes().ConvertType(paramTy);
  462. unsigned size = layout.getTypeAllocSize(Type);
  463. return AlignBaseOffset(baseOffset, size, paramTy, bDefaultRowMajor);
  464. }
  465. static unsigned GetMatrixSizeInCB(QualType Ty, bool defaultRowMajor,
  466. bool b64Bit) {
  467. bool bColMajor = !defaultRowMajor;
  468. if (const AttributedType *AT = dyn_cast<AttributedType>(Ty)) {
  469. switch (AT->getAttrKind()) {
  470. case AttributedType::Kind::attr_hlsl_column_major:
  471. bColMajor = true;
  472. break;
  473. case AttributedType::Kind::attr_hlsl_row_major:
  474. bColMajor = false;
  475. break;
  476. default:
  477. // Do nothing
  478. break;
  479. }
  480. }
  481. unsigned row, col;
  482. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  483. unsigned EltSize = b64Bit ? 8 : 4;
  484. // Align to 4 * 4bytes.
  485. unsigned alignment = 4 * 4;
  486. if (bColMajor) {
  487. unsigned rowSize = EltSize * row;
  488. // 3x64bit or 4x64bit align to 32 bytes.
  489. if (rowSize > alignment)
  490. alignment <<= 1;
  491. return alignment * (col - 1) + row * EltSize;
  492. } else {
  493. unsigned rowSize = EltSize * col;
  494. // 3x64bit or 4x64bit align to 32 bytes.
  495. if (rowSize > alignment)
  496. alignment <<= 1;
  497. return alignment * (row - 1) + col * EltSize;
  498. }
  499. }
  500. static CompType::Kind BuiltinTyToCompTy(const BuiltinType *BTy, bool bSNorm,
  501. bool bUNorm) {
  502. CompType::Kind kind = CompType::Kind::Invalid;
  503. switch (BTy->getKind()) {
  504. case BuiltinType::UInt:
  505. kind = CompType::Kind::U32;
  506. break;
  507. case BuiltinType::UShort:
  508. kind = CompType::Kind::U16;
  509. break;
  510. case BuiltinType::ULongLong:
  511. kind = CompType::Kind::U64;
  512. break;
  513. case BuiltinType::Int:
  514. kind = CompType::Kind::I32;
  515. break;
  516. case BuiltinType::Min12Int:
  517. case BuiltinType::Short:
  518. kind = CompType::Kind::I16;
  519. break;
  520. case BuiltinType::LongLong:
  521. kind = CompType::Kind::I64;
  522. break;
  523. case BuiltinType::Min10Float:
  524. case BuiltinType::Half:
  525. if (bSNorm)
  526. kind = CompType::Kind::SNormF16;
  527. else if (bUNorm)
  528. kind = CompType::Kind::UNormF16;
  529. else
  530. kind = CompType::Kind::F16;
  531. break;
  532. case BuiltinType::Float:
  533. if (bSNorm)
  534. kind = CompType::Kind::SNormF32;
  535. else if (bUNorm)
  536. kind = CompType::Kind::UNormF32;
  537. else
  538. kind = CompType::Kind::F32;
  539. break;
  540. case BuiltinType::Double:
  541. if (bSNorm)
  542. kind = CompType::Kind::SNormF64;
  543. else if (bUNorm)
  544. kind = CompType::Kind::UNormF64;
  545. else
  546. kind = CompType::Kind::F64;
  547. break;
  548. case BuiltinType::Bool:
  549. kind = CompType::Kind::I1;
  550. break;
  551. }
  552. return kind;
  553. }
  554. static void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotation, QualType fieldTy, bool bDefaultRowMajor) {
  555. QualType Ty = fieldTy;
  556. if (Ty->isReferenceType())
  557. Ty = Ty.getNonReferenceType();
  558. // Get element type.
  559. if (Ty->isArrayType()) {
  560. while (isa<clang::ArrayType>(Ty)) {
  561. const clang::ArrayType *ATy = dyn_cast<clang::ArrayType>(Ty);
  562. Ty = ATy->getElementType();
  563. }
  564. }
  565. QualType EltTy = Ty;
  566. if (hlsl::IsHLSLMatType(Ty)) {
  567. DxilMatrixAnnotation Matrix;
  568. Matrix.Orientation = bDefaultRowMajor ? MatrixOrientation::RowMajor
  569. : MatrixOrientation::ColumnMajor;
  570. if (const AttributedType *AT = dyn_cast<AttributedType>(Ty)) {
  571. switch (AT->getAttrKind()) {
  572. case AttributedType::Kind::attr_hlsl_column_major:
  573. Matrix.Orientation = MatrixOrientation::ColumnMajor;
  574. break;
  575. case AttributedType::Kind::attr_hlsl_row_major:
  576. Matrix.Orientation = MatrixOrientation::RowMajor;
  577. break;
  578. default:
  579. // Do nothing
  580. break;
  581. }
  582. }
  583. unsigned row, col;
  584. hlsl::GetHLSLMatRowColCount(Ty, row, col);
  585. Matrix.Cols = col;
  586. Matrix.Rows = row;
  587. fieldAnnotation.SetMatrixAnnotation(Matrix);
  588. EltTy = hlsl::GetHLSLMatElementType(Ty);
  589. }
  590. if (hlsl::IsHLSLVecType(Ty))
  591. EltTy = hlsl::GetHLSLVecElementType(Ty);
  592. bool bSNorm = false;
  593. bool bUNorm = false;
  594. if (const AttributedType *AT = dyn_cast<AttributedType>(Ty)) {
  595. switch (AT->getAttrKind()) {
  596. case AttributedType::Kind::attr_hlsl_snorm:
  597. bSNorm = true;
  598. break;
  599. case AttributedType::Kind::attr_hlsl_unorm:
  600. bUNorm = true;
  601. break;
  602. default:
  603. // Do nothing
  604. break;
  605. }
  606. }
  607. if (EltTy->isBuiltinType()) {
  608. const BuiltinType *BTy = EltTy->getAs<BuiltinType>();
  609. CompType::Kind kind = BuiltinTyToCompTy(BTy, bSNorm, bUNorm);
  610. fieldAnnotation.SetCompType(kind);
  611. }
  612. else
  613. DXASSERT(!bSNorm && !bUNorm, "snorm/unorm on invalid type, validate at handleHLSLTypeAttr");
  614. }
  615. static void ConstructFieldInterpolation(DxilFieldAnnotation &fieldAnnotation,
  616. FieldDecl *fieldDecl) {
  617. // Keep undefined for interpMode here.
  618. InterpolationMode InterpMode = {fieldDecl->hasAttr<HLSLNoInterpolationAttr>(),
  619. fieldDecl->hasAttr<HLSLLinearAttr>(),
  620. fieldDecl->hasAttr<HLSLNoPerspectiveAttr>(),
  621. fieldDecl->hasAttr<HLSLCentroidAttr>(),
  622. fieldDecl->hasAttr<HLSLSampleAttr>()};
  623. if (InterpMode.GetKind() != InterpolationMode::Kind::Undefined)
  624. fieldAnnotation.SetInterpolationMode(InterpMode);
  625. }
  626. unsigned CGMSHLSLRuntime::ConstructStructAnnotation(DxilStructAnnotation *annotation,
  627. const RecordDecl *RD,
  628. DxilTypeSystem &dxilTypeSys) {
  629. unsigned fieldIdx = 0;
  630. unsigned offset = 0;
  631. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  632. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  633. if (CXXRD->getNumBases()) {
  634. // Add base as field.
  635. for (const auto &I : CXXRD->bases()) {
  636. const CXXRecordDecl *BaseDecl =
  637. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  638. std::string fieldSemName = "";
  639. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  640. // Align offset.
  641. offset = AlignBaseOffset(parentTy, offset, bDefaultRowMajor, CGM,
  642. legacyLayout);
  643. unsigned CBufferOffset = offset;
  644. unsigned arrayEltSize = 0;
  645. // Process field to make sure the size of field is ready.
  646. unsigned size =
  647. AddTypeAnnotation(parentTy, dxilTypeSys, arrayEltSize);
  648. // Update offset.
  649. offset += size;
  650. if (size > 0) {
  651. DxilFieldAnnotation &fieldAnnotation =
  652. annotation->GetFieldAnnotation(fieldIdx++);
  653. fieldAnnotation.SetCBufferOffset(CBufferOffset);
  654. fieldAnnotation.SetFieldName(BaseDecl->getNameAsString());
  655. }
  656. }
  657. }
  658. }
  659. for (auto fieldDecl : RD->fields()) {
  660. std::string fieldSemName = "";
  661. QualType fieldTy = fieldDecl->getType();
  662. // Align offset.
  663. offset = AlignBaseOffset(fieldTy, offset, bDefaultRowMajor, CGM, legacyLayout);
  664. unsigned CBufferOffset = offset;
  665. bool userOffset = false;
  666. // Try to get info from fieldDecl.
  667. for (const hlsl::UnusualAnnotation *it :
  668. fieldDecl->getUnusualAnnotations()) {
  669. switch (it->getKind()) {
  670. case hlsl::UnusualAnnotation::UA_SemanticDecl: {
  671. const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
  672. fieldSemName = sd->SemanticName;
  673. } break;
  674. case hlsl::UnusualAnnotation::UA_ConstantPacking: {
  675. const hlsl::ConstantPacking *cp = cast<hlsl::ConstantPacking>(it);
  676. CBufferOffset = cp->Subcomponent << 2;
  677. CBufferOffset += cp->ComponentOffset;
  678. // Change to byte.
  679. CBufferOffset <<= 2;
  680. userOffset = true;
  681. } break;
  682. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  683. // register assignment only works on global constant.
  684. DiagnosticsEngine &Diags = CGM.getDiags();
  685. unsigned DiagID = Diags.getCustomDiagID(
  686. DiagnosticsEngine::Error,
  687. "location semantics cannot be specified on members.");
  688. Diags.Report(it->Loc, DiagID);
  689. return 0;
  690. } break;
  691. default:
  692. llvm_unreachable("only semantic for input/output");
  693. break;
  694. }
  695. }
  696. unsigned arrayEltSize = 0;
  697. // Process field to make sure the size of field is ready.
  698. unsigned size = AddTypeAnnotation(fieldDecl->getType(), dxilTypeSys, arrayEltSize);
  699. // Update offset.
  700. offset += size;
  701. DxilFieldAnnotation &fieldAnnotation = annotation->GetFieldAnnotation(fieldIdx++);
  702. ConstructFieldAttributedAnnotation(fieldAnnotation, fieldTy, bDefaultRowMajor);
  703. ConstructFieldInterpolation(fieldAnnotation, fieldDecl);
  704. if (fieldDecl->hasAttr<HLSLPreciseAttr>())
  705. fieldAnnotation.SetPrecise();
  706. fieldAnnotation.SetCBufferOffset(CBufferOffset);
  707. fieldAnnotation.SetFieldName(fieldDecl->getName());
  708. if (!fieldSemName.empty())
  709. fieldAnnotation.SetSemanticString(fieldSemName);
  710. }
  711. annotation->SetCBufferSize(offset);
  712. if (offset == 0) {
  713. annotation->MarkEmptyStruct();
  714. }
  715. return offset;
  716. }
  717. static bool IsElementInputOutputType(QualType Ty) {
  718. return Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty);
  719. }
  720. // Return the size for constant buffer of each decl.
  721. unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
  722. DxilTypeSystem &dxilTypeSys,
  723. unsigned &arrayEltSize) {
  724. QualType paramTy = Ty.getCanonicalType();
  725. if (const ReferenceType *RefType = dyn_cast<ReferenceType>(paramTy))
  726. paramTy = RefType->getPointeeType();
  727. // Get size.
  728. llvm::Type *Type = CGM.getTypes().ConvertType(paramTy);
  729. unsigned size = legacyLayout.getTypeAllocSize(Type);
  730. if (IsHLSLMatType(Ty)) {
  731. unsigned col, row;
  732. llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Type, col, row);
  733. bool b64Bit = legacyLayout.getTypeAllocSize(EltTy) == 8;
  734. size = GetMatrixSizeInCB(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor,
  735. b64Bit);
  736. }
  737. // Skip element types.
  738. if (IsElementInputOutputType(paramTy))
  739. return size;
  740. else if (IsHLSLStreamOutputType(Ty)) {
  741. return AddTypeAnnotation(GetHLSLOutputPatchElementType(Ty), dxilTypeSys,
  742. arrayEltSize);
  743. } else if (IsHLSLInputPatchType(Ty))
  744. return AddTypeAnnotation(GetHLSLInputPatchElementType(Ty), dxilTypeSys,
  745. arrayEltSize);
  746. else if (IsHLSLOutputPatchType(Ty))
  747. return AddTypeAnnotation(GetHLSLOutputPatchElementType(Ty), dxilTypeSys,
  748. arrayEltSize);
  749. else if (const RecordType *RT = paramTy->getAsStructureType()) {
  750. RecordDecl *RD = RT->getDecl();
  751. llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
  752. // Skip if already created.
  753. if (DxilStructAnnotation *annotation = dxilTypeSys.GetStructAnnotation(ST)) {
  754. unsigned structSize = annotation->GetCBufferSize();
  755. return structSize;
  756. }
  757. DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
  758. return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
  759. } else if (const RecordType *RT = dyn_cast<RecordType>(paramTy)) {
  760. // For this pointer.
  761. RecordDecl *RD = RT->getDecl();
  762. llvm::StructType *ST = CGM.getTypes().ConvertRecordDeclType(RD);
  763. // Skip if already created.
  764. if (DxilStructAnnotation *annotation = dxilTypeSys.GetStructAnnotation(ST)) {
  765. unsigned structSize = annotation->GetCBufferSize();
  766. return structSize;
  767. }
  768. DxilStructAnnotation *annotation = dxilTypeSys.AddStructAnnotation(ST);
  769. return ConstructStructAnnotation(annotation, RD, dxilTypeSys);
  770. } else if (IsHLSLResouceType(Ty))
  771. return AddTypeAnnotation(GetHLSLResourceResultType(Ty), dxilTypeSys, arrayEltSize);
  772. else {
  773. unsigned arraySize = 0;
  774. QualType arrayElementTy = Ty;
  775. if (Ty->isConstantArrayType()) {
  776. const ConstantArrayType *arrayTy =
  777. CGM.getContext().getAsConstantArrayType(Ty);
  778. DXASSERT(arrayTy != nullptr, "Must array type here");
  779. arraySize = arrayTy->getSize().getLimitedValue();
  780. arrayElementTy = arrayTy->getElementType();
  781. }
  782. else if (Ty->isIncompleteArrayType()) {
  783. const IncompleteArrayType *arrayTy = CGM.getContext().getAsIncompleteArrayType(Ty);
  784. arrayElementTy = arrayTy->getElementType();
  785. } else
  786. DXASSERT(0, "Must array type here");
  787. unsigned elementSize = AddTypeAnnotation(arrayElementTy, dxilTypeSys, arrayEltSize);
  788. // Only set arrayEltSize once.
  789. if (arrayEltSize == 0)
  790. arrayEltSize = elementSize;
  791. // Align to 4 * 4bytes.
  792. unsigned alignedSize = (elementSize + 15) & 0xfffffff0;
  793. return alignedSize * (arraySize - 1) + elementSize;
  794. }
  795. }
  796. static DxilResource::Kind KeywordToKind(StringRef keyword) {
  797. // TODO: refactor for faster search (switch by 1/2/3 first letters, then
  798. // compare)
  799. if (keyword == "Texture1D" || keyword == "RWTexture1D" || keyword == "RasterizerOrderedTexture1D")
  800. return DxilResource::Kind::Texture1D;
  801. if (keyword == "Texture2D" || keyword == "RWTexture2D" || keyword == "RasterizerOrderedTexture2D")
  802. return DxilResource::Kind::Texture2D;
  803. if (keyword == "Texture2DMS" || keyword == "RWTexture2DMS")
  804. return DxilResource::Kind::Texture2DMS;
  805. if (keyword == "Texture3D" || keyword == "RWTexture3D" || keyword == "RasterizerOrderedTexture3D")
  806. return DxilResource::Kind::Texture3D;
  807. if (keyword == "TextureCube" || keyword == "RWTextureCube")
  808. return DxilResource::Kind::TextureCube;
  809. if (keyword == "Texture1DArray" || keyword == "RWTexture1DArray" || keyword == "RasterizerOrderedTexture1DArray")
  810. return DxilResource::Kind::Texture1DArray;
  811. if (keyword == "Texture2DArray" || keyword == "RWTexture2DArray" || keyword == "RasterizerOrderedTexture2DArray")
  812. return DxilResource::Kind::Texture2DArray;
  813. if (keyword == "Texture2DMSArray" || keyword == "RWTexture2DMSArray")
  814. return DxilResource::Kind::Texture2DMSArray;
  815. if (keyword == "TextureCubeArray" || keyword == "RWTextureCubeArray")
  816. return DxilResource::Kind::TextureCubeArray;
  817. if (keyword == "ByteAddressBuffer" || keyword == "RWByteAddressBuffer" || keyword == "RasterizerOrderedByteAddressBuffer")
  818. return DxilResource::Kind::RawBuffer;
  819. if (keyword == "StructuredBuffer" || keyword == "RWStructuredBuffer" || keyword == "RasterizerOrderedStructuredBuffer")
  820. return DxilResource::Kind::StructuredBuffer;
  821. if (keyword == "AppendStructuredBuffer" || keyword == "ConsumeStructuredBuffer")
  822. return DxilResource::Kind::StructuredBuffer;
  823. // TODO: this is not efficient.
  824. bool isBuffer = keyword == "Buffer";
  825. isBuffer |= keyword == "RWBuffer";
  826. isBuffer |= keyword == "RasterizerOrderedBuffer";
  827. if (isBuffer)
  828. return DxilResource::Kind::TypedBuffer;
  829. return DxilResource::Kind::Invalid;
  830. }
  831. void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
  832. // Add hlsl intrinsic attr
  833. unsigned intrinsicOpcode;
  834. StringRef intrinsicGroup;
  835. if (hlsl::GetIntrinsicOp(FD, intrinsicOpcode, intrinsicGroup)) {
  836. AddHLSLIntrinsicOpcodeToFunction(F, intrinsicOpcode);
  837. F->addFnAttr(hlsl::HLPrefix, intrinsicGroup);
  838. // Save resource type annotation.
  839. if (const CXXMethodDecl *MD = dyn_cast<CXXMethodDecl>(FD)) {
  840. const CXXRecordDecl *RD = MD->getParent();
  841. // For nested case like sample_slice_type.
  842. if (const CXXRecordDecl *PRD = dyn_cast<CXXRecordDecl>(RD->getDeclContext())) {
  843. RD = PRD;
  844. }
  845. QualType recordTy = MD->getASTContext().getRecordType(RD);
  846. hlsl::DxilResourceBase::Class resClass = TypeToClass(recordTy);
  847. llvm::Type *Ty = F->getFunctionType()->params()[0]->getPointerElementType();
  848. // Add resource type annotation.
  849. switch (resClass) {
  850. case DXIL::ResourceClass::Sampler:
  851. m_pHLModule->AddResourceTypeAnnotation(Ty, DXIL::ResourceClass::Sampler,
  852. DXIL::ResourceKind::Sampler);
  853. break;
  854. case DXIL::ResourceClass::UAV:
  855. case DXIL::ResourceClass::SRV: {
  856. hlsl::DxilResource::Kind kind = KeywordToKind(RD->getName());
  857. m_pHLModule->AddResourceTypeAnnotation(Ty, resClass, kind);
  858. } break;
  859. }
  860. }
  861. // Don't need to add FunctionQual for intrinsic function.
  862. return;
  863. }
  864. // Set entry function
  865. const std::string &entryName = m_pHLModule->GetEntryFunctionName();
  866. bool isEntry = FD->getNameAsString() == entryName;
  867. if (isEntry)
  868. EntryFunc = F;
  869. std::unique_ptr<HLFunctionProps> funcProps = std::make_unique<HLFunctionProps>();
  870. // Save patch constant function to patchConstantFunctionMap.
  871. bool isPatchConstantFunction = false;
  872. if (IsPatchConstantFunctionDecl(FD)) {
  873. isPatchConstantFunction = true;
  874. if (patchConstantFunctionMap.count(FD->getName()) == 0)
  875. patchConstantFunctionMap[FD->getName()] = F;
  876. else {
  877. // TODO: This is not the same as how fxc handles patch constant functions.
  878. // This will fail if more than one function with the same name has a SV_TessFactor semantic.
  879. // Fxc just selects the last function defined that has the matching name when referenced
  880. // by the patchconstantfunc attribute from the hull shader currently being compiled.
  881. // Report error
  882. DiagnosticsEngine &Diags = CGM.getDiags();
  883. unsigned DiagID =
  884. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  885. "Multiple definitions for patchconstantfunc.");
  886. Diags.Report(FD->getLocation(), DiagID);
  887. return;
  888. }
  889. for (Argument &arg : F->getArgumentList()) {
  890. const ParmVarDecl *parmDecl = FD->getParamDecl(arg.getArgNo());
  891. QualType Ty = parmDecl->getType();
  892. if (IsHLSLOutputPatchType(Ty)) {
  893. funcProps->ShaderProps.HS.outputControlPoints =
  894. GetHLSLOutputPatchCount(parmDecl->getType());
  895. } else if (IsHLSLInputPatchType(Ty)) {
  896. funcProps->ShaderProps.HS.inputControlPoints =
  897. GetHLSLInputPatchCount(parmDecl->getType());
  898. }
  899. }
  900. }
  901. const ShaderModel *SM = m_pHLModule->GetShaderModel();
  902. // TODO: how to know VS/PS?
  903. funcProps->shaderKind = DXIL::ShaderKind::Invalid;
  904. DiagnosticsEngine &Diags = CGM.getDiags();
  905. // Geometry shader.
  906. bool isGS = false;
  907. if (const HLSLMaxVertexCountAttr *Attr =
  908. FD->getAttr<HLSLMaxVertexCountAttr>()) {
  909. isGS = true;
  910. funcProps->shaderKind = DXIL::ShaderKind::Geometry;
  911. funcProps->ShaderProps.GS.maxVertexCount = Attr->getCount();
  912. if (isEntry && !SM->IsGS()) {
  913. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  914. "attribute maxvertexcount only valid for GS.");
  915. Diags.Report(Attr->getLocation(), DiagID);
  916. return;
  917. }
  918. }
  919. if (const HLSLInstanceAttr *Attr = FD->getAttr<HLSLInstanceAttr>()) {
  920. unsigned instanceCount = Attr->getCount();
  921. funcProps->ShaderProps.GS.instanceCount = instanceCount;
  922. if (isEntry && !SM->IsGS()) {
  923. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  924. "attribute maxvertexcount only valid for GS.");
  925. Diags.Report(Attr->getLocation(), DiagID);
  926. return;
  927. }
  928. }
  929. else {
  930. // Set default instance count.
  931. if (isGS)
  932. funcProps->ShaderProps.GS.instanceCount = 1;
  933. }
  934. // Computer shader.
  935. bool isCS = false;
  936. if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
  937. isCS = true;
  938. funcProps->shaderKind = DXIL::ShaderKind::Compute;
  939. funcProps->ShaderProps.CS.numThreads[0] = Attr->getX();
  940. funcProps->ShaderProps.CS.numThreads[1] = Attr->getY();
  941. funcProps->ShaderProps.CS.numThreads[2] = Attr->getZ();
  942. if (isEntry && !SM->IsCS()) {
  943. unsigned DiagID = Diags.getCustomDiagID(
  944. DiagnosticsEngine::Error, "attribute numthreads only valid for CS.");
  945. Diags.Report(Attr->getLocation(), DiagID);
  946. return;
  947. }
  948. }
  949. // Hull shader.
  950. bool isHS = false;
  951. if (const HLSLPatchConstantFuncAttr *Attr =
  952. FD->getAttr<HLSLPatchConstantFuncAttr>()) {
  953. if (isEntry && !SM->IsHS()) {
  954. unsigned DiagID = Diags.getCustomDiagID(
  955. DiagnosticsEngine::Error,
  956. "attribute patchconstantfunc only valid for HS.");
  957. Diags.Report(Attr->getLocation(), DiagID);
  958. return;
  959. }
  960. isHS = true;
  961. funcProps->shaderKind = DXIL::ShaderKind::Hull;
  962. StringRef funcName = Attr->getFunctionName();
  963. if (patchConstantFunctionMap.count(funcName) == 1) {
  964. Function *patchConstFunc = patchConstantFunctionMap[funcName];
  965. funcProps->ShaderProps.HS.patchConstantFunc = patchConstFunc;
  966. DXASSERT_NOMSG(m_pHLModule->HasHLFunctionProps(patchConstFunc));
  967. // Check no inout parameter for patch constant function.
  968. DxilFunctionAnnotation *patchConstFuncAnnotation =
  969. m_pHLModule->GetFunctionAnnotation(patchConstFunc);
  970. for (unsigned i = 0; i < patchConstFuncAnnotation->GetNumParameters();
  971. i++) {
  972. if (patchConstFuncAnnotation->GetParameterAnnotation(i)
  973. .GetParamInputQual() == DxilParamInputQual::Inout) {
  974. unsigned DiagID = Diags.getCustomDiagID(
  975. DiagnosticsEngine::Error,
  976. "Patch Constant function should not have inout param.");
  977. Diags.Report(Attr->getLocation(), DiagID);
  978. return;
  979. }
  980. }
  981. } else {
  982. // TODO: Bring this in line with fxc behavior. In fxc, patchconstantfunc
  983. // selection is based only on name (last function with matching name),
  984. // not by whether it has SV_TessFactor output.
  985. //// Report error
  986. // DiagnosticsEngine &Diags = CGM.getDiags();
  987. // unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  988. // "Cannot find
  989. // patchconstantfunc.");
  990. // Diags.Report(Attr->getLocation(), DiagID);
  991. }
  992. }
  993. if (const HLSLOutputControlPointsAttr *Attr =
  994. FD->getAttr<HLSLOutputControlPointsAttr>()) {
  995. if (isHS) {
  996. funcProps->ShaderProps.HS.outputControlPoints = Attr->getCount();
  997. } else if (isEntry && !SM->IsHS()) {
  998. unsigned DiagID = Diags.getCustomDiagID(
  999. DiagnosticsEngine::Error,
  1000. "attribute outputcontrolpoints only valid for HS.");
  1001. Diags.Report(Attr->getLocation(), DiagID);
  1002. return;
  1003. }
  1004. }
  1005. if (const HLSLPartitioningAttr *Attr = FD->getAttr<HLSLPartitioningAttr>()) {
  1006. if (isHS) {
  1007. DXIL::TessellatorPartitioning partition =
  1008. StringToPartitioning(Attr->getScheme());
  1009. funcProps->ShaderProps.HS.partition = partition;
  1010. } else if (isEntry && !SM->IsHS()) {
  1011. unsigned DiagID =
  1012. Diags.getCustomDiagID(DiagnosticsEngine::Warning,
  1013. "attribute partitioning only valid for HS.");
  1014. Diags.Report(Attr->getLocation(), DiagID);
  1015. }
  1016. }
  1017. if (const HLSLOutputTopologyAttr *Attr =
  1018. FD->getAttr<HLSLOutputTopologyAttr>()) {
  1019. if (isHS) {
  1020. DXIL::TessellatorOutputPrimitive primitive =
  1021. StringToTessOutputPrimitive(Attr->getTopology());
  1022. funcProps->ShaderProps.HS.outputPrimitive = primitive;
  1023. } else if (isEntry && !SM->IsHS()) {
  1024. unsigned DiagID =
  1025. Diags.getCustomDiagID(DiagnosticsEngine::Warning,
  1026. "attribute outputtopology only valid for HS.");
  1027. Diags.Report(Attr->getLocation(), DiagID);
  1028. }
  1029. }
  1030. if (isHS) {
  1031. funcProps->ShaderProps.HS.maxTessFactor = DXIL::kHSMaxTessFactorUpperBound;
  1032. }
  1033. if (const HLSLMaxTessFactorAttr *Attr =
  1034. FD->getAttr<HLSLMaxTessFactorAttr>()) {
  1035. if (isHS) {
  1036. // TODO: change getFactor to return float.
  1037. llvm::APInt intV(32, Attr->getFactor());
  1038. funcProps->ShaderProps.HS.maxTessFactor = intV.bitsToFloat();
  1039. } else if (isEntry && !SM->IsHS()) {
  1040. unsigned DiagID =
  1041. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1042. "attribute maxtessfactor only valid for HS.");
  1043. Diags.Report(Attr->getLocation(), DiagID);
  1044. return;
  1045. }
  1046. }
  1047. // Hull or domain shader.
  1048. bool isDS = false;
  1049. if (const HLSLDomainAttr *Attr = FD->getAttr<HLSLDomainAttr>()) {
  1050. if (isEntry && !SM->IsHS() && !SM->IsDS()) {
  1051. unsigned DiagID =
  1052. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1053. "attribute domain only valid for HS or DS.");
  1054. Diags.Report(Attr->getLocation(), DiagID);
  1055. return;
  1056. }
  1057. isDS = !isHS;
  1058. if (isDS)
  1059. funcProps->shaderKind = DXIL::ShaderKind::Domain;
  1060. DXIL::TessellatorDomain domain = StringToDomain(Attr->getDomainType());
  1061. if (isHS)
  1062. funcProps->ShaderProps.HS.domain = domain;
  1063. else
  1064. funcProps->ShaderProps.DS.domain = domain;
  1065. }
  1066. // Vertex shader.
  1067. bool isVS = false;
  1068. if (const HLSLClipPlanesAttr *Attr = FD->getAttr<HLSLClipPlanesAttr>()) {
  1069. if (isEntry && !SM->IsVS()) {
  1070. unsigned DiagID =
  1071. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1072. "attribute clipplane only valid for VS.");
  1073. Diags.Report(Attr->getLocation(), DiagID);
  1074. return;
  1075. }
  1076. isVS = true;
  1077. // The real job is done at EmitHLSLFunctionProlog where debug info is available.
  1078. // Only set shader kind here.
  1079. funcProps->shaderKind = DXIL::ShaderKind::Vertex;
  1080. }
  1081. // Pixel shader.
  1082. bool isPS = false;
  1083. if (const HLSLEarlyDepthStencilAttr *Attr = FD->getAttr<HLSLEarlyDepthStencilAttr>()) {
  1084. if (isEntry && !SM->IsPS()) {
  1085. unsigned DiagID =
  1086. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1087. "attribute earlydepthstencil only valid for PS.");
  1088. Diags.Report(Attr->getLocation(), DiagID);
  1089. return;
  1090. }
  1091. isPS = true;
  1092. funcProps->ShaderProps.PS.EarlyDepthStencil = true;
  1093. funcProps->shaderKind = DXIL::ShaderKind::Pixel;
  1094. }
  1095. unsigned profileAttributes = 0;
  1096. if (isCS)
  1097. profileAttributes++;
  1098. if (isHS)
  1099. profileAttributes++;
  1100. if (isDS)
  1101. profileAttributes++;
  1102. if (isGS)
  1103. profileAttributes++;
  1104. if (isVS)
  1105. profileAttributes++;
  1106. if (isPS)
  1107. profileAttributes++;
  1108. // TODO: check this in front-end and report error.
  1109. DXASSERT(profileAttributes<2, "profile attributes are mutual exclusive");
  1110. if (isEntry) {
  1111. switch (funcProps->shaderKind) {
  1112. case ShaderModel::Kind::Compute:
  1113. case ShaderModel::Kind::Hull:
  1114. case ShaderModel::Kind::Domain:
  1115. case ShaderModel::Kind::Geometry:
  1116. case ShaderModel::Kind::Vertex:
  1117. case ShaderModel::Kind::Pixel:
  1118. DXASSERT(funcProps->shaderKind == SM->GetKind(),
  1119. "attribute profile not match entry function profile");
  1120. break;
  1121. }
  1122. }
  1123. DxilFunctionAnnotation *FuncAnnotation = m_pHLModule->AddFunctionAnnotation(F);
  1124. // Ret Info
  1125. DxilParameterAnnotation &retTyAnnotation = FuncAnnotation->GetRetTypeAnnotation();
  1126. QualType retTy = FD->getReturnType();
  1127. // keep Undefined here, we cannot decide for struct
  1128. retTyAnnotation.SetInterpolationMode(
  1129. GetInterpMode(FD, CompType::Kind::Invalid, /*bKeepUndefined*/ true)
  1130. .GetKind());
  1131. SourceLocation retTySemanticLoc = SetSemantic(FD, retTyAnnotation);
  1132. retTyAnnotation.SetParamInputQual(DxilParamInputQual::Out);
  1133. if (isEntry) {
  1134. CheckParameterAnnotation(retTySemanticLoc, retTyAnnotation, /*isPatchConstantFunction*/false);
  1135. }
  1136. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  1137. ConstructFieldAttributedAnnotation(retTyAnnotation, retTy, bDefaultRowMajor);
  1138. if (FD->hasAttr<HLSLPreciseAttr>())
  1139. retTyAnnotation.SetPrecise();
  1140. // Param Info
  1141. unsigned streamIndex = 0;
  1142. unsigned inputPatchCount = 0;
  1143. unsigned outputPatchCount = 0;
  1144. unsigned primitiveCount = 0;
  1145. for (unsigned ArgNo = 0; ArgNo < F->arg_size(); ++ArgNo) {
  1146. unsigned ParmIdx = ArgNo;
  1147. DxilParameterAnnotation &paramAnnotation = FuncAnnotation->GetParameterAnnotation(ArgNo);
  1148. if (isa<CXXMethodDecl>(FD)) {
  1149. // skip arg0 for this pointer
  1150. if (ArgNo == 0)
  1151. continue;
  1152. // update idx for rest params
  1153. ParmIdx--;
  1154. }
  1155. const ParmVarDecl *parmDecl = FD->getParamDecl(ParmIdx);
  1156. ConstructFieldAttributedAnnotation(paramAnnotation, parmDecl->getType(), bDefaultRowMajor);
  1157. if (parmDecl->hasAttr<HLSLPreciseAttr>())
  1158. paramAnnotation.SetPrecise();
  1159. // keep Undefined here, we cannot decide for struct
  1160. InterpolationMode paramIM =
  1161. GetInterpMode(parmDecl, CompType::Kind::Invalid, KeepUndefinedTrue);
  1162. paramAnnotation.SetInterpolationMode(paramIM);
  1163. SourceLocation paramSemanticLoc = SetSemantic(parmDecl, paramAnnotation);
  1164. DxilParamInputQual dxilInputQ = DxilParamInputQual::In;
  1165. if (parmDecl->hasAttr<HLSLInOutAttr>())
  1166. dxilInputQ = DxilParamInputQual::Inout;
  1167. else if (parmDecl->hasAttr<HLSLOutAttr>())
  1168. dxilInputQ = DxilParamInputQual::Out;
  1169. if (IsHLSLOutputPatchType(parmDecl->getType())) {
  1170. outputPatchCount++;
  1171. if (dxilInputQ != DxilParamInputQual::In) {
  1172. unsigned DiagID = Diags.getCustomDiagID(
  1173. DiagnosticsEngine::Error, "OutputPatch should not be out/inout parameter");
  1174. Diags.Report(parmDecl->getLocation(), DiagID);
  1175. continue;
  1176. }
  1177. dxilInputQ = DxilParamInputQual::OutputPatch;
  1178. if (isDS)
  1179. funcProps->ShaderProps.DS.inputControlPoints =
  1180. GetHLSLOutputPatchCount(parmDecl->getType());
  1181. }
  1182. else if (IsHLSLInputPatchType(parmDecl->getType())) {
  1183. inputPatchCount++;
  1184. if (dxilInputQ != DxilParamInputQual::In) {
  1185. unsigned DiagID = Diags.getCustomDiagID(
  1186. DiagnosticsEngine::Error, "InputPatch should not be out/inout parameter");
  1187. Diags.Report(parmDecl->getLocation(), DiagID);
  1188. continue;
  1189. }
  1190. dxilInputQ = DxilParamInputQual::InputPatch;
  1191. if (isHS) {
  1192. funcProps->ShaderProps.HS.inputControlPoints =
  1193. GetHLSLInputPatchCount(parmDecl->getType());
  1194. }
  1195. else if (isGS) {
  1196. if (funcProps->ShaderProps.GS.inputPrimitive !=
  1197. DXIL::InputPrimitive::Undefined) {
  1198. DiagnosticsEngine &Diags = CGM.getDiags();
  1199. unsigned DiagID =
  1200. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1201. "may only have one InputPatch parameter");
  1202. Diags.Report(FD->getLocation(), DiagID);
  1203. }
  1204. funcProps->ShaderProps.GS.inputPrimitive = (DXIL::InputPrimitive)(
  1205. (unsigned)DXIL::InputPrimitive::ControlPointPatch1 +
  1206. GetHLSLInputPatchCount(parmDecl->getType())-1);
  1207. // Set to InputPrimitive for GS.
  1208. dxilInputQ = DxilParamInputQual::InputPrimitive;
  1209. }
  1210. }
  1211. else if (IsHLSLStreamOutputType(parmDecl->getType())) {
  1212. // TODO: validation this at ASTContext::getFunctionType in AST/ASTContext.cpp
  1213. DXASSERT(dxilInputQ == DxilParamInputQual::Inout, "stream output parameter must be inout");
  1214. switch (streamIndex) {
  1215. case 0:
  1216. dxilInputQ = DxilParamInputQual::OutStream0;
  1217. break;
  1218. case 1:
  1219. dxilInputQ = DxilParamInputQual::OutStream1;
  1220. break;
  1221. case 2:
  1222. dxilInputQ = DxilParamInputQual::OutStream2;
  1223. break;
  1224. case 3:
  1225. default:
  1226. // TODO: validation this at ASTContext::getFunctionType in AST/ASTContext.cpp
  1227. DXASSERT(streamIndex==3, "stream number out of bound");
  1228. dxilInputQ = DxilParamInputQual::OutStream3;
  1229. break;
  1230. }
  1231. DXIL::PrimitiveTopology &streamTopology = funcProps->ShaderProps.GS.streamPrimitiveTopologies[streamIndex];
  1232. if (IsHLSLPointStreamType(parmDecl->getType()))
  1233. streamTopology = DXIL::PrimitiveTopology::PointList;
  1234. else if (IsHLSLLineStreamType(parmDecl->getType()))
  1235. streamTopology = DXIL::PrimitiveTopology::LineStrip;
  1236. else {
  1237. DXASSERT(IsHLSLTriangleStreamType(parmDecl->getType()), "invalid StreamType");
  1238. streamTopology = DXIL::PrimitiveTopology::TriangleStrip;
  1239. }
  1240. if (streamIndex > 0) {
  1241. bool bAllPoint = streamTopology == DXIL::PrimitiveTopology::PointList &&
  1242. funcProps->ShaderProps.GS.streamPrimitiveTopologies[0] == DXIL::PrimitiveTopology::PointList;
  1243. if (!bAllPoint) {
  1244. DiagnosticsEngine &Diags = CGM.getDiags();
  1245. unsigned DiagID = Diags.getCustomDiagID(
  1246. DiagnosticsEngine::Error,
  1247. "when multiple GS output streams are used they must be pointlists.");
  1248. Diags.Report(FD->getLocation(), DiagID);
  1249. }
  1250. }
  1251. streamIndex++;
  1252. }
  1253. if (parmDecl->hasAttr<HLSLTriangleAttr>()) {
  1254. funcProps->ShaderProps.GS.inputPrimitive = DXIL::InputPrimitive::Triangle;
  1255. dxilInputQ = DxilParamInputQual::InputPrimitive;
  1256. primitiveCount++;
  1257. } else if (parmDecl->hasAttr<HLSLTriangleAdjAttr>()) {
  1258. funcProps->ShaderProps.GS.inputPrimitive =
  1259. DXIL::InputPrimitive::TriangleWithAdjacency;
  1260. dxilInputQ = DxilParamInputQual::InputPrimitive;
  1261. primitiveCount++;
  1262. } else if (parmDecl->hasAttr<HLSLPointAttr>()) {
  1263. funcProps->ShaderProps.GS.inputPrimitive = DXIL::InputPrimitive::Point;
  1264. dxilInputQ = DxilParamInputQual::InputPrimitive;
  1265. primitiveCount++;
  1266. }
  1267. paramAnnotation.SetParamInputQual(dxilInputQ);
  1268. if (isEntry) {
  1269. CheckParameterAnnotation(paramSemanticLoc, paramAnnotation, /*isPatchConstantFunction*/false);
  1270. }
  1271. }
  1272. if (inputPatchCount > 1) {
  1273. DiagnosticsEngine &Diags = CGM.getDiags();
  1274. unsigned DiagID = Diags.getCustomDiagID(
  1275. DiagnosticsEngine::Error, "may only have one InputPatch parameter");
  1276. Diags.Report(FD->getLocation(), DiagID);
  1277. }
  1278. if (outputPatchCount > 1) {
  1279. DiagnosticsEngine &Diags = CGM.getDiags();
  1280. unsigned DiagID = Diags.getCustomDiagID(
  1281. DiagnosticsEngine::Error, "may only have one OutputPatch parameter");
  1282. Diags.Report(FD->getLocation(), DiagID);
  1283. }
  1284. primitiveCount += inputPatchCount;
  1285. if (primitiveCount > 1 && inputPatchCount < 2) {
  1286. DiagnosticsEngine &Diags = CGM.getDiags();
  1287. unsigned DiagID = Diags.getCustomDiagID(
  1288. DiagnosticsEngine::Error, "may only have one Primitive parameter");
  1289. Diags.Report(FD->getLocation(), DiagID);
  1290. }
  1291. // Type annotation for parameters and return type.
  1292. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  1293. unsigned arrayEltSize = 0;
  1294. AddTypeAnnotation(FD->getReturnType(), dxilTypeSys, arrayEltSize);
  1295. // Type annotation for this pointer.
  1296. if (const CXXMethodDecl *MFD = dyn_cast<CXXMethodDecl>(FD)) {
  1297. const CXXRecordDecl *RD = MFD->getParent();
  1298. QualType Ty = CGM.getContext().getTypeDeclType(RD);
  1299. AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
  1300. }
  1301. for (const ValueDecl*param : FD->params()) {
  1302. QualType Ty = param->getType();
  1303. AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
  1304. }
  1305. if (isHS) {
  1306. // Check
  1307. Function *patchConstFunc = funcProps->ShaderProps.HS.patchConstantFunc;
  1308. if (m_pHLModule->HasHLFunctionProps(patchConstFunc)) {
  1309. HLFunctionProps &patchProps =
  1310. m_pHLModule->GetHLFunctionProps(patchConstFunc);
  1311. if (patchProps.ShaderProps.HS.outputControlPoints != 0 &&
  1312. patchProps.ShaderProps.HS.outputControlPoints !=
  1313. funcProps->ShaderProps.HS.outputControlPoints) {
  1314. unsigned DiagID = Diags.getCustomDiagID(
  1315. DiagnosticsEngine::Error,
  1316. "Patch constant function's output patch input "
  1317. "should have %0 elements, but has %1.");
  1318. Diags.Report(FD->getLocation(), DiagID)
  1319. << funcProps->ShaderProps.HS.outputControlPoints
  1320. << patchProps.ShaderProps.HS.outputControlPoints;
  1321. }
  1322. if (patchProps.ShaderProps.HS.inputControlPoints != 0 &&
  1323. patchProps.ShaderProps.HS.inputControlPoints !=
  1324. funcProps->ShaderProps.HS.inputControlPoints) {
  1325. unsigned DiagID =
  1326. Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1327. "Patch constant function's input patch input "
  1328. "should have %0 elements, but has %1.");
  1329. Diags.Report(FD->getLocation(), DiagID)
  1330. << funcProps->ShaderProps.HS.inputControlPoints
  1331. << patchProps.ShaderProps.HS.inputControlPoints;
  1332. }
  1333. }
  1334. }
  1335. // Only add functionProps when exist.
  1336. if (profileAttributes || isPatchConstantFunction)
  1337. m_pHLModule->AddHLFunctionProps(F, funcProps);
  1338. }
  1339. void CGMSHLSLRuntime::EmitHLSLFunctionProlog(Function *F, const FunctionDecl *FD) {
  1340. // Support clip plane need debug info which not available when create function attribute.
  1341. if (const HLSLClipPlanesAttr *Attr = FD->getAttr<HLSLClipPlanesAttr>()) {
  1342. HLFunctionProps &funcProps = m_pHLModule->GetHLFunctionProps(F);
  1343. // Initialize to null.
  1344. memset(funcProps.ShaderProps.VS.clipPlanes, 0, sizeof(funcProps.ShaderProps.VS.clipPlanes));
  1345. // Create global for each clip plane, and use the clip plane val as init val.
  1346. auto AddClipPlane = [&](Expr *clipPlane, unsigned idx) {
  1347. if (DeclRefExpr *decl = dyn_cast<DeclRefExpr>(clipPlane)) {
  1348. const VarDecl *VD = cast<VarDecl>(decl->getDecl());
  1349. Constant *clipPlaneVal = CGM.GetAddrOfGlobalVar(VD);
  1350. funcProps.ShaderProps.VS.clipPlanes[idx] = clipPlaneVal;
  1351. if (m_bDebugInfo) {
  1352. CodeGenFunction CGF(CGM);
  1353. ApplyDebugLocation applyDebugLoc(CGF, clipPlane);
  1354. debugInfoMap[clipPlaneVal] = CGF.Builder.getCurrentDebugLocation();
  1355. }
  1356. } else {
  1357. // Must be a MemberExpr.
  1358. const MemberExpr *ME = cast<MemberExpr>(clipPlane);
  1359. CodeGenFunction CGF(CGM);
  1360. CodeGen::LValue LV = CGF.EmitMemberExpr(ME);
  1361. Value *addr = LV.getAddress();
  1362. funcProps.ShaderProps.VS.clipPlanes[idx] = cast<Constant>(addr);
  1363. if (m_bDebugInfo) {
  1364. CodeGenFunction CGF(CGM);
  1365. ApplyDebugLocation applyDebugLoc(CGF, clipPlane);
  1366. debugInfoMap[addr] = CGF.Builder.getCurrentDebugLocation();
  1367. }
  1368. }
  1369. };
  1370. if (Expr *clipPlane = Attr->getClipPlane1())
  1371. AddClipPlane(clipPlane, 0);
  1372. if (Expr *clipPlane = Attr->getClipPlane2())
  1373. AddClipPlane(clipPlane, 1);
  1374. if (Expr *clipPlane = Attr->getClipPlane3())
  1375. AddClipPlane(clipPlane, 2);
  1376. if (Expr *clipPlane = Attr->getClipPlane4())
  1377. AddClipPlane(clipPlane, 3);
  1378. if (Expr *clipPlane = Attr->getClipPlane5())
  1379. AddClipPlane(clipPlane, 4);
  1380. if (Expr *clipPlane = Attr->getClipPlane6())
  1381. AddClipPlane(clipPlane, 5);
  1382. clipPlaneFuncList.emplace_back(F);
  1383. }
  1384. }
  1385. void CGMSHLSLRuntime::AddControlFlowHint(CodeGenFunction &CGF, const Stmt &S,
  1386. llvm::TerminatorInst *TI,
  1387. ArrayRef<const Attr *> Attrs) {
  1388. // Build hints.
  1389. bool bNoBranchFlatten = true;
  1390. bool bBranch = false;
  1391. bool bFlatten = false;
  1392. std::vector<DXIL::ControlFlowHint> hints;
  1393. for (const auto *Attr : Attrs) {
  1394. if (isa<HLSLBranchAttr>(Attr)) {
  1395. hints.emplace_back(DXIL::ControlFlowHint::Branch);
  1396. bNoBranchFlatten = false;
  1397. bBranch = true;
  1398. }
  1399. else if (isa<HLSLFlattenAttr>(Attr)) {
  1400. hints.emplace_back(DXIL::ControlFlowHint::Flatten);
  1401. bNoBranchFlatten = false;
  1402. bFlatten = true;
  1403. } else if (isa<HLSLForceCaseAttr>(Attr)) {
  1404. if (isa<SwitchStmt>(&S)) {
  1405. hints.emplace_back(DXIL::ControlFlowHint::ForceCase);
  1406. }
  1407. }
  1408. // Ignore fastopt, allow_uav_condition and call for now.
  1409. }
  1410. if (bNoBranchFlatten) {
  1411. // CHECK control flow option.
  1412. if (CGF.CGM.getCodeGenOpts().HLSLPreferControlFlow)
  1413. hints.emplace_back(DXIL::ControlFlowHint::Branch);
  1414. else if (CGF.CGM.getCodeGenOpts().HLSLAvoidControlFlow)
  1415. hints.emplace_back(DXIL::ControlFlowHint::Flatten);
  1416. }
  1417. if (bFlatten && bBranch) {
  1418. DiagnosticsEngine &Diags = CGM.getDiags();
  1419. unsigned DiagID = Diags.getCustomDiagID(
  1420. DiagnosticsEngine::Error,
  1421. "can't use branch and flatten attributes together");
  1422. Diags.Report(S.getLocStart(), DiagID);
  1423. }
  1424. if (hints.size()) {
  1425. // Add meta data to the instruction.
  1426. MDNode *hintsNode = DxilMDHelper::EmitControlFlowHints(Context, hints);
  1427. TI->setMetadata(DxilMDHelper::kDxilControlFlowHintMDName, hintsNode);
  1428. }
  1429. }
  1430. void CGMSHLSLRuntime::FinishAutoVar(CodeGenFunction &CGF, const VarDecl &D, llvm::Value *V) {
  1431. if (D.hasAttr<HLSLPreciseAttr>()) {
  1432. AllocaInst *AI = cast<AllocaInst>(V);
  1433. HLModule::MarkPreciseAttributeWithMetadata(AI);
  1434. }
  1435. // Add type annotation for local variable.
  1436. DxilTypeSystem &typeSys = m_pHLModule->GetTypeSystem();
  1437. unsigned arrayEltSize = 0;
  1438. AddTypeAnnotation(D.getType(), typeSys, arrayEltSize);
  1439. }
  1440. hlsl::InterpolationMode CGMSHLSLRuntime::GetInterpMode(const Decl *decl,
  1441. CompType compType,
  1442. bool bKeepUndefined) {
  1443. InterpolationMode Interp(
  1444. decl->hasAttr<HLSLNoInterpolationAttr>(), decl->hasAttr<HLSLLinearAttr>(),
  1445. decl->hasAttr<HLSLNoPerspectiveAttr>(), decl->hasAttr<HLSLCentroidAttr>(),
  1446. decl->hasAttr<HLSLSampleAttr>());
  1447. DXASSERT(Interp.IsValid(), "otherwise front-end missing validation");
  1448. if (Interp.IsUndefined() && !bKeepUndefined) {
  1449. // Type-based default: linear for floats, constant for others.
  1450. if (compType.IsFloatTy())
  1451. Interp = InterpolationMode::Kind::Linear;
  1452. else
  1453. Interp = InterpolationMode::Kind::Constant;
  1454. }
  1455. return Interp;
  1456. }
  1457. hlsl::CompType CGMSHLSLRuntime::GetCompType(const BuiltinType *BT) {
  1458. hlsl::CompType ElementType = hlsl::CompType::getInvalid();
  1459. switch (BT->getKind()) {
  1460. case BuiltinType::Bool:
  1461. ElementType = hlsl::CompType::getI1();
  1462. break;
  1463. case BuiltinType::Double:
  1464. ElementType = hlsl::CompType::getF64();
  1465. break;
  1466. case BuiltinType::Float:
  1467. ElementType = hlsl::CompType::getF32();
  1468. break;
  1469. case BuiltinType::Min10Float:
  1470. case BuiltinType::Half:
  1471. ElementType = hlsl::CompType::getF16();
  1472. break;
  1473. case BuiltinType::Int:
  1474. ElementType = hlsl::CompType::getI32();
  1475. break;
  1476. case BuiltinType::LongLong:
  1477. ElementType = hlsl::CompType::getI64();
  1478. break;
  1479. case BuiltinType::Min12Int:
  1480. case BuiltinType::Short:
  1481. ElementType = hlsl::CompType::getI16();
  1482. break;
  1483. case BuiltinType::UInt:
  1484. ElementType = hlsl::CompType::getU32();
  1485. break;
  1486. case BuiltinType::ULongLong:
  1487. ElementType = hlsl::CompType::getU64();
  1488. break;
  1489. case BuiltinType::UShort:
  1490. ElementType = hlsl::CompType::getU16();
  1491. break;
  1492. default:
  1493. llvm_unreachable("unsupported type");
  1494. break;
  1495. }
  1496. return ElementType;
  1497. }
  1498. /// Add resouce to the program
  1499. void CGMSHLSLRuntime::addResource(Decl *D) {
  1500. if (HLSLBufferDecl *BD = dyn_cast<HLSLBufferDecl>(D))
  1501. GetOrCreateCBuffer(BD);
  1502. else if (VarDecl *VD = dyn_cast<VarDecl>(D)) {
  1503. hlsl::DxilResourceBase::Class resClass = TypeToClass(VD->getType());
  1504. // skip decl has init which is resource.
  1505. if (VD->hasInit() && resClass != DXIL::ResourceClass::Invalid)
  1506. return;
  1507. // skip static global.
  1508. if (!VD->isExternallyVisible())
  1509. return;
  1510. if (D->hasAttr<HLSLGroupSharedAttr>()) {
  1511. GlobalVariable *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(VD));
  1512. m_pHLModule->AddGroupSharedVariable(GV);
  1513. return;
  1514. }
  1515. switch (resClass) {
  1516. case hlsl::DxilResourceBase::Class::Sampler:
  1517. AddSampler(VD);
  1518. break;
  1519. case hlsl::DxilResourceBase::Class::UAV:
  1520. case hlsl::DxilResourceBase::Class::SRV:
  1521. AddUAVSRV(VD, resClass);
  1522. break;
  1523. case hlsl::DxilResourceBase::Class::Invalid: {
  1524. // normal global constant, add to global CB
  1525. HLCBuffer &globalCB = GetGlobalCBuffer();
  1526. AddConstant(VD, globalCB);
  1527. break;
  1528. }
  1529. case DXIL::ResourceClass::CBuffer:
  1530. DXASSERT(0, "cbuffer should not be here");
  1531. break;
  1532. }
  1533. }
  1534. }
  1535. // TODO: collect such helper utility functions in one place.
  1536. static DxilResourceBase::Class KeywordToClass(const std::string &keyword) {
  1537. // TODO: refactor for faster search (switch by 1/2/3 first letters, then
  1538. // compare)
  1539. if (keyword == "SamplerState")
  1540. return DxilResourceBase::Class::Sampler;
  1541. if (keyword == "SamplerComparisonState")
  1542. return DxilResourceBase::Class::Sampler;
  1543. if (keyword == "ConstantBuffer")
  1544. return DxilResourceBase::Class::CBuffer;
  1545. if (keyword == "TextureBuffer")
  1546. return DxilResourceBase::Class::SRV;
  1547. bool isSRV = keyword == "Buffer";
  1548. isSRV |= keyword == "ByteAddressBuffer";
  1549. isSRV |= keyword == "StructuredBuffer";
  1550. isSRV |= keyword == "Texture1D";
  1551. isSRV |= keyword == "Texture1DArray";
  1552. isSRV |= keyword == "Texture2D";
  1553. isSRV |= keyword == "Texture2DArray";
  1554. isSRV |= keyword == "Texture3D";
  1555. isSRV |= keyword == "TextureCube";
  1556. isSRV |= keyword == "TextureCubeArray";
  1557. isSRV |= keyword == "Texture2DMS";
  1558. isSRV |= keyword == "Texture2DMSArray";
  1559. if (isSRV)
  1560. return DxilResourceBase::Class::SRV;
  1561. bool isUAV = keyword == "RWBuffer";
  1562. isUAV |= keyword == "RWByteAddressBuffer";
  1563. isUAV |= keyword == "RWStructuredBuffer";
  1564. isUAV |= keyword == "RWTexture1D";
  1565. isUAV |= keyword == "RWTexture1DArray";
  1566. isUAV |= keyword == "RWTexture2D";
  1567. isUAV |= keyword == "RWTexture2DArray";
  1568. isUAV |= keyword == "RWTexture3D";
  1569. isUAV |= keyword == "RWTextureCube";
  1570. isUAV |= keyword == "RWTextureCubeArray";
  1571. isUAV |= keyword == "RWTexture2DMS";
  1572. isUAV |= keyword == "RWTexture2DMSArray";
  1573. isUAV |= keyword == "AppendStructuredBuffer";
  1574. isUAV |= keyword == "ConsumeStructuredBuffer";
  1575. isUAV |= keyword == "RasterizerOrderedBuffer";
  1576. isUAV |= keyword == "RasterizerOrderedByteAddressBuffer";
  1577. isUAV |= keyword == "RasterizerOrderedStructuredBuffer";
  1578. isUAV |= keyword == "RasterizerOrderedTexture1D";
  1579. isUAV |= keyword == "RasterizerOrderedTexture1DArray";
  1580. isUAV |= keyword == "RasterizerOrderedTexture2D";
  1581. isUAV |= keyword == "RasterizerOrderedTexture2DArray";
  1582. isUAV |= keyword == "RasterizerOrderedTexture3D";
  1583. if (isUAV)
  1584. return DxilResourceBase::Class::UAV;
  1585. return DxilResourceBase::Class::Invalid;
  1586. }
  1587. static DxilSampler::SamplerKind KeywordToSamplerKind(const std::string &keyword) {
  1588. // TODO: refactor for faster search (switch by 1/2/3 first letters, then
  1589. // compare)
  1590. if (keyword == "SamplerState")
  1591. return DxilSampler::SamplerKind::Default;
  1592. if (keyword == "SamplerComparisonState")
  1593. return DxilSampler::SamplerKind::Comparison;
  1594. return DxilSampler::SamplerKind::Invalid;
  1595. }
  1596. // This should probably be refactored to ASTContextHLSL, and follow types
  1597. // rather than do string comparisons.
  1598. DXIL::ResourceClass
  1599. hlsl::GetResourceClassForType(const clang::ASTContext &context,
  1600. clang::QualType Ty) {
  1601. Ty = Ty.getCanonicalType();
  1602. if (const clang::ArrayType *arrayType = context.getAsArrayType(Ty)) {
  1603. return GetResourceClassForType(context, arrayType->getElementType());
  1604. } else if (const RecordType *RT = Ty->getAsStructureType()) {
  1605. return KeywordToClass(RT->getDecl()->getName());
  1606. } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
  1607. if (const ClassTemplateSpecializationDecl *templateDecl =
  1608. dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl())) {
  1609. return KeywordToClass(templateDecl->getName());
  1610. }
  1611. }
  1612. return hlsl::DxilResourceBase::Class::Invalid;
  1613. }
  1614. hlsl::DxilResourceBase::Class CGMSHLSLRuntime::TypeToClass(clang::QualType Ty) {
  1615. return hlsl::GetResourceClassForType(CGM.getContext(), Ty);
  1616. }
  1617. uint32_t CGMSHLSLRuntime::AddSampler(VarDecl *samplerDecl) {
  1618. llvm::Constant *val = CGM.GetAddrOfGlobalVar(samplerDecl);
  1619. unique_ptr<DxilSampler> hlslRes(new DxilSampler);
  1620. hlslRes->SetLowerBound(UINT_MAX);
  1621. hlslRes->SetGlobalSymbol(cast<llvm::GlobalVariable>(val));
  1622. hlslRes->SetGlobalName(samplerDecl->getName());
  1623. QualType VarTy = samplerDecl->getType();
  1624. if (const clang::ArrayType *arrayType =
  1625. CGM.getContext().getAsArrayType(VarTy)) {
  1626. if (arrayType->isConstantArrayType()) {
  1627. uint32_t arraySize =
  1628. cast<ConstantArrayType>(arrayType)->getSize().getLimitedValue();
  1629. hlslRes->SetRangeSize(arraySize);
  1630. } else {
  1631. hlslRes->SetRangeSize(UINT_MAX);
  1632. }
  1633. // use elementTy
  1634. VarTy = arrayType->getElementType();
  1635. // Support more dim.
  1636. while (const clang::ArrayType *arrayType =
  1637. CGM.getContext().getAsArrayType(VarTy)) {
  1638. unsigned rangeSize = hlslRes->GetRangeSize();
  1639. if (arrayType->isConstantArrayType()) {
  1640. uint32_t arraySize =
  1641. cast<ConstantArrayType>(arrayType)->getSize().getLimitedValue();
  1642. if (rangeSize != UINT_MAX)
  1643. hlslRes->SetRangeSize(rangeSize * arraySize);
  1644. } else
  1645. hlslRes->SetRangeSize(UINT_MAX);
  1646. // use elementTy
  1647. VarTy = arrayType->getElementType();
  1648. }
  1649. } else
  1650. hlslRes->SetRangeSize(1);
  1651. const RecordType *RT = VarTy->getAs<RecordType>();
  1652. DxilSampler::SamplerKind kind = KeywordToSamplerKind(RT->getDecl()->getName());
  1653. hlslRes->SetSamplerKind(kind);
  1654. for (hlsl::UnusualAnnotation *it : samplerDecl->getUnusualAnnotations()) {
  1655. switch (it->getKind()) {
  1656. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  1657. hlsl::RegisterAssignment *ra = cast<hlsl::RegisterAssignment>(it);
  1658. hlslRes->SetLowerBound(ra->RegisterNumber);
  1659. hlslRes->SetSpaceID(ra->RegisterSpace);
  1660. break;
  1661. }
  1662. default:
  1663. llvm_unreachable("only register for sampler");
  1664. break;
  1665. }
  1666. }
  1667. hlslRes->SetID(m_pHLModule->GetSamplers().size());
  1668. return m_pHLModule->AddSampler(std::move(hlslRes));
  1669. }
  1670. uint32_t CGMSHLSLRuntime::AddUAVSRV(VarDecl *decl,
  1671. hlsl::DxilResourceBase::Class resClass) {
  1672. llvm::GlobalVariable *val =
  1673. cast<llvm::GlobalVariable>(CGM.GetAddrOfGlobalVar(decl));
  1674. QualType VarTy = decl->getType().getCanonicalType();
  1675. unique_ptr<HLResource> hlslRes(new HLResource);
  1676. hlslRes->SetLowerBound(UINT_MAX);
  1677. hlslRes->SetGlobalSymbol(val);
  1678. hlslRes->SetGlobalName(decl->getName());
  1679. if (const clang::ArrayType *arrayType =
  1680. CGM.getContext().getAsArrayType(VarTy)) {
  1681. if (arrayType->isConstantArrayType()) {
  1682. uint32_t arraySize =
  1683. cast<ConstantArrayType>(arrayType)->getSize().getLimitedValue();
  1684. hlslRes->SetRangeSize(arraySize);
  1685. } else
  1686. hlslRes->SetRangeSize(UINT_MAX);
  1687. // use elementTy
  1688. VarTy = arrayType->getElementType();
  1689. // Support more dim.
  1690. while (const clang::ArrayType *arrayType =
  1691. CGM.getContext().getAsArrayType(VarTy)) {
  1692. unsigned rangeSize = hlslRes->GetRangeSize();
  1693. if (arrayType->isConstantArrayType()) {
  1694. uint32_t arraySize =
  1695. cast<ConstantArrayType>(arrayType)->getSize().getLimitedValue();
  1696. if (rangeSize != UINT_MAX)
  1697. hlslRes->SetRangeSize(rangeSize * arraySize);
  1698. } else
  1699. hlslRes->SetRangeSize(UINT_MAX);
  1700. // use elementTy
  1701. VarTy = arrayType->getElementType();
  1702. }
  1703. } else
  1704. hlslRes->SetRangeSize(1);
  1705. for (hlsl::UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  1706. switch (it->getKind()) {
  1707. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  1708. hlsl::RegisterAssignment *ra = cast<hlsl::RegisterAssignment>(it);
  1709. hlslRes->SetLowerBound(ra->RegisterNumber);
  1710. hlslRes->SetSpaceID(ra->RegisterSpace);
  1711. break;
  1712. }
  1713. default:
  1714. llvm_unreachable("only register for uav/srv");
  1715. break;
  1716. }
  1717. }
  1718. const RecordType *RT = VarTy->getAs<RecordType>();
  1719. RecordDecl *RD = RT->getDecl();
  1720. hlsl::DxilResource::Kind kind = KeywordToKind(RT->getDecl()->getName());
  1721. hlslRes->SetKind(kind);
  1722. // Get the result type from handle field.
  1723. FieldDecl *FD = *(RD->field_begin());
  1724. DXASSERT(FD->getName() == "h", "must be handle field");
  1725. QualType resultTy = FD->getType();
  1726. // Type annotation for result type of resource.
  1727. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  1728. unsigned arrayEltSize = 0;
  1729. AddTypeAnnotation(decl->getType(), dxilTypeSys, arrayEltSize);
  1730. if (kind == hlsl::DxilResource::Kind::Texture2DMS ||
  1731. kind == hlsl::DxilResource::Kind::Texture2DMSArray) {
  1732. const ClassTemplateSpecializationDecl *templateDecl =
  1733. dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl());
  1734. const clang::TemplateArgument &sampleCountArg =
  1735. templateDecl->getTemplateArgs()[1];
  1736. uint32_t sampleCount = sampleCountArg.getAsIntegral().getLimitedValue();
  1737. hlslRes->SetSampleCount(sampleCount);
  1738. }
  1739. if (kind != hlsl::DxilResource::Kind::StructuredBuffer) {
  1740. QualType Ty = resultTy;
  1741. QualType EltTy = Ty;
  1742. if (hlsl::IsHLSLMatType(Ty))
  1743. EltTy = hlsl::GetHLSLMatElementType(Ty);
  1744. if (hlsl::IsHLSLVecType(Ty))
  1745. EltTy = hlsl::GetHLSLVecElementType(Ty);
  1746. EltTy = EltTy.getCanonicalType();
  1747. bool bSNorm = false;
  1748. bool bUNorm = false;
  1749. if (const AttributedType *AT = dyn_cast<AttributedType>(Ty)) {
  1750. switch (AT->getAttrKind()) {
  1751. case AttributedType::Kind::attr_hlsl_snorm:
  1752. bSNorm = true;
  1753. break;
  1754. case AttributedType::Kind::attr_hlsl_unorm:
  1755. bUNorm = true;
  1756. break;
  1757. default:
  1758. // Do nothing
  1759. break;
  1760. }
  1761. }
  1762. if (EltTy->isBuiltinType()) {
  1763. const BuiltinType *BTy = EltTy->getAs<BuiltinType>();
  1764. CompType::Kind kind = BuiltinTyToCompTy(BTy, bSNorm, bUNorm);
  1765. hlslRes->SetCompType(kind);
  1766. } else
  1767. DXASSERT(!bSNorm && !bUNorm, "snorm/unorm on invalid type");
  1768. }
  1769. // TODO: set resource
  1770. // hlslRes.SetGloballyCoherent();
  1771. hlslRes->SetROV(RT->getDecl()->getName().startswith("RasterizerOrdered"));
  1772. if (kind == hlsl::DxilResource::Kind::TypedBuffer ||
  1773. kind == hlsl::DxilResource::Kind::StructuredBuffer) {
  1774. const ClassTemplateSpecializationDecl *templateDecl =
  1775. dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl());
  1776. const clang::TemplateArgument &retTyArg =
  1777. templateDecl->getTemplateArgs()[0];
  1778. llvm::Type *retTy = CGM.getTypes().ConvertType(retTyArg.getAsType());
  1779. uint32_t strideInBytes = legacyLayout.getTypeAllocSize(retTy);
  1780. hlslRes->SetElementStride(strideInBytes);
  1781. }
  1782. if (resClass == hlsl::DxilResourceBase::Class::SRV) {
  1783. hlslRes->SetRW(false);
  1784. hlslRes->SetID(m_pHLModule->GetSRVs().size());
  1785. return m_pHLModule->AddSRV(std::move(hlslRes));
  1786. } else {
  1787. hlslRes->SetRW(true);
  1788. hlslRes->SetID(m_pHLModule->GetUAVs().size());
  1789. return m_pHLModule->AddUAV(std::move(hlslRes));
  1790. }
  1791. }
  1792. static bool IsResourceInType(const clang::ASTContext &context,
  1793. clang::QualType Ty) {
  1794. Ty = Ty.getCanonicalType();
  1795. if (const clang::ArrayType *arrayType = context.getAsArrayType(Ty)) {
  1796. return IsResourceInType(context, arrayType->getElementType());
  1797. } else if (const RecordType *RT = Ty->getAsStructureType()) {
  1798. if (KeywordToClass(RT->getDecl()->getName()) != DxilResourceBase::Class::Invalid)
  1799. return true;
  1800. const CXXRecordDecl* typeRecordDecl = RT->getAsCXXRecordDecl();
  1801. if (typeRecordDecl && !typeRecordDecl->isImplicit()) {
  1802. for (auto field : typeRecordDecl->fields()) {
  1803. if (IsResourceInType(context, field->getType()))
  1804. return true;
  1805. }
  1806. }
  1807. } else if (const RecordType *RT = Ty->getAs<RecordType>()) {
  1808. if (const ClassTemplateSpecializationDecl *templateDecl =
  1809. dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl())) {
  1810. if (KeywordToClass(templateDecl->getName()) != DxilResourceBase::Class::Invalid)
  1811. return true;
  1812. }
  1813. }
  1814. return false; // no resources found
  1815. }
  1816. void CGMSHLSLRuntime::AddConstant(VarDecl *constDecl, HLCBuffer &CB) {
  1817. if (constDecl->getStorageClass() == SC_Static) {
  1818. // For static inside cbuffer, take as global static.
  1819. // Don't add to cbuffer.
  1820. CGM.EmitGlobal(constDecl);
  1821. return;
  1822. }
  1823. // Search defined structure for resource objects and fail
  1824. if (IsResourceInType(CGM.getContext(), constDecl->getType())) {
  1825. DiagnosticsEngine &Diags = CGM.getDiags();
  1826. unsigned DiagID = Diags.getCustomDiagID(
  1827. DiagnosticsEngine::Error,
  1828. "object types not supported in global aggregate instances, cbuffers, or tbuffers.");
  1829. Diags.Report(constDecl->getLocation(), DiagID);
  1830. return;
  1831. }
  1832. llvm::Constant *constVal = CGM.GetAddrOfGlobalVar(constDecl);
  1833. bool isGlobalCB = CB.GetID() == globalCBIndex;
  1834. uint32_t offset = 0;
  1835. bool userOffset = false;
  1836. for (hlsl::UnusualAnnotation *it : constDecl->getUnusualAnnotations()) {
  1837. switch (it->getKind()) {
  1838. case hlsl::UnusualAnnotation::UA_ConstantPacking: {
  1839. if (!isGlobalCB) {
  1840. // TODO: check cannot mix packoffset elements with nonpackoffset
  1841. // elements in a cbuffer.
  1842. hlsl::ConstantPacking *cp = cast<hlsl::ConstantPacking>(it);
  1843. offset = cp->Subcomponent << 2;
  1844. offset += cp->ComponentOffset;
  1845. // Change to byte.
  1846. offset <<= 2;
  1847. userOffset = true;
  1848. } else {
  1849. DiagnosticsEngine &Diags = CGM.getDiags();
  1850. unsigned DiagID = Diags.getCustomDiagID(
  1851. DiagnosticsEngine::Error,
  1852. "packoffset is only allowed in a constant buffer.");
  1853. Diags.Report(it->Loc, DiagID);
  1854. }
  1855. break;
  1856. }
  1857. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  1858. if (isGlobalCB) {
  1859. RegisterAssignment *ra = cast<RegisterAssignment>(it);
  1860. offset = ra->RegisterNumber << 2;
  1861. // Change to byte.
  1862. offset <<= 2;
  1863. userOffset = true;
  1864. }
  1865. break;
  1866. }
  1867. case hlsl::UnusualAnnotation::UA_SemanticDecl:
  1868. // skip semantic on constant
  1869. break;
  1870. }
  1871. }
  1872. std::unique_ptr<DxilResourceBase> pHlslConst = std::make_unique<DxilResourceBase>(DXIL::ResourceClass::Invalid);
  1873. pHlslConst->SetLowerBound(UINT_MAX);
  1874. pHlslConst->SetGlobalSymbol(cast<llvm::GlobalVariable>(constVal));
  1875. pHlslConst->SetGlobalName(constDecl->getName());
  1876. if (userOffset) {
  1877. pHlslConst->SetLowerBound(offset);
  1878. }
  1879. DxilTypeSystem &dxilTypeSys = m_pHLModule->GetTypeSystem();
  1880. // Just add type annotation here.
  1881. // Offset will be allocated later.
  1882. QualType Ty = constDecl->getType();
  1883. if (CB.GetRangeSize() != 1) {
  1884. while (Ty->isArrayType()) {
  1885. Ty = Ty->getAsArrayTypeUnsafe()->getElementType();
  1886. }
  1887. }
  1888. unsigned arrayEltSize = 0;
  1889. unsigned size = AddTypeAnnotation(Ty, dxilTypeSys, arrayEltSize);
  1890. pHlslConst->SetRangeSize(size);
  1891. CB.AddConst(pHlslConst);
  1892. // Save fieldAnnotation for the const var.
  1893. DxilFieldAnnotation fieldAnnotation;
  1894. if (userOffset)
  1895. fieldAnnotation.SetCBufferOffset(offset);
  1896. // Get the nested element type.
  1897. if (Ty->isArrayType()) {
  1898. while (const ConstantArrayType *arrayTy =
  1899. CGM.getContext().getAsConstantArrayType(Ty)) {
  1900. Ty = arrayTy->getElementType();
  1901. }
  1902. }
  1903. bool bDefaultRowMajor = m_pHLModule->GetHLOptions().bDefaultRowMajor;
  1904. ConstructFieldAttributedAnnotation(fieldAnnotation, Ty, bDefaultRowMajor);
  1905. m_ConstVarAnnotationMap[constVal] = fieldAnnotation;
  1906. }
  1907. uint32_t CGMSHLSLRuntime::AddCBuffer(HLSLBufferDecl *D) {
  1908. unique_ptr<HLCBuffer> CB = std::make_unique<HLCBuffer>();
  1909. // setup the CB
  1910. CB->SetGlobalSymbol(nullptr);
  1911. CB->SetGlobalName(D->getNameAsString());
  1912. CB->SetLowerBound(UINT_MAX);
  1913. if (!D->isCBuffer()) {
  1914. CB->SetKind(DXIL::ResourceKind::TBuffer);
  1915. }
  1916. // the global variable will only used once by the createHandle?
  1917. // SetHandle(llvm::Value *pHandle);
  1918. for (hlsl::UnusualAnnotation *it : D->getUnusualAnnotations()) {
  1919. switch (it->getKind()) {
  1920. case hlsl::UnusualAnnotation::UA_RegisterAssignment: {
  1921. hlsl::RegisterAssignment *ra = cast<hlsl::RegisterAssignment>(it);
  1922. uint32_t regNum = ra->RegisterNumber;
  1923. uint32_t regSpace = ra->RegisterSpace;
  1924. CB->SetSpaceID(regSpace);
  1925. CB->SetLowerBound(regNum);
  1926. break;
  1927. }
  1928. case hlsl::UnusualAnnotation::UA_SemanticDecl:
  1929. // skip semantic on constant buffer
  1930. break;
  1931. case hlsl::UnusualAnnotation::UA_ConstantPacking:
  1932. llvm_unreachable("no packoffset on constant buffer");
  1933. break;
  1934. }
  1935. }
  1936. // Add constant
  1937. if (D->isConstantBufferView()) {
  1938. VarDecl *constDecl = cast<VarDecl>(*D->decls_begin());
  1939. CB->SetRangeSize(1);
  1940. QualType Ty = constDecl->getType();
  1941. if (Ty->isArrayType()) {
  1942. if (!Ty->isIncompleteArrayType()) {
  1943. unsigned arraySize = 1;
  1944. while (Ty->isArrayType()) {
  1945. Ty = Ty->getCanonicalTypeUnqualified();
  1946. const ConstantArrayType *AT = cast<ConstantArrayType>(Ty);
  1947. arraySize *= AT->getSize().getLimitedValue();
  1948. Ty = AT->getElementType();
  1949. }
  1950. CB->SetRangeSize(arraySize);
  1951. } else {
  1952. CB->SetRangeSize(UINT_MAX);
  1953. }
  1954. }
  1955. AddConstant(constDecl, *CB.get());
  1956. } else {
  1957. auto declsEnds = D->decls_end();
  1958. CB->SetRangeSize(1);
  1959. for (auto it = D->decls_begin(); it != declsEnds; it++) {
  1960. if (VarDecl *constDecl = dyn_cast<VarDecl>(*it))
  1961. AddConstant(constDecl, *CB.get());
  1962. else if (isa<EmptyDecl>(*it)) {
  1963. } else if (isa<CXXRecordDecl>(*it)) {
  1964. } else {
  1965. HLSLBufferDecl *inner = cast<HLSLBufferDecl>(*it);
  1966. GetOrCreateCBuffer(inner);
  1967. }
  1968. }
  1969. }
  1970. CB->SetID(m_pHLModule->GetCBuffers().size());
  1971. return m_pHLModule->AddCBuffer(std::move(CB));
  1972. }
  1973. HLCBuffer &CGMSHLSLRuntime::GetOrCreateCBuffer(HLSLBufferDecl *D) {
  1974. if (constantBufMap.count(D) != 0) {
  1975. uint32_t cbIndex = constantBufMap[D];
  1976. return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(cbIndex)));
  1977. }
  1978. uint32_t cbID = AddCBuffer(D);
  1979. constantBufMap[D] = cbID;
  1980. return *static_cast<HLCBuffer*>(&(m_pHLModule->GetCBuffer(cbID)));
  1981. }
  1982. bool CGMSHLSLRuntime::IsPatchConstantFunction(const Function *F) {
  1983. DXASSERT_NOMSG(F != nullptr);
  1984. for (auto && p : patchConstantFunctionMap) {
  1985. if (p.second == F) return true;
  1986. }
  1987. return false;
  1988. }
  1989. void CGMSHLSLRuntime::SetEntryFunction() {
  1990. if (EntryFunc == nullptr) {
  1991. DiagnosticsEngine &Diags = CGM.getDiags();
  1992. unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
  1993. "cannot find entry function %0");
  1994. Diags.Report(DiagID) << CGM.getCodeGenOpts().HLSLEntryFunction;
  1995. return;
  1996. }
  1997. m_pHLModule->SetEntryFunction(EntryFunc);
  1998. }
  1999. // Here the size is CB size. So don't need check type.
  2000. static unsigned AlignCBufferOffset(unsigned offset, unsigned size, llvm::Type *Ty) {
  2001. // offset is already 4 bytes aligned.
  2002. bool b8BytesAlign = Ty->isDoubleTy();
  2003. if (llvm::IntegerType *IT = dyn_cast<llvm::IntegerType>(Ty)) {
  2004. b8BytesAlign = IT->getBitWidth() > 32;
  2005. }
  2006. // Align it to 4 x 4bytes.
  2007. if (unsigned remainder = (offset & 0xf)) {
  2008. unsigned aligned = offset - remainder + 16;
  2009. // If cannot fit in the remainder, need align.
  2010. bool bNeedAlign = (remainder + size) > 16;
  2011. // Array always start aligned.
  2012. bNeedAlign |= Ty->isArrayTy();
  2013. if (bNeedAlign)
  2014. return AlignTo8Bytes(aligned, b8BytesAlign);
  2015. else
  2016. return AlignTo8Bytes(offset, b8BytesAlign);
  2017. } else
  2018. return offset;
  2019. }
  2020. static unsigned AllocateDxilConstantBuffer(HLCBuffer &CB) {
  2021. unsigned offset = 0;
  2022. // Scan user allocated constants first.
  2023. // Update offset.
  2024. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2025. if (C->GetLowerBound() == UINT_MAX)
  2026. continue;
  2027. unsigned size = C->GetRangeSize();
  2028. unsigned nextOffset = size + C->GetLowerBound();
  2029. if (offset < nextOffset)
  2030. offset = nextOffset;
  2031. }
  2032. // Alloc after user allocated constants.
  2033. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2034. if (C->GetLowerBound() != UINT_MAX)
  2035. continue;
  2036. unsigned size = C->GetRangeSize();
  2037. llvm::Type *Ty = C->GetGlobalSymbol()->getType()->getPointerElementType();
  2038. // Align offset.
  2039. offset = AlignCBufferOffset(offset, size, Ty);
  2040. if (C->GetLowerBound() == UINT_MAX) {
  2041. C->SetLowerBound(offset);
  2042. }
  2043. offset += size;
  2044. }
  2045. return offset;
  2046. }
  2047. static void AllocateDxilConstantBuffers(HLModule *pHLModule) {
  2048. for (unsigned i = 0; i < pHLModule->GetCBuffers().size(); i++) {
  2049. HLCBuffer &CB = *static_cast<HLCBuffer*>(&(pHLModule->GetCBuffer(i)));
  2050. unsigned size = AllocateDxilConstantBuffer(CB);
  2051. CB.SetSize(size);
  2052. }
  2053. }
  2054. static void ReplaceUseInFunction(Value *V, Value *NewV, Function *F,
  2055. IRBuilder<> &Builder) {
  2056. for (auto U = V->user_begin(); U != V->user_end(); ) {
  2057. User *user = *(U++);
  2058. if (Instruction *I = dyn_cast<Instruction>(user)) {
  2059. if (I->getParent()->getParent() == F) {
  2060. // replace use with GEP if in F
  2061. for (unsigned i = 0; i < I->getNumOperands(); i++) {
  2062. if (I->getOperand(i) == V)
  2063. I->setOperand(i, NewV);
  2064. }
  2065. }
  2066. } else {
  2067. // For constant operator, create local clone which use GEP.
  2068. // Only support GEP and bitcast.
  2069. if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(user)) {
  2070. std::vector<Value *> idxList(GEPOp->idx_begin(), GEPOp->idx_end());
  2071. Value *NewGEP = Builder.CreateInBoundsGEP(NewV, idxList);
  2072. ReplaceUseInFunction(GEPOp, NewGEP, F, Builder);
  2073. } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(user)) {
  2074. // Change the init val into NewV with Store.
  2075. GV->setInitializer(nullptr);
  2076. Builder.CreateStore(NewV, GV);
  2077. } else {
  2078. // Must be bitcast here.
  2079. BitCastOperator *BC = cast<BitCastOperator>(user);
  2080. Value *NewBC = Builder.CreateBitCast(NewV, BC->getType());
  2081. ReplaceUseInFunction(BC, NewBC, F, Builder);
  2082. }
  2083. }
  2084. }
  2085. }
  2086. void MarkUsedFunctionForConst(Value *V, std::unordered_set<Function*> &usedFunc) {
  2087. for (auto U = V->user_begin(); U != V->user_end();) {
  2088. User *user = *(U++);
  2089. if (Instruction *I = dyn_cast<Instruction>(user)) {
  2090. Function *F = I->getParent()->getParent();
  2091. usedFunc.insert(F);
  2092. } else {
  2093. // For constant operator, create local clone which use GEP.
  2094. // Only support GEP and bitcast.
  2095. if (GEPOperator *GEPOp = dyn_cast<GEPOperator>(user)) {
  2096. MarkUsedFunctionForConst(GEPOp, usedFunc);
  2097. } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(user)) {
  2098. MarkUsedFunctionForConst(GV, usedFunc);
  2099. } else {
  2100. // Must be bitcast here.
  2101. BitCastOperator *BC = cast<BitCastOperator>(user);
  2102. MarkUsedFunctionForConst(BC, usedFunc);
  2103. }
  2104. }
  2105. }
  2106. }
  2107. static bool CreateCBufferVariable(HLCBuffer &CB,
  2108. llvm::Module &M) {
  2109. bool bUsed = false;
  2110. // Build Struct for CBuffer.
  2111. SmallVector<llvm::Type*, 4> Elements;
  2112. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2113. Value *GV = C->GetGlobalSymbol();
  2114. if (GV->hasNUsesOrMore(1))
  2115. bUsed = true;
  2116. // Global variable must be pointer type.
  2117. llvm::Type *Ty = GV->getType()->getPointerElementType();
  2118. Elements.emplace_back(Ty);
  2119. }
  2120. // Don't create CBuffer variable for unused cbuffer.
  2121. if (!bUsed)
  2122. return false;
  2123. bool isCBArray = CB.GetRangeSize() != 1;
  2124. llvm::GlobalVariable *cbGV = nullptr;
  2125. llvm::Type *cbTy = nullptr;
  2126. unsigned cbIndexDepth = 0;
  2127. if (!isCBArray) {
  2128. llvm::StructType *CBStructTy =
  2129. llvm::StructType::create(Elements, CB.GetGlobalName());
  2130. cbGV = new llvm::GlobalVariable(M, CBStructTy, /*IsConstant*/ true,
  2131. llvm::GlobalValue::ExternalLinkage,
  2132. /*InitVal*/ nullptr, CB.GetGlobalName());
  2133. cbTy = cbGV->getType();
  2134. } else {
  2135. // For array of ConstantBuffer, create array of struct instead of struct of
  2136. // array.
  2137. DXASSERT(CB.GetConstants().size() == 1,
  2138. "ConstantBuffer should have 1 constant");
  2139. Value *GV = CB.GetConstants()[0]->GetGlobalSymbol();
  2140. llvm::Type *CBEltTy =
  2141. GV->getType()->getPointerElementType()->getArrayElementType();
  2142. cbIndexDepth = 1;
  2143. while (CBEltTy->isArrayTy()) {
  2144. CBEltTy = CBEltTy->getArrayElementType();
  2145. cbIndexDepth++;
  2146. }
  2147. // Add one level struct type to match normal case.
  2148. llvm::StructType *CBStructTy =
  2149. llvm::StructType::create({CBEltTy}, CB.GetGlobalName());
  2150. llvm::ArrayType *CBArrayTy =
  2151. llvm::ArrayType::get(CBStructTy, CB.GetRangeSize());
  2152. cbGV = new llvm::GlobalVariable(M, CBArrayTy, /*IsConstant*/ true,
  2153. llvm::GlobalValue::ExternalLinkage,
  2154. /*InitVal*/ nullptr, CB.GetGlobalName());
  2155. cbTy = llvm::PointerType::get(CBStructTy,
  2156. cbGV->getType()->getPointerAddressSpace());
  2157. }
  2158. CB.SetGlobalSymbol(cbGV);
  2159. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  2160. llvm::Type *idxTy = opcodeTy;
  2161. llvm::FunctionType *SubscriptFuncTy =
  2162. llvm::FunctionType::get(cbTy, { opcodeTy, cbGV->getType(), idxTy}, false);
  2163. Function *subscriptFunc =
  2164. GetOrCreateHLFunction(M, SubscriptFuncTy, HLOpcodeGroup::HLSubscript,
  2165. (unsigned)HLSubscriptOpcode::CBufferSubscript);
  2166. Constant *opArg = ConstantInt::get(opcodeTy, (unsigned)HLSubscriptOpcode::CBufferSubscript);
  2167. Constant *zeroIdx = ConstantInt::get(opcodeTy, 0);
  2168. Value *args[] = { opArg, nullptr, zeroIdx };
  2169. llvm::LLVMContext &Context = M.getContext();
  2170. llvm::Type *i32Ty = llvm::Type::getInt32Ty(Context);
  2171. Value *zero = ConstantInt::get(i32Ty, (uint64_t)0);
  2172. std::vector<Value *> indexArray(CB.GetConstants().size());
  2173. std::vector<std::unordered_set<Function*>> constUsedFuncList(CB.GetConstants().size());
  2174. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2175. Value *idx = ConstantInt::get(i32Ty, C->GetID());
  2176. indexArray[C->GetID()] = idx;
  2177. Value *GV = C->GetGlobalSymbol();
  2178. MarkUsedFunctionForConst(GV, constUsedFuncList[C->GetID()]);
  2179. }
  2180. for (Function &F : M.functions()) {
  2181. if (!F.isDeclaration()) {
  2182. IRBuilder<> Builder(F.getEntryBlock().getFirstInsertionPt());
  2183. args[HLOperandIndex::kSubscriptObjectOpIdx] = cbGV;
  2184. // create HL subscript to make all the use of cbuffer start from it.
  2185. Instruction *cbSubscript = cast<Instruction>(Builder.CreateCall(subscriptFunc, {args} ));
  2186. // Replace constant var with GEP pGV
  2187. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2188. Value *GV = C->GetGlobalSymbol();
  2189. if (constUsedFuncList[C->GetID()].count(&F) == 0)
  2190. continue;
  2191. Value *idx = indexArray[C->GetID()];
  2192. if (!isCBArray) {
  2193. Instruction *GEP = cast<Instruction>(
  2194. Builder.CreateInBoundsGEP(cbSubscript, {zero, idx}));
  2195. // TODO: make sure the debug info is synced to GEP.
  2196. // GEP->setDebugLoc(GV);
  2197. ReplaceUseInFunction(GV, GEP, &F, Builder);
  2198. // Delete if no use in F.
  2199. if (GEP->user_empty())
  2200. GEP->eraseFromParent();
  2201. } else {
  2202. for (auto U = GV->user_begin(); U != GV->user_end();) {
  2203. User *user = *(U++);
  2204. if (user->user_empty())
  2205. continue;
  2206. Instruction *I = dyn_cast<Instruction>(user);
  2207. if (I && I->getParent()->getParent() != &F)
  2208. continue;
  2209. IRBuilder<> *instBuilder = &Builder;
  2210. unique_ptr<IRBuilder<> > B;
  2211. if (I) {
  2212. B = make_unique<IRBuilder<> >(I);
  2213. instBuilder = B.get();
  2214. }
  2215. GEPOperator *GEPOp = cast<GEPOperator>(user);
  2216. std::vector<Value *> idxList;
  2217. DXASSERT(GEPOp->getNumIndices() >= 1 + cbIndexDepth,
  2218. "must indexing ConstantBuffer array");
  2219. idxList.reserve(GEPOp->getNumIndices() - (cbIndexDepth - 1));
  2220. gep_type_iterator GI = gep_type_begin(*GEPOp), E = gep_type_end(*GEPOp);
  2221. idxList.push_back(GI.getOperand());
  2222. // change array index with 0 for struct index.
  2223. idxList.push_back(zero);
  2224. GI++;
  2225. Value *arrayIdx = GI.getOperand();
  2226. GI++;
  2227. for (unsigned curIndex = 1; GI != E && curIndex < cbIndexDepth; ++GI, ++curIndex) {
  2228. arrayIdx = instBuilder->CreateMul(arrayIdx, Builder.getInt32(GI->getArrayNumElements()));
  2229. arrayIdx = instBuilder->CreateAdd(arrayIdx, GI.getOperand());
  2230. }
  2231. for (; GI != E; ++GI) {
  2232. idxList.push_back(GI.getOperand());
  2233. }
  2234. args[HLOperandIndex::kSubscriptIndexOpIdx] = arrayIdx;
  2235. Instruction *cbSubscript =
  2236. cast<Instruction>(instBuilder->CreateCall(subscriptFunc, {args}));
  2237. Instruction *NewGEP = cast<Instruction>(
  2238. instBuilder->CreateInBoundsGEP(cbSubscript, idxList));
  2239. ReplaceUseInFunction(GEPOp, NewGEP, &F, *instBuilder);
  2240. }
  2241. }
  2242. }
  2243. // Delete if no use in F.
  2244. if (cbSubscript->user_empty())
  2245. cbSubscript->eraseFromParent();
  2246. }
  2247. }
  2248. return true;
  2249. }
  2250. static void ConstructCBufferAnnotation(
  2251. HLCBuffer &CB, DxilTypeSystem &dxilTypeSys,
  2252. std::unordered_map<Constant *, DxilFieldAnnotation> &AnnotationMap) {
  2253. Value *GV = CB.GetGlobalSymbol();
  2254. llvm::StructType *CBStructTy =
  2255. dyn_cast<llvm::StructType>(GV->getType()->getPointerElementType());
  2256. if (!CBStructTy) {
  2257. // For Array of ConstantBuffer.
  2258. llvm::ArrayType *CBArrayTy =
  2259. cast<llvm::ArrayType>(GV->getType()->getPointerElementType());
  2260. CBStructTy = cast<llvm::StructType>(CBArrayTy->getArrayElementType());
  2261. }
  2262. DxilStructAnnotation *CBAnnotation =
  2263. dxilTypeSys.AddStructAnnotation(CBStructTy);
  2264. CBAnnotation->SetCBufferSize(CB.GetSize());
  2265. // Set fieldAnnotation for each constant var.
  2266. for (const std::unique_ptr<DxilResourceBase> &C : CB.GetConstants()) {
  2267. Constant *GV = C->GetGlobalSymbol();
  2268. DxilFieldAnnotation &fieldAnnotation =
  2269. CBAnnotation->GetFieldAnnotation(C->GetID());
  2270. fieldAnnotation = AnnotationMap[GV];
  2271. // This is after CBuffer allocation.
  2272. fieldAnnotation.SetCBufferOffset(C->GetLowerBound());
  2273. fieldAnnotation.SetFieldName(C->GetGlobalName());
  2274. }
  2275. }
  2276. static void ConstructCBuffer(
  2277. HLModule *pHLModule,
  2278. llvm::Type *CBufferType,
  2279. std::unordered_map<Constant *, DxilFieldAnnotation> &AnnotationMap) {
  2280. DxilTypeSystem &dxilTypeSys = pHLModule->GetTypeSystem();
  2281. for (unsigned i = 0; i < pHLModule->GetCBuffers().size(); i++) {
  2282. HLCBuffer &CB = *static_cast<HLCBuffer*>(&(pHLModule->GetCBuffer(i)));
  2283. if (CB.GetConstants().size() == 0) {
  2284. // Create Fake variable for cbuffer which is empty.
  2285. llvm::GlobalVariable *pGV = new llvm::GlobalVariable(
  2286. *pHLModule->GetModule(), CBufferType, true,
  2287. llvm::GlobalValue::ExternalLinkage, nullptr, CB.GetGlobalName());
  2288. CB.SetGlobalSymbol(pGV);
  2289. } else {
  2290. bool bCreated = CreateCBufferVariable(CB, *pHLModule->GetModule());
  2291. if (bCreated)
  2292. ConstructCBufferAnnotation(CB, dxilTypeSys, AnnotationMap);
  2293. else {
  2294. // Create Fake variable for cbuffer which is unused.
  2295. llvm::GlobalVariable *pGV = new llvm::GlobalVariable(
  2296. *pHLModule->GetModule(), CBufferType, true,
  2297. llvm::GlobalValue::ExternalLinkage, nullptr, CB.GetGlobalName());
  2298. CB.SetGlobalSymbol(pGV);
  2299. }
  2300. }
  2301. // Clear the constants which useless now.
  2302. CB.GetConstants().clear();
  2303. }
  2304. }
  2305. static void ReplaceBoolVectorSubscript(CallInst *CI) {
  2306. Value *Ptr = CI->getArgOperand(0);
  2307. Value *Idx = CI->getArgOperand(1);
  2308. Value *IdxList[] = {ConstantInt::get(Idx->getType(), 0), Idx};
  2309. llvm::Type *i1Ty = llvm::Type::getInt1Ty(Idx->getContext());
  2310. for (auto It = CI->user_begin(), E = CI->user_end(); It != E;) {
  2311. Instruction *user = cast<Instruction>(*(It++));
  2312. IRBuilder<> Builder(user);
  2313. Value *GEP = Builder.CreateInBoundsGEP(Ptr, IdxList);
  2314. if (LoadInst *LI = dyn_cast<LoadInst>(user)) {
  2315. Value *NewLd = Builder.CreateLoad(GEP);
  2316. Value *cast = Builder.CreateZExt(NewLd, LI->getType());
  2317. LI->replaceAllUsesWith(cast);
  2318. LI->eraseFromParent();
  2319. } else {
  2320. // Must be a store inst here.
  2321. StoreInst *SI = cast<StoreInst>(user);
  2322. Value *V = SI->getValueOperand();
  2323. Value *cast = Builder.CreateTrunc(V, i1Ty);
  2324. Builder.CreateStore(cast, GEP);
  2325. SI->eraseFromParent();
  2326. }
  2327. }
  2328. CI->eraseFromParent();
  2329. }
  2330. static void ReplaceBoolVectorSubscript(Function *F) {
  2331. for (auto It = F->user_begin(), E = F->user_end(); It != E; ) {
  2332. User *user = *(It++);
  2333. CallInst *CI = cast<CallInst>(user);
  2334. ReplaceBoolVectorSubscript(CI);
  2335. }
  2336. }
  2337. // Add function body for intrinsic if possible.
  2338. static Function *CreateOpFunction(llvm::Module &M, Function *F,
  2339. llvm::FunctionType *funcTy,
  2340. HLOpcodeGroup group, unsigned opcode) {
  2341. Function *opFunc = nullptr;
  2342. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  2343. if (group == HLOpcodeGroup::HLIntrinsic) {
  2344. IntrinsicOp intriOp = static_cast<IntrinsicOp>(opcode);
  2345. switch (intriOp) {
  2346. case IntrinsicOp::MOP_Append:
  2347. case IntrinsicOp::MOP_Consume: {
  2348. bool bAppend = intriOp == IntrinsicOp::MOP_Append;
  2349. llvm::Type *handleTy = funcTy->getParamType(HLOperandIndex::kHandleOpIdx);
  2350. // Don't generate body for OutputStream::Append.
  2351. if (bAppend && HLModule::IsStreamOutputPtrType(handleTy)) {
  2352. opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  2353. break;
  2354. }
  2355. opFunc = GetOrCreateHLFunctionWithBody(M, funcTy, group, opcode,
  2356. bAppend ? "append" : "consume");
  2357. llvm::Type *counterTy = llvm::Type::getInt32Ty(M.getContext());
  2358. llvm::FunctionType *IncCounterFuncTy =
  2359. llvm::FunctionType::get(counterTy, {opcodeTy, handleTy}, false);
  2360. unsigned counterOpcode = bAppend ? (unsigned)IntrinsicOp::MOP_IncrementCounter:
  2361. (unsigned)IntrinsicOp::MOP_DecrementCounter;
  2362. Function *incCounterFunc =
  2363. GetOrCreateHLFunction(M, IncCounterFuncTy, group,
  2364. counterOpcode);
  2365. llvm::Type *idxTy = counterTy;
  2366. llvm::Type *valTy = bAppend ?
  2367. funcTy->getParamType(HLOperandIndex::kAppendValOpIndex):funcTy->getReturnType();
  2368. llvm::Type *subscriptTy = valTy;
  2369. if (!valTy->isPointerTy()) {
  2370. // Return type for subscript should be pointer type.
  2371. subscriptTy = llvm::PointerType::get(valTy, 0);
  2372. }
  2373. llvm::FunctionType *SubscriptFuncTy =
  2374. llvm::FunctionType::get(subscriptTy, {opcodeTy, handleTy, idxTy}, false);
  2375. Function *subscriptFunc =
  2376. GetOrCreateHLFunction(M, SubscriptFuncTy, HLOpcodeGroup::HLSubscript,
  2377. (unsigned)HLSubscriptOpcode::DefaultSubscript);
  2378. BasicBlock *BB = BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
  2379. IRBuilder<> Builder(BB);
  2380. auto argIter = opFunc->args().begin();
  2381. // Skip the opcode arg.
  2382. argIter++;
  2383. Argument *thisArg = argIter++;
  2384. // int counter = IncrementCounter/DecrementCounter(Buf);
  2385. Value *incCounterOpArg =
  2386. ConstantInt::get(idxTy, counterOpcode);
  2387. Value *counter =
  2388. Builder.CreateCall(incCounterFunc, {incCounterOpArg, thisArg});
  2389. // Buf[counter];
  2390. Value *subscriptOpArg = ConstantInt::get(
  2391. idxTy, (unsigned)HLSubscriptOpcode::DefaultSubscript);
  2392. Value *subscript =
  2393. Builder.CreateCall(subscriptFunc, {subscriptOpArg, thisArg, counter});
  2394. if (bAppend) {
  2395. Argument *valArg = argIter;
  2396. // Buf[counter] = val;
  2397. if (valTy->isPointerTy()) {
  2398. Value *valArgCast = Builder.CreateBitCast(valArg, llvm::Type::getInt8PtrTy(F->getContext()));
  2399. Value *subscriptCast = Builder.CreateBitCast(subscript, llvm::Type::getInt8PtrTy(F->getContext()));
  2400. // TODO: use real type size and alignment.
  2401. Value *tySize = ConstantInt::get(idxTy, 8);
  2402. unsigned Align = 8;
  2403. Builder.CreateMemCpy(subscriptCast, valArgCast, tySize, Align);
  2404. } else
  2405. Builder.CreateStore(valArg, subscript);
  2406. Builder.CreateRetVoid();
  2407. } else {
  2408. // return Buf[counter];
  2409. if (valTy->isPointerTy())
  2410. Builder.CreateRet(subscript);
  2411. else {
  2412. Value *retVal = Builder.CreateLoad(subscript);
  2413. Builder.CreateRet(retVal);
  2414. }
  2415. }
  2416. } break;
  2417. case IntrinsicOp::IOP_sincos: {
  2418. opFunc = GetOrCreateHLFunctionWithBody(M, funcTy, group, opcode, "sincos");
  2419. llvm::Type *valTy = funcTy->getParamType(HLOperandIndex::kTrinaryOpSrc0Idx);
  2420. llvm::FunctionType *sinFuncTy =
  2421. llvm::FunctionType::get(valTy, {opcodeTy, valTy}, false);
  2422. unsigned sinOp = static_cast<unsigned>(IntrinsicOp::IOP_sin);
  2423. unsigned cosOp = static_cast<unsigned>(IntrinsicOp::IOP_cos);
  2424. Function *sinFunc = GetOrCreateHLFunction(M, sinFuncTy, group, sinOp);
  2425. Function *cosFunc = GetOrCreateHLFunction(M, sinFuncTy, group, cosOp);
  2426. BasicBlock *BB = BasicBlock::Create(opFunc->getContext(), "Entry", opFunc);
  2427. IRBuilder<> Builder(BB);
  2428. auto argIter = opFunc->args().begin();
  2429. // Skip the opcode arg.
  2430. argIter++;
  2431. Argument *valArg = argIter++;
  2432. Argument *sinPtrArg = argIter++;
  2433. Argument *cosPtrArg = argIter++;
  2434. Value *sinOpArg =
  2435. ConstantInt::get(opcodeTy, sinOp);
  2436. Value *sinVal = Builder.CreateCall(sinFunc, {sinOpArg, valArg});
  2437. Builder.CreateStore(sinVal, sinPtrArg);
  2438. Value *cosOpArg =
  2439. ConstantInt::get(opcodeTy, cosOp);
  2440. Value *cosVal = Builder.CreateCall(cosFunc, {cosOpArg, valArg});
  2441. Builder.CreateStore(cosVal, cosPtrArg);
  2442. // Ret.
  2443. Builder.CreateRetVoid();
  2444. } break;
  2445. default:
  2446. opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  2447. break;
  2448. }
  2449. } else {
  2450. opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  2451. }
  2452. // Add attribute
  2453. if (F->hasFnAttribute(Attribute::ReadNone))
  2454. opFunc->addFnAttr(Attribute::ReadNone);
  2455. if (F->hasFnAttribute(Attribute::ReadOnly))
  2456. opFunc->addFnAttr(Attribute::ReadOnly);
  2457. return opFunc;
  2458. }
  2459. static void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F,
  2460. unsigned opcode) {
  2461. llvm::Module &M = *HLM.GetModule();
  2462. llvm::FunctionType *oldFuncTy = F->getFunctionType();
  2463. SmallVector<llvm::Type *, 4> paramTyList;
  2464. // Add the opcode param
  2465. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  2466. paramTyList.emplace_back(opcodeTy);
  2467. paramTyList.append(oldFuncTy->param_begin(), oldFuncTy->param_end());
  2468. for (unsigned i = 1; i < paramTyList.size(); i++) {
  2469. llvm::Type *Ty = paramTyList[i];
  2470. if (Ty->isPointerTy()) {
  2471. Ty = Ty->getPointerElementType();
  2472. if (HLModule::IsHLSLObjectType(Ty) &&
  2473. // StreamOutput don't need handle.
  2474. !HLModule::IsStreamOutputType(Ty)) {
  2475. // Use object type directly, not by pointer.
  2476. // This will make sure temp object variable only used by ld/st.
  2477. paramTyList[i] = Ty;
  2478. }
  2479. }
  2480. }
  2481. HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByAttr(F);
  2482. if (group == HLOpcodeGroup::HLSubscript &&
  2483. opcode == static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)) {
  2484. llvm::FunctionType *FT = F->getFunctionType();
  2485. llvm::Type *VecArgTy = FT->getParamType(0);
  2486. llvm::VectorType *VType =
  2487. cast<llvm::VectorType>(VecArgTy->getPointerElementType());
  2488. llvm::Type *Ty = VType->getElementType();
  2489. DXASSERT(Ty->isIntegerTy(), "Only bool could use VectorSubscript");
  2490. llvm::IntegerType *ITy = cast<IntegerType>(Ty);
  2491. DXASSERT_LOCALVAR(ITy, ITy->getBitWidth() == 1, "Only bool could use VectorSubscript");
  2492. // The return type is i8*.
  2493. // Replace all uses with i1*.
  2494. ReplaceBoolVectorSubscript(F);
  2495. return;
  2496. }
  2497. bool isDoubleSubscriptFunc = group == HLOpcodeGroup::HLSubscript &&
  2498. opcode == static_cast<unsigned>(HLSubscriptOpcode::DoubleSubscript);
  2499. llvm::Type *RetTy = oldFuncTy->getReturnType();
  2500. if (isDoubleSubscriptFunc) {
  2501. CallInst *doubleSub = cast<CallInst>(*F->user_begin());
  2502. // Change currentIdx type into coord type.
  2503. auto U = doubleSub->user_begin();
  2504. Value *user = *U;
  2505. CallInst *secSub = cast<CallInst>(user);
  2506. unsigned coordIdx = HLOperandIndex::kSubscriptIndexOpIdx;
  2507. // opcode operand not add yet, so the index need -1.
  2508. if (GetHLOpcodeGroupByName(secSub->getCalledFunction()) == HLOpcodeGroup::NotHL)
  2509. coordIdx -= 1;
  2510. Value *coord = secSub->getArgOperand(coordIdx);
  2511. llvm::Type *coordTy = coord->getType();
  2512. paramTyList[HLOperandIndex::kSubscriptIndexOpIdx] = coordTy;
  2513. // Add the sampleIdx or mipLevel parameter to the end.
  2514. paramTyList.emplace_back(opcodeTy);
  2515. // Change return type to be resource ret type.
  2516. // opcode operand not add yet, so the index need -1.
  2517. Value *objPtr = doubleSub->getArgOperand(HLOperandIndex::kSubscriptObjectOpIdx-1);
  2518. // Must be a GEP
  2519. GEPOperator *objGEP = cast<GEPOperator>(objPtr);
  2520. gep_type_iterator GEPIt = gep_type_begin(objGEP), E = gep_type_end(objGEP);
  2521. llvm::Type *resTy = nullptr;
  2522. while (GEPIt != E) {
  2523. if (HLModule::IsHLSLObjectType(*GEPIt)) {
  2524. resTy = *GEPIt;
  2525. break;
  2526. }
  2527. GEPIt++;
  2528. }
  2529. DXASSERT(resTy, "must find the resource type");
  2530. // Change object type to resource type.
  2531. paramTyList[HLOperandIndex::kSubscriptObjectOpIdx] = resTy;
  2532. // Change RetTy into pointer of resource reture type.
  2533. RetTy = cast<StructType>(resTy)->getElementType(0)->getPointerTo();
  2534. llvm::Type *sliceTy = objGEP->getType()->getPointerElementType();
  2535. DXIL::ResourceClass RC = HLM.GetResourceClass(sliceTy);
  2536. DXIL::ResourceKind RK = HLM.GetResourceKind(sliceTy);
  2537. HLM.AddResourceTypeAnnotation(resTy, RC, RK);
  2538. }
  2539. llvm::FunctionType *funcTy =
  2540. llvm::FunctionType::get(RetTy, paramTyList, false);
  2541. Function *opFunc = CreateOpFunction(M, F, funcTy, group, opcode);
  2542. for (auto user = F->user_begin(); user != F->user_end();) {
  2543. // User must be a call.
  2544. CallInst *oldCI = cast<CallInst>(*(user++));
  2545. SmallVector<Value *, 4> opcodeParamList;
  2546. Value *opcodeConst = Constant::getIntegerValue(opcodeTy, APInt(32, opcode));
  2547. opcodeParamList.emplace_back(opcodeConst);
  2548. opcodeParamList.append(oldCI->arg_operands().begin(),
  2549. oldCI->arg_operands().end());
  2550. IRBuilder<> Builder(oldCI);
  2551. if (isDoubleSubscriptFunc) {
  2552. // Change obj to the resource pointer.
  2553. Value *objVal = opcodeParamList[HLOperandIndex::kSubscriptObjectOpIdx];
  2554. GEPOperator *objGEP = cast<GEPOperator>(objVal);
  2555. SmallVector<Value *, 8> IndexList;
  2556. IndexList.append(objGEP->idx_begin(), objGEP->idx_end());
  2557. Value *lastIndex = IndexList.back();
  2558. ConstantInt *constIndex = cast<ConstantInt>(lastIndex);
  2559. DXASSERT_LOCALVAR(constIndex, constIndex->getLimitedValue() == 1, "last index must 1");
  2560. // Remove the last index.
  2561. IndexList.pop_back();
  2562. objVal = objGEP->getPointerOperand();
  2563. if (IndexList.size() > 1)
  2564. objVal = Builder.CreateInBoundsGEP(objVal, IndexList);
  2565. // Change obj to the resource pointer.
  2566. opcodeParamList[HLOperandIndex::kSubscriptObjectOpIdx] = objVal;
  2567. // Set idx and mipIdx.
  2568. Value *mipIdx = opcodeParamList[HLOperandIndex::kSubscriptIndexOpIdx];
  2569. auto U = oldCI->user_begin();
  2570. Value *user = *U;
  2571. CallInst *secSub = cast<CallInst>(user);
  2572. unsigned idxOpIndex = HLOperandIndex::kSubscriptIndexOpIdx;
  2573. if (GetHLOpcodeGroupByName(secSub->getCalledFunction()) == HLOpcodeGroup::NotHL)
  2574. idxOpIndex--;
  2575. Value *idx = secSub->getArgOperand(idxOpIndex);
  2576. DXASSERT(secSub->hasOneUse(), "subscript should only has one use");
  2577. // Add the sampleIdx or mipLevel parameter to the end.
  2578. opcodeParamList[HLOperandIndex::kSubscriptIndexOpIdx] = idx;
  2579. opcodeParamList.emplace_back(mipIdx);
  2580. // Insert new call before secSub to make sure idx is ready to use.
  2581. Builder.SetInsertPoint(secSub);
  2582. }
  2583. for (unsigned i = 1; i < opcodeParamList.size(); i++) {
  2584. Value *arg = opcodeParamList[i];
  2585. llvm::Type *Ty = arg->getType();
  2586. if (Ty->isPointerTy()) {
  2587. Ty = Ty->getPointerElementType();
  2588. if (HLModule::IsHLSLObjectType(Ty) &&
  2589. // StreamOutput don't need handle.
  2590. !HLModule::IsStreamOutputType(Ty)) {
  2591. // Use object type directly, not by pointer.
  2592. // This will make sure temp object variable only used by ld/st.
  2593. if (GEPOperator *argGEP = dyn_cast<GEPOperator>(arg)) {
  2594. std::vector<Value*> idxList(argGEP->idx_begin(), argGEP->idx_end());
  2595. // Create instruction to avoid GEPOperator.
  2596. GetElementPtrInst *GEP = GetElementPtrInst::CreateInBounds(argGEP->getPointerOperand(),
  2597. idxList);
  2598. Builder.Insert(GEP);
  2599. arg = GEP;
  2600. }
  2601. opcodeParamList[i] = Builder.CreateLoad(arg);
  2602. }
  2603. }
  2604. }
  2605. Value *CI = Builder.CreateCall(opFunc, opcodeParamList);
  2606. if (!isDoubleSubscriptFunc) {
  2607. // replace new call and delete the old call
  2608. oldCI->replaceAllUsesWith(CI);
  2609. oldCI->eraseFromParent();
  2610. } else {
  2611. // For double script.
  2612. // Replace single users use with new CI.
  2613. auto U = oldCI->user_begin();
  2614. Value *user = *U;
  2615. CallInst *secSub = cast<CallInst>(user);
  2616. secSub->replaceAllUsesWith(CI);
  2617. secSub->eraseFromParent();
  2618. oldCI->eraseFromParent();
  2619. }
  2620. }
  2621. // delete the function
  2622. F->eraseFromParent();
  2623. }
  2624. static void AddOpcodeParamForIntrinsics(HLModule &HLM
  2625. , std::unordered_map<Function *, unsigned> &intrinsicMap) {
  2626. for (auto mapIter = intrinsicMap.begin(); mapIter != intrinsicMap.end();
  2627. mapIter++) {
  2628. Function *F = mapIter->first;
  2629. if (F->user_empty()) {
  2630. // delete the function
  2631. F->eraseFromParent();
  2632. continue;
  2633. }
  2634. unsigned opcode = mapIter->second;
  2635. AddOpcodeParamForIntrinsic(HLM, F, opcode);
  2636. }
  2637. }
  2638. static void SimplifyScalarToVec1Splat(BitCastInst *BCI, std::vector<Instruction *> &deadInsts) {
  2639. Value *Ptr = BCI->getOperand(0);
  2640. // For case like SsaoBuffer[DTid.xy].xxx;
  2641. // It will translated into
  2642. //%8 = bitcast float* %7 to <1 x float>*
  2643. //%9 = load <1 x float>, <1 x float>* %8
  2644. //%10 = shufflevector <1 x float> %9, <1 x float> undef, <3 x i32>
  2645. //zeroinitializer
  2646. // To remove the bitcast,
  2647. // We transform it into
  2648. // %8 = load float, float* %7
  2649. // %9 = insertelement <1 x float> undef, float %8, i64 0
  2650. // %10 = shufflevector <1 x float> %9, <1 x float> undef, <3 x i32>
  2651. // zeroinitializer
  2652. IRBuilder<> Builder(BCI);
  2653. Value *SVal = Builder.CreateLoad(Ptr);
  2654. Value *VVal = UndefValue::get(BCI->getType()->getPointerElementType());
  2655. VVal = Builder.CreateInsertElement(VVal, SVal, (uint64_t)0);
  2656. for (Value::user_iterator Iter = BCI->user_begin(), IterE = BCI->user_end();
  2657. Iter != IterE;) {
  2658. Instruction *I = cast<Instruction>(*(Iter++));
  2659. if (LoadInst *ldInst = dyn_cast<LoadInst>(I)) {
  2660. ldInst->replaceAllUsesWith(VVal);
  2661. deadInsts.emplace_back(ldInst);
  2662. } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
  2663. GEP->replaceAllUsesWith(Ptr);
  2664. deadInsts.emplace_back(GEP);
  2665. } else {
  2666. // Must be StoreInst here.
  2667. StoreInst *stInst = cast<StoreInst>(I);
  2668. Value *Val = stInst->getValueOperand();
  2669. IRBuilder<> Builder(stInst);
  2670. Val = Builder.CreateExtractElement(Val, (uint64_t)0);
  2671. Builder.CreateStore(Val, Ptr);
  2672. deadInsts.emplace_back(stInst);
  2673. }
  2674. }
  2675. deadInsts.emplace_back(BCI);
  2676. }
  2677. static void SimplifyVectorTrunc(BitCastInst *BCI, std::vector<Instruction *> &deadInsts) {
  2678. // Transform
  2679. //%a.addr = alloca <2 x float>, align 4
  2680. //%1 = bitcast <2 x float>* %a.addr to <1 x float>*
  2681. //%2 = getelementptr inbounds <1 x float>, <1 x float>* %1, i32 0, i32 0
  2682. // into
  2683. //%a.addr = alloca <2 x float>, align 4
  2684. //%2 = getelementptr inbounds <2 x float>, <2 x float>* %2, i32 0, i32 0
  2685. Value *bigVec = BCI->getOperand(0);
  2686. llvm::Type *idxTy = llvm::Type::getInt32Ty(BCI->getContext());
  2687. Constant *zeroIdx = ConstantInt::get(idxTy, 0);
  2688. unsigned vecSize = bigVec->getType()->getPointerElementType()->getVectorNumElements();
  2689. for (auto It = BCI->user_begin(), E = BCI->user_end(); It != E;) {
  2690. Instruction *I = cast<Instruction>(*(It++));
  2691. if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) {
  2692. DXASSERT_NOMSG(
  2693. !isa<llvm::VectorType>(GEP->getType()->getPointerElementType()));
  2694. IRBuilder<> Builder(GEP);
  2695. std::vector<Value *> idxList(GEP->idx_begin(), GEP->idx_end());
  2696. Value *NewGEP = Builder.CreateInBoundsGEP(bigVec, idxList);
  2697. GEP->replaceAllUsesWith(NewGEP);
  2698. deadInsts.emplace_back(GEP);
  2699. } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
  2700. IRBuilder<> Builder(LI);
  2701. Value *NewLI = Builder.CreateLoad(bigVec);
  2702. NewLI = Builder.CreateShuffleVector(NewLI, NewLI, {0});
  2703. LI->replaceAllUsesWith(NewLI);
  2704. deadInsts.emplace_back(LI);
  2705. } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
  2706. Value *V = SI->getValueOperand();
  2707. IRBuilder<> Builder(LI);
  2708. for (unsigned i = 0; i < vecSize; i++) {
  2709. Value *Elt = Builder.CreateExtractElement(V, i);
  2710. Value *EltGEP = Builder.CreateInBoundsGEP(
  2711. bigVec, {zeroIdx, ConstantInt::get(idxTy, i)});
  2712. Builder.CreateStore(Elt, EltGEP);
  2713. }
  2714. deadInsts.emplace_back(SI);
  2715. } else {
  2716. DXASSERT(0, "not support yet");
  2717. }
  2718. }
  2719. deadInsts.emplace_back(BCI);
  2720. }
  2721. static void SimplifyArrayToVector(Value *Cast, Value *Ptr, llvm::Type *i32Ty,
  2722. std::vector<Instruction *> &deadInsts) {
  2723. // Transform
  2724. // %4 = bitcast [4 x i32]* %Val2 to <4 x i32>*
  2725. // store <4 x i32> %5, <4 x i32>* %4, !tbaa !0
  2726. // Into
  2727. //%6 = extractelement <4 x i32> %5, i64 0
  2728. //%7 = getelementptr inbounds [4 x i32], [4 x i32]* %Val2, i32 0, i32 0
  2729. // store i32 %6, i32* %7
  2730. //%8 = extractelement <4 x i32> %5, i64 1
  2731. //%9 = getelementptr inbounds [4 x i32], [4 x i32]* %Val2, i32 0, i32 1
  2732. // store i32 %8, i32* %9
  2733. //%10 = extractelement <4 x i32> %5, i64 2
  2734. //%11 = getelementptr inbounds [4 x i32], [4 x i32]* %Val2, i32 0, i32 2
  2735. // store i32 %10, i32* %11
  2736. //%12 = extractelement <4 x i32> %5, i64 3
  2737. //%13 = getelementptr inbounds [4 x i32], [4 x i32]* %Val2, i32 0, i32 3
  2738. // store i32 %12, i32* %13
  2739. Value *zeroIdx = ConstantInt::get(i32Ty, 0);
  2740. for (User *U : Cast->users()) {
  2741. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  2742. IRBuilder<> Builder(LI);
  2743. unsigned vecSize = LI->getType()->getVectorNumElements();
  2744. Value *NewLd = UndefValue::get(LI->getType());
  2745. for (unsigned i = 0; i < vecSize; i++) {
  2746. Value *GEP = Builder.CreateInBoundsGEP(
  2747. Ptr, {zeroIdx, ConstantInt::get(i32Ty, i)});
  2748. Value *Elt = Builder.CreateLoad(GEP);
  2749. NewLd = Builder.CreateInsertElement(NewLd, Elt, i);
  2750. }
  2751. LI->replaceAllUsesWith(NewLd);
  2752. deadInsts.emplace_back(LI);
  2753. } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) {
  2754. Value *V = SI->getValueOperand();
  2755. IRBuilder<> Builder(SI);
  2756. unsigned vecSize = V->getType()->getVectorNumElements();
  2757. for (unsigned i = 0; i < vecSize; i++) {
  2758. Value *Elt = Builder.CreateExtractElement(V, i);
  2759. Value *GEP = Builder.CreateInBoundsGEP(
  2760. Ptr, {zeroIdx, ConstantInt::get(i32Ty, i)});
  2761. Builder.CreateStore(Elt, GEP);
  2762. }
  2763. deadInsts.emplace_back(SI);
  2764. } else {
  2765. DXASSERT(0, "not support yet");
  2766. }
  2767. }
  2768. }
  2769. static void SimplifyArrayToVector(BitCastInst *BCI, std::vector<Instruction *> &deadInsts) {
  2770. Value *Ptr = BCI->getOperand(0);
  2771. llvm::Type *i32Ty = llvm::Type::getInt32Ty(BCI->getContext());
  2772. SimplifyArrayToVector(BCI, Ptr, i32Ty, deadInsts);
  2773. deadInsts.emplace_back(BCI);
  2774. }
  2775. static void SimplifyBoolCast(BitCastInst *BCI, llvm::Type *i1Ty, std::vector<Instruction *> &deadInsts) {
  2776. // Transform
  2777. //%22 = bitcast i1* %21 to i8*
  2778. //%23 = load i8, i8* %22, !tbaa !3, !range !7
  2779. //%tobool5 = trunc i8 %23 to i1
  2780. // To
  2781. //%tobool5 = load i1, i1* %21, !tbaa !3, !range !7
  2782. Value *i1Ptr = BCI->getOperand(0);
  2783. for (User *U : BCI->users()) {
  2784. if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
  2785. if (!LI->hasOneUse()) {
  2786. continue;
  2787. }
  2788. if (TruncInst *TI = dyn_cast<TruncInst>(*LI->user_begin())) {
  2789. if (TI->getType() == i1Ty) {
  2790. IRBuilder<> Builder(LI);
  2791. Value *i1Val = Builder.CreateLoad(i1Ptr);
  2792. TI->replaceAllUsesWith(i1Val);
  2793. deadInsts.emplace_back(LI);
  2794. deadInsts.emplace_back(TI);
  2795. }
  2796. }
  2797. }
  2798. }
  2799. }
  2800. typedef float(__cdecl *FloatUnaryEvalFuncType)(float);
  2801. typedef double(__cdecl *DoubleUnaryEvalFuncType)(double);
  2802. typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float);
  2803. typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double);
  2804. static Value * EvalUnaryIntrinsic(CallInst *CI,
  2805. FloatUnaryEvalFuncType floatEvalFunc,
  2806. DoubleUnaryEvalFuncType doubleEvalFunc) {
  2807. Value *V = CI->getArgOperand(0);
  2808. ConstantFP *fpV = cast<ConstantFP>(V);
  2809. llvm::Type *Ty = CI->getType();
  2810. Value *Result = nullptr;
  2811. if (Ty->isDoubleTy()) {
  2812. double dV = fpV->getValueAPF().convertToDouble();
  2813. Value *dResult = ConstantFP::get(V->getType(), doubleEvalFunc(dV));
  2814. CI->replaceAllUsesWith(dResult);
  2815. Result = dResult;
  2816. } else {
  2817. DXASSERT_NOMSG(Ty->isFloatTy());
  2818. float fV = fpV->getValueAPF().convertToFloat();
  2819. Value *dResult = ConstantFP::get(V->getType(), floatEvalFunc(fV));
  2820. CI->replaceAllUsesWith(dResult);
  2821. Result = dResult;
  2822. }
  2823. CI->eraseFromParent();
  2824. return Result;
  2825. }
  2826. static Value * EvalBinaryIntrinsic(CallInst *CI,
  2827. FloatBinaryEvalFuncType floatEvalFunc,
  2828. DoubleBinaryEvalFuncType doubleEvalFunc) {
  2829. Value *V0 = CI->getArgOperand(0);
  2830. ConstantFP *fpV0 = cast<ConstantFP>(V0);
  2831. Value *V1 = CI->getArgOperand(1);
  2832. ConstantFP *fpV1 = cast<ConstantFP>(V1);
  2833. llvm::Type *Ty = CI->getType();
  2834. Value *Result = nullptr;
  2835. if (Ty->isDoubleTy()) {
  2836. double dV0 = fpV0->getValueAPF().convertToDouble();
  2837. double dV1 = fpV1->getValueAPF().convertToDouble();
  2838. Value *dResult = ConstantFP::get(V0->getType(), doubleEvalFunc(dV0, dV1));
  2839. CI->replaceAllUsesWith(dResult);
  2840. Result = dResult;
  2841. } else {
  2842. DXASSERT_NOMSG(Ty->isFloatTy());
  2843. float fV0 = fpV0->getValueAPF().convertToFloat();
  2844. float fV1 = fpV1->getValueAPF().convertToFloat();
  2845. Value *dResult = ConstantFP::get(V0->getType(), floatEvalFunc(fV0, fV1));
  2846. CI->replaceAllUsesWith(dResult);
  2847. Result = dResult;
  2848. }
  2849. CI->eraseFromParent();
  2850. return Result;
  2851. }
  2852. static Value * TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp) {
  2853. switch (intriOp) {
  2854. case IntrinsicOp::IOP_tan: {
  2855. return EvalUnaryIntrinsic(CI, tanf, tan);
  2856. } break;
  2857. case IntrinsicOp::IOP_tanh: {
  2858. return EvalUnaryIntrinsic(CI, tanhf, tanh);
  2859. } break;
  2860. case IntrinsicOp::IOP_sin: {
  2861. return EvalUnaryIntrinsic(CI, sinf, sin);
  2862. } break;
  2863. case IntrinsicOp::IOP_sinh: {
  2864. return EvalUnaryIntrinsic(CI, sinhf, sinh);
  2865. } break;
  2866. case IntrinsicOp::IOP_cos: {
  2867. return EvalUnaryIntrinsic(CI, cosf, cos);
  2868. } break;
  2869. case IntrinsicOp::IOP_cosh: {
  2870. return EvalUnaryIntrinsic(CI, coshf, cosh);
  2871. } break;
  2872. case IntrinsicOp::IOP_asin: {
  2873. return EvalUnaryIntrinsic(CI, asinf, asin);
  2874. } break;
  2875. case IntrinsicOp::IOP_acos: {
  2876. return EvalUnaryIntrinsic(CI, acosf, acos);
  2877. } break;
  2878. case IntrinsicOp::IOP_atan: {
  2879. return EvalUnaryIntrinsic(CI, atanf, atan);
  2880. } break;
  2881. case IntrinsicOp::IOP_atan2: {
  2882. Value *V0 = CI->getArgOperand(0);
  2883. ConstantFP *fpV0 = cast<ConstantFP>(V0);
  2884. Value *V1 = CI->getArgOperand(1);
  2885. ConstantFP *fpV1 = cast<ConstantFP>(V1);
  2886. llvm::Type *Ty = CI->getType();
  2887. Value *Result = nullptr;
  2888. if (Ty->isDoubleTy()) {
  2889. double dV0 = fpV0->getValueAPF().convertToDouble();
  2890. double dV1 = fpV1->getValueAPF().convertToDouble();
  2891. Value *atanV = ConstantFP::get(CI->getType(), atan(dV0 / dV1));
  2892. CI->replaceAllUsesWith(atanV);
  2893. Result = atanV;
  2894. } else {
  2895. DXASSERT_NOMSG(Ty->isFloatTy());
  2896. float fV0 = fpV0->getValueAPF().convertToFloat();
  2897. float fV1 = fpV1->getValueAPF().convertToFloat();
  2898. Value *atanV = ConstantFP::get(CI->getType(), atanf(fV0 / fV1));
  2899. CI->replaceAllUsesWith(atanV);
  2900. Result = atanV;
  2901. }
  2902. CI->eraseFromParent();
  2903. return Result;
  2904. } break;
  2905. case IntrinsicOp::IOP_sqrt: {
  2906. return EvalUnaryIntrinsic(CI, sqrtf, sqrt);
  2907. } break;
  2908. case IntrinsicOp::IOP_rsqrt: {
  2909. auto rsqrtF = [](float v) -> float { return 1.0 / sqrtf(v); };
  2910. auto rsqrtD = [](double v) -> double { return 1.0 / sqrt(v); };
  2911. return EvalUnaryIntrinsic(CI, rsqrtF, rsqrtD);
  2912. } break;
  2913. case IntrinsicOp::IOP_exp: {
  2914. return EvalUnaryIntrinsic(CI, expf, exp);
  2915. } break;
  2916. case IntrinsicOp::IOP_exp2: {
  2917. return EvalUnaryIntrinsic(CI, exp2f, exp2);
  2918. } break;
  2919. case IntrinsicOp::IOP_log: {
  2920. return EvalUnaryIntrinsic(CI, logf, log);
  2921. } break;
  2922. case IntrinsicOp::IOP_log10: {
  2923. return EvalUnaryIntrinsic(CI, log10f, log10);
  2924. } break;
  2925. case IntrinsicOp::IOP_log2: {
  2926. return EvalUnaryIntrinsic(CI, log2f, log2);
  2927. } break;
  2928. case IntrinsicOp::IOP_pow: {
  2929. return EvalBinaryIntrinsic(CI, powf, pow);
  2930. } break;
  2931. case IntrinsicOp::IOP_max: {
  2932. auto maxF = [](float a, float b) -> float { return a > b ? a:b; };
  2933. auto maxD = [](double a, double b) -> double { return a > b ? a:b; };
  2934. return EvalBinaryIntrinsic(CI, maxF, maxD);
  2935. } break;
  2936. case IntrinsicOp::IOP_min: {
  2937. auto minF = [](float a, float b) -> float { return a < b ? a:b; };
  2938. auto minD = [](double a, double b) -> double { return a < b ? a:b; };
  2939. return EvalBinaryIntrinsic(CI, minF, minD);
  2940. } break;
  2941. case IntrinsicOp::IOP_rcp: {
  2942. auto rcpF = [](float v) -> float { return 1.0 / v; };
  2943. auto rcpD = [](double v) -> double { return 1.0 / v; };
  2944. return EvalUnaryIntrinsic(CI, rcpF, rcpD);
  2945. } break;
  2946. case IntrinsicOp::IOP_ceil: {
  2947. return EvalUnaryIntrinsic(CI, ceilf, ceil);
  2948. } break;
  2949. case IntrinsicOp::IOP_floor: {
  2950. return EvalUnaryIntrinsic(CI, floorf, floor);
  2951. } break;
  2952. case IntrinsicOp::IOP_round: {
  2953. return EvalUnaryIntrinsic(CI, roundf, round);
  2954. } break;
  2955. case IntrinsicOp::IOP_trunc: {
  2956. return EvalUnaryIntrinsic(CI, truncf, trunc);
  2957. } break;
  2958. case IntrinsicOp::IOP_frac: {
  2959. auto fracF = [](float v) -> float {
  2960. int exp = 0;
  2961. return frexpf(v, &exp);
  2962. };
  2963. auto fracD = [](double v) -> double {
  2964. int exp = 0;
  2965. return frexp(v, &exp);
  2966. };
  2967. return EvalUnaryIntrinsic(CI, fracF, fracD);
  2968. } break;
  2969. case IntrinsicOp::IOP_isnan: {
  2970. Value *V = CI->getArgOperand(0);
  2971. ConstantFP *fV = cast<ConstantFP>(V);
  2972. bool isNan = fV->getValueAPF().isNaN();
  2973. Constant *cNan = ConstantInt::get(CI->getType(), isNan ? 1 : 0);
  2974. CI->replaceAllUsesWith(cNan);
  2975. CI->eraseFromParent();
  2976. return cNan;
  2977. } break;
  2978. case IntrinsicOp::IOP_firstbithigh: {
  2979. Value *V = CI->getArgOperand(0);
  2980. ConstantInt *iV = cast<ConstantInt>(V);
  2981. APInt v = iV->getValue();
  2982. Value *firstbit = nullptr;
  2983. if (v == 0) {
  2984. firstbit = ConstantInt::get(CI->getType(), -1);
  2985. } else {
  2986. bool mask = true;
  2987. if (v.isNegative())
  2988. mask = false;
  2989. unsigned bitWidth = v.getBitWidth();
  2990. for (int i = bitWidth - 2; i >= 0; i--) {
  2991. if (v[i] == mask) {
  2992. firstbit = ConstantInt::get(CI->getType(), bitWidth-1-i);
  2993. break;
  2994. }
  2995. }
  2996. }
  2997. CI->replaceAllUsesWith(firstbit);
  2998. CI->eraseFromParent();
  2999. return firstbit;
  3000. } break;
  3001. case IntrinsicOp::IOP_ufirstbithigh: {
  3002. Value *V = CI->getArgOperand(0);
  3003. ConstantInt *iV = cast<ConstantInt>(V);
  3004. APInt v = iV->getValue();
  3005. Value *firstbit = nullptr;
  3006. if (v == 0) {
  3007. firstbit = ConstantInt::get(CI->getType(), -1);
  3008. } else {
  3009. unsigned bitWidth = v.getBitWidth();
  3010. for (int i = bitWidth - 1; i >= 0; i--) {
  3011. if (v[i]) {
  3012. firstbit = ConstantInt::get(CI->getType(), bitWidth-1-i);
  3013. break;
  3014. }
  3015. }
  3016. }
  3017. CI->replaceAllUsesWith(firstbit);
  3018. CI->eraseFromParent();
  3019. return firstbit;
  3020. } break;
  3021. default:
  3022. return nullptr;
  3023. }
  3024. }
  3025. static void SimpleTransformForHLDXIR(Instruction *I,
  3026. std::vector<Instruction *> &deadInsts) {
  3027. unsigned opcode = I->getOpcode();
  3028. switch (opcode) {
  3029. case Instruction::BitCast: {
  3030. BitCastInst *BCI = cast<BitCastInst>(I);
  3031. llvm::Type *ToTy = BCI->getType();
  3032. llvm::Type *FromTy = BCI->getOperand(0)->getType();
  3033. if (ToTy->isPointerTy() && FromTy->isPointerTy()) {
  3034. ToTy = ToTy->getPointerElementType();
  3035. FromTy = FromTy->getPointerElementType();
  3036. llvm::Type *i1Ty = llvm::Type::getInt1Ty(ToTy->getContext());
  3037. if (ToTy->isVectorTy()) {
  3038. unsigned vecSize = ToTy->getVectorNumElements();
  3039. if (vecSize == 1 &&
  3040. ToTy->getVectorElementType() == FromTy) {
  3041. SimplifyScalarToVec1Splat(BCI, deadInsts);
  3042. } else if (FromTy->isVectorTy() && vecSize == 1) {
  3043. if (FromTy->getScalarType() == ToTy->getScalarType()) {
  3044. SimplifyVectorTrunc(BCI, deadInsts);
  3045. }
  3046. } else if (FromTy->isArrayTy()) {
  3047. llvm::Type *FromEltTy = FromTy->getArrayElementType();
  3048. llvm::Type *ToEltTy = ToTy->getVectorElementType();
  3049. if (FromTy->getArrayNumElements() == vecSize &&
  3050. FromEltTy == ToEltTy) {
  3051. SimplifyArrayToVector(BCI, deadInsts);
  3052. }
  3053. }
  3054. }
  3055. else if (FromTy == i1Ty) {
  3056. SimplifyBoolCast(BCI, i1Ty, deadInsts);
  3057. }
  3058. // TODO: support array to array cast.
  3059. }
  3060. } break;
  3061. case Instruction::Load: {
  3062. LoadInst *ldInst = cast<LoadInst>(I);
  3063. DXASSERT_LOCALVAR(ldInst, !HLMatrixLower::IsMatrixType(ldInst->getType()),
  3064. "matrix load should use HL LdStMatrix");
  3065. } break;
  3066. case Instruction::Store: {
  3067. StoreInst *stInst = cast<StoreInst>(I);
  3068. Value *V = stInst->getValueOperand();
  3069. DXASSERT_LOCALVAR(V, !HLMatrixLower::IsMatrixType(V->getType()),
  3070. "matrix store should use HL LdStMatrix");
  3071. } break;
  3072. case Instruction::LShr:
  3073. case Instruction::AShr:
  3074. case Instruction::Shl: {
  3075. llvm::BinaryOperator *BO = cast<llvm::BinaryOperator>(I);
  3076. Value *op2 = BO->getOperand(1);
  3077. IntegerType *Ty = cast<IntegerType>(BO->getType()->getScalarType());
  3078. unsigned bitWidth = Ty->getBitWidth();
  3079. // Clamp op2 to 0 ~ bitWidth-1
  3080. if (ConstantInt *cOp2 = dyn_cast<ConstantInt>(op2)) {
  3081. unsigned iOp2 = cOp2->getLimitedValue();
  3082. unsigned clampedOp2 = iOp2 & (bitWidth - 1);
  3083. if (iOp2 != clampedOp2) {
  3084. BO->setOperand(1, ConstantInt::get(op2->getType(), clampedOp2));
  3085. }
  3086. } else {
  3087. Value *mask = ConstantInt::get(op2->getType(), bitWidth - 1);
  3088. IRBuilder<> Builder(I);
  3089. op2 = Builder.CreateAnd(op2, mask);
  3090. BO->setOperand(1, op2);
  3091. }
  3092. } break;
  3093. }
  3094. }
  3095. // Do simple transform to make later lower pass easier.
  3096. static void SimpleTransformForHLDXIR(llvm::Module *pM) {
  3097. std::vector<Instruction *> deadInsts;
  3098. for (Function &F : pM->functions()) {
  3099. for (BasicBlock &BB : F.getBasicBlockList()) {
  3100. for (BasicBlock::iterator Iter = BB.begin(); Iter != BB.end(); ) {
  3101. Instruction *I = (Iter++);
  3102. SimpleTransformForHLDXIR(I, deadInsts);
  3103. }
  3104. }
  3105. }
  3106. llvm::Type *i32Ty = llvm::Type::getInt32Ty(pM->getContext());
  3107. for (GlobalVariable &GV : pM->globals()) {
  3108. if (HLModule::IsStaticGlobal(&GV)) {
  3109. for (User *U : GV.users()) {
  3110. if (BitCastOperator *BCO = dyn_cast<BitCastOperator>(U)) {
  3111. llvm::Type *ToTy = BCO->getType();
  3112. llvm::Type *FromTy = BCO->getOperand(0)->getType();
  3113. if (ToTy->isPointerTy() && FromTy->isPointerTy()) {
  3114. ToTy = ToTy->getPointerElementType();
  3115. FromTy = FromTy->getPointerElementType();
  3116. if (ToTy->isVectorTy()) {
  3117. unsigned vecSize = ToTy->getVectorNumElements();
  3118. if (FromTy->isArrayTy()) {
  3119. llvm::Type *FromEltTy = FromTy->getArrayElementType();
  3120. llvm::Type *ToEltTy = ToTy->getVectorElementType();
  3121. if (FromTy->getArrayNumElements() == vecSize &&
  3122. FromEltTy == ToEltTy) {
  3123. SimplifyArrayToVector(BCO, &GV, i32Ty, deadInsts);
  3124. }
  3125. }
  3126. }
  3127. // TODO: support array to array cast.
  3128. }
  3129. }
  3130. }
  3131. }
  3132. }
  3133. for (Instruction * I : deadInsts)
  3134. I->dropAllReferences();
  3135. for (Instruction * I : deadInsts)
  3136. I->eraseFromParent();
  3137. }
  3138. void CGMSHLSLRuntime::FinishCodeGen() {
  3139. SetEntryFunction();
  3140. // If at this point we haven't determined the entry function it's an error.
  3141. if (m_pHLModule->GetEntryFunction() == nullptr) {
  3142. assert(CGM.getDiags().hasErrorOccurred() &&
  3143. "else SetEntryFunction should have reported this condition");
  3144. return;
  3145. }
  3146. // Remove all useless functions.
  3147. if (!CGM.getCodeGenOpts().HLSLHighLevel) {
  3148. Function *patchConstantFunc = nullptr;
  3149. if (m_pHLModule->GetShaderModel()->IsHS()) {
  3150. patchConstantFunc = m_pHLModule->GetHLFunctionProps(EntryFunc)
  3151. .ShaderProps.HS.patchConstantFunc;
  3152. }
  3153. std::unordered_set<Function *> DeadFuncSet;
  3154. for (auto FIt = TheModule.functions().begin(),
  3155. FE = TheModule.functions().end();
  3156. FIt != FE;) {
  3157. Function *F = FIt++;
  3158. if (F != EntryFunc && F != patchConstantFunc && !F->isDeclaration()) {
  3159. if (F->user_empty())
  3160. F->eraseFromParent();
  3161. else
  3162. DeadFuncSet.insert(F);
  3163. }
  3164. }
  3165. while (!DeadFuncSet.empty()) {
  3166. bool noUpdate = true;
  3167. for (auto FIt = DeadFuncSet.begin(), FE = DeadFuncSet.end(); FIt != FE;) {
  3168. Function *F = *(FIt++);
  3169. if (F->user_empty()) {
  3170. DeadFuncSet.erase(F);
  3171. F->eraseFromParent();
  3172. noUpdate = false;
  3173. }
  3174. }
  3175. // Avoid dead loop.
  3176. if (noUpdate)
  3177. break;
  3178. }
  3179. // Remove unused external function.
  3180. for (auto FIt = TheModule.functions().begin(),
  3181. FE = TheModule.functions().end();
  3182. FIt != FE;) {
  3183. Function *F = FIt++;
  3184. if (F->isDeclaration() && F->user_empty()) {
  3185. if (m_IntrinsicMap.count(F))
  3186. m_IntrinsicMap.erase(F);
  3187. F->eraseFromParent();
  3188. }
  3189. }
  3190. }
  3191. // Create copy for clip plane.
  3192. for (Function *F : clipPlaneFuncList) {
  3193. HLFunctionProps &props = m_pHLModule->GetHLFunctionProps(F);
  3194. IRBuilder<> Builder(F->getEntryBlock().getFirstInsertionPt());
  3195. for (unsigned i = 0; i < DXIL::kNumClipPlanes; i++) {
  3196. Value *clipPlane = props.ShaderProps.VS.clipPlanes[i];
  3197. if (!clipPlane)
  3198. continue;
  3199. if (m_bDebugInfo) {
  3200. Builder.SetCurrentDebugLocation(debugInfoMap[clipPlane]);
  3201. }
  3202. llvm::Type *Ty = clipPlane->getType()->getPointerElementType();
  3203. // Constant *zeroInit = ConstantFP::get(Ty, 0);
  3204. GlobalVariable *GV = new llvm::GlobalVariable(
  3205. TheModule, Ty, /*IsConstant*/ false, // constant false to store.
  3206. llvm::GlobalValue::ExternalLinkage,
  3207. /*InitVal*/ nullptr, Twine("SV_ClipPlane") + Twine(i));
  3208. Value *initVal = Builder.CreateLoad(clipPlane);
  3209. Builder.CreateStore(initVal, GV);
  3210. props.ShaderProps.VS.clipPlanes[i] = GV;
  3211. }
  3212. }
  3213. // Allocate constant buffers.
  3214. AllocateDxilConstantBuffers(m_pHLModule);
  3215. // TODO: create temp variable for constant which has store use.
  3216. // Create Global variable and type annotation for each CBuffer.
  3217. ConstructCBuffer(m_pHLModule, CBufferType, m_ConstVarAnnotationMap);
  3218. // add global call to entry func
  3219. auto AddGlobalCall = [&](StringRef globalName, Instruction *InsertPt) {
  3220. GlobalVariable *GV = TheModule.getGlobalVariable(globalName);
  3221. if (GV) {
  3222. if (ConstantArray *CA = dyn_cast<ConstantArray>(GV->getInitializer())) {
  3223. IRBuilder<> Builder(InsertPt);
  3224. for (User::op_iterator i = CA->op_begin(), e = CA->op_end(); i != e;
  3225. ++i) {
  3226. if (isa<ConstantAggregateZero>(*i))
  3227. continue;
  3228. ConstantStruct *CS = cast<ConstantStruct>(*i);
  3229. if (isa<ConstantPointerNull>(CS->getOperand(1)))
  3230. continue;
  3231. // Must have a function or null ptr.
  3232. if (!isa<Function>(CS->getOperand(1)))
  3233. continue;
  3234. Function *Ctor = cast<Function>(CS->getOperand(1));
  3235. assert(Ctor->getReturnType()->isVoidTy() && Ctor->arg_size() == 0 &&
  3236. "function type must be void (void)");
  3237. Builder.CreateCall(Ctor);
  3238. }
  3239. // remove the GV
  3240. GV->eraseFromParent();
  3241. }
  3242. }
  3243. };
  3244. // need this for "llvm.global_dtors"?
  3245. AddGlobalCall("llvm.global_ctors",
  3246. EntryFunc->getEntryBlock().getFirstInsertionPt());
  3247. // translate opcode into parameter for intrinsic functions
  3248. AddOpcodeParamForIntrinsics(*m_pHLModule, m_IntrinsicMap);
  3249. // Pin entry point and constant buffers, mark everything else internal.
  3250. for (Function &f : m_pHLModule->GetModule()->functions()) {
  3251. if (&f == m_pHLModule->GetEntryFunction() || IsPatchConstantFunction(&f) ||
  3252. f.isDeclaration()) {
  3253. f.setLinkage(GlobalValue::LinkageTypes::ExternalLinkage);
  3254. } else {
  3255. f.setLinkage(GlobalValue::LinkageTypes::InternalLinkage);
  3256. }
  3257. // Always inline.
  3258. f.addFnAttr(llvm::Attribute::AlwaysInline);
  3259. }
  3260. // Do simple transform to make later lower pass easier.
  3261. SimpleTransformForHLDXIR(m_pHLModule->GetModule());
  3262. }
  3263. RValue CGMSHLSLRuntime::EmitHLSLBuiltinCallExpr(CodeGenFunction &CGF,
  3264. const FunctionDecl *FD,
  3265. const CallExpr *E,
  3266. ReturnValueSlot ReturnValue) {
  3267. StringRef name = FD->getName();
  3268. const Decl *TargetDecl = E->getCalleeDecl();
  3269. llvm::Value *Callee = CGF.EmitScalarExpr(E->getCallee());
  3270. RValue RV = CGF.EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue,
  3271. TargetDecl);
  3272. if (RV.isScalar() && RV.getScalarVal() != nullptr) {
  3273. if (CallInst *CI = dyn_cast<CallInst>(RV.getScalarVal())) {
  3274. Function *F = CI->getCalledFunction();
  3275. HLOpcodeGroup group = hlsl::GetHLOpcodeGroupByAttr(F);
  3276. if (group == HLOpcodeGroup::HLIntrinsic) {
  3277. bool allOperandImm = true;
  3278. for (auto &operand : CI->arg_operands()) {
  3279. bool isImm = isa<ConstantInt>(operand) || isa<ConstantFP>(operand);
  3280. if (!isImm) {
  3281. allOperandImm = false;
  3282. break;
  3283. }
  3284. }
  3285. if (allOperandImm) {
  3286. unsigned intrinsicOpcode;
  3287. StringRef intrinsicGroup;
  3288. hlsl::GetIntrinsicOp(FD, intrinsicOpcode, intrinsicGroup);
  3289. IntrinsicOp opcode = static_cast<IntrinsicOp>(intrinsicOpcode);
  3290. if (Value *Result = TryEvalIntrinsic(CI, opcode)) {
  3291. RV = RValue::get(Result);
  3292. }
  3293. }
  3294. }
  3295. }
  3296. }
  3297. return RV;
  3298. }
  3299. static HLOpcodeGroup GetHLOpcodeGroup(const clang::Stmt::StmtClass stmtClass) {
  3300. switch (stmtClass) {
  3301. case Stmt::CStyleCastExprClass:
  3302. case Stmt::ImplicitCastExprClass:
  3303. case Stmt::CXXFunctionalCastExprClass:
  3304. return HLOpcodeGroup::HLCast;
  3305. case Stmt::InitListExprClass:
  3306. return HLOpcodeGroup::HLInit;
  3307. case Stmt::BinaryOperatorClass:
  3308. case Stmt::CompoundAssignOperatorClass:
  3309. return HLOpcodeGroup::HLBinOp;
  3310. case Stmt::UnaryOperatorClass:
  3311. return HLOpcodeGroup::HLUnOp;
  3312. case Stmt::ExtMatrixElementExprClass:
  3313. return HLOpcodeGroup::HLSubscript;
  3314. case Stmt::CallExprClass:
  3315. return HLOpcodeGroup::HLIntrinsic;
  3316. case Stmt::ConditionalOperatorClass:
  3317. return HLOpcodeGroup::HLSelect;
  3318. default:
  3319. llvm_unreachable("not support operation");
  3320. }
  3321. }
  3322. // NOTE: This table must match BinaryOperator::Opcode
  3323. static const HLBinaryOpcode BinaryOperatorKindMap[] = {
  3324. HLBinaryOpcode::Invalid, // PtrMemD
  3325. HLBinaryOpcode::Invalid, // PtrMemI
  3326. HLBinaryOpcode::Mul, HLBinaryOpcode::Div, HLBinaryOpcode::Rem,
  3327. HLBinaryOpcode::Add, HLBinaryOpcode::Sub, HLBinaryOpcode::Shl,
  3328. HLBinaryOpcode::Shr, HLBinaryOpcode::LT, HLBinaryOpcode::GT,
  3329. HLBinaryOpcode::LE, HLBinaryOpcode::GE, HLBinaryOpcode::EQ,
  3330. HLBinaryOpcode::NE, HLBinaryOpcode::And, HLBinaryOpcode::Xor,
  3331. HLBinaryOpcode::Or, HLBinaryOpcode::LAnd, HLBinaryOpcode::LOr,
  3332. HLBinaryOpcode::Invalid, // Assign,
  3333. // The assign part is done by matrix store
  3334. HLBinaryOpcode::Mul, // MulAssign
  3335. HLBinaryOpcode::Div, // DivAssign
  3336. HLBinaryOpcode::Rem, // RemAssign
  3337. HLBinaryOpcode::Add, // AddAssign
  3338. HLBinaryOpcode::Sub, // SubAssign
  3339. HLBinaryOpcode::Shl, // ShlAssign
  3340. HLBinaryOpcode::Shr, // ShrAssign
  3341. HLBinaryOpcode::And, // AndAssign
  3342. HLBinaryOpcode::Xor, // XorAssign
  3343. HLBinaryOpcode::Or, // OrAssign
  3344. HLBinaryOpcode::Invalid, // Comma
  3345. };
  3346. // NOTE: This table must match UnaryOperator::Opcode
  3347. static const HLUnaryOpcode UnaryOperatorKindMap[] = {
  3348. HLUnaryOpcode::PostInc, HLUnaryOpcode::PostDec,
  3349. HLUnaryOpcode::PreInc, HLUnaryOpcode::PreDec,
  3350. HLUnaryOpcode::Invalid, // AddrOf,
  3351. HLUnaryOpcode::Invalid, // Deref,
  3352. HLUnaryOpcode::Plus, HLUnaryOpcode::Minus,
  3353. HLUnaryOpcode::Not, HLUnaryOpcode::LNot,
  3354. HLUnaryOpcode::Invalid, // Real,
  3355. HLUnaryOpcode::Invalid, // Imag,
  3356. HLUnaryOpcode::Invalid, // Extension
  3357. };
  3358. static bool IsRowMajorMatrix(QualType Ty, bool bDefaultRowMajor) {
  3359. if (const AttributedType *AT = Ty->getAs<AttributedType>()) {
  3360. if (AT->getAttrKind() == AttributedType::attr_hlsl_row_major)
  3361. return true;
  3362. else if (AT->getAttrKind() == AttributedType::attr_hlsl_column_major)
  3363. return false;
  3364. else
  3365. return bDefaultRowMajor;
  3366. } else {
  3367. return bDefaultRowMajor;
  3368. }
  3369. }
  3370. static bool IsUnsigned(QualType Ty) {
  3371. Ty = Ty.getCanonicalType().getNonReferenceType();
  3372. if (hlsl::IsHLSLVecMatType(Ty))
  3373. Ty = CGHLSLRuntime::GetHLSLVecMatElementType(Ty);
  3374. if (Ty->isExtVectorType())
  3375. Ty = Ty->getAs<clang::ExtVectorType>()->getElementType();
  3376. return Ty->isUnsignedIntegerType();
  3377. }
  3378. static unsigned GetHLOpcode(const Expr *E) {
  3379. switch (E->getStmtClass()) {
  3380. case Stmt::CompoundAssignOperatorClass:
  3381. case Stmt::BinaryOperatorClass: {
  3382. const clang::BinaryOperator *binOp = cast<clang::BinaryOperator>(E);
  3383. HLBinaryOpcode binOpcode = BinaryOperatorKindMap[binOp->getOpcode()];
  3384. if (HasUnsignedOpcode(binOpcode)) {
  3385. if (IsUnsigned(binOp->getLHS()->getType())) {
  3386. binOpcode = GetUnsignedOpcode(binOpcode);
  3387. }
  3388. }
  3389. return static_cast<unsigned>(binOpcode);
  3390. }
  3391. case Stmt::UnaryOperatorClass: {
  3392. const UnaryOperator *unOp = cast<clang::UnaryOperator>(E);
  3393. HLUnaryOpcode unOpcode = UnaryOperatorKindMap[unOp->getOpcode()];
  3394. return static_cast<unsigned>(unOpcode);
  3395. }
  3396. case Stmt::ImplicitCastExprClass:
  3397. case Stmt::CStyleCastExprClass: {
  3398. const CastExpr *CE = cast<CastExpr>(E);
  3399. bool toUnsigned = IsUnsigned(E->getType());
  3400. bool fromUnsigned = IsUnsigned(CE->getSubExpr()->getType());
  3401. if (toUnsigned && fromUnsigned)
  3402. return static_cast<unsigned>(HLCastOpcode::UnsignedUnsignedCast);
  3403. else if (toUnsigned)
  3404. return static_cast<unsigned>(HLCastOpcode::ToUnsignedCast);
  3405. else if (fromUnsigned)
  3406. return static_cast<unsigned>(HLCastOpcode::FromUnsignedCast);
  3407. else
  3408. return static_cast<unsigned>(HLCastOpcode::DefaultCast);
  3409. }
  3410. default:
  3411. return 0;
  3412. }
  3413. }
  3414. static Value *
  3415. EmitHLSLMatrixOperationCallImp(CGBuilderTy &Builder, HLOpcodeGroup group,
  3416. unsigned opcode, llvm::Type *RetType,
  3417. ArrayRef<Value *> paramList, llvm::Module &M) {
  3418. SmallVector<llvm::Type *, 4> paramTyList;
  3419. // Add the opcode param
  3420. llvm::Type *opcodeTy = llvm::Type::getInt32Ty(M.getContext());
  3421. paramTyList.emplace_back(opcodeTy);
  3422. for (Value *param : paramList) {
  3423. paramTyList.emplace_back(param->getType());
  3424. }
  3425. llvm::FunctionType *funcTy =
  3426. llvm::FunctionType::get(RetType, paramTyList, false);
  3427. Function *opFunc = GetOrCreateHLFunction(M, funcTy, group, opcode);
  3428. SmallVector<Value *, 4> opcodeParamList;
  3429. Value *opcodeConst = Constant::getIntegerValue(opcodeTy, APInt(32, opcode));
  3430. opcodeParamList.emplace_back(opcodeConst);
  3431. opcodeParamList.append(paramList.begin(), paramList.end());
  3432. return Builder.CreateCall(opFunc, opcodeParamList);
  3433. }
  3434. static Value *EmitHLSLArrayInit(CGBuilderTy &Builder, HLOpcodeGroup group,
  3435. unsigned opcode, llvm::Type *RetType,
  3436. ArrayRef<Value *> paramList, llvm::Module &M) {
  3437. // It's a matrix init.
  3438. if (!RetType->isVoidTy())
  3439. return EmitHLSLMatrixOperationCallImp(Builder, group, opcode, RetType,
  3440. paramList, M);
  3441. Value *arrayPtr = paramList[0];
  3442. llvm::ArrayType *AT =
  3443. cast<llvm::ArrayType>(arrayPtr->getType()->getPointerElementType());
  3444. // Avoid the arrayPtr.
  3445. unsigned paramSize = paramList.size() - 1;
  3446. // Support simple case here.
  3447. if (paramSize == AT->getArrayNumElements()) {
  3448. bool typeMatch = true;
  3449. llvm::Type *EltTy = AT->getArrayElementType();
  3450. if (EltTy->isAggregateType()) {
  3451. // Aggregate Type use pointer in initList.
  3452. EltTy = llvm::PointerType::get(EltTy, 0);
  3453. }
  3454. for (unsigned i = 1; i < paramList.size(); i++) {
  3455. if (paramList[i]->getType() != EltTy) {
  3456. typeMatch = false;
  3457. break;
  3458. }
  3459. }
  3460. // Both size and type match.
  3461. if (typeMatch) {
  3462. bool isPtr = EltTy->isPointerTy();
  3463. llvm::Type *i32Ty = llvm::Type::getInt32Ty(EltTy->getContext());
  3464. Constant *zero = ConstantInt::get(i32Ty, 0);
  3465. for (unsigned i = 1; i < paramList.size(); i++) {
  3466. Constant *idx = ConstantInt::get(i32Ty, i - 1);
  3467. Value *GEP = Builder.CreateInBoundsGEP(arrayPtr, {zero, idx});
  3468. Value *Elt = paramList[i];
  3469. if (isPtr) {
  3470. Elt = Builder.CreateLoad(Elt);
  3471. }
  3472. Builder.CreateStore(Elt, GEP);
  3473. }
  3474. // The return value will not be used.
  3475. return nullptr;
  3476. }
  3477. }
  3478. // Other case will be lowered in later pass.
  3479. return EmitHLSLMatrixOperationCallImp(Builder, group, opcode, RetType,
  3480. paramList, M);
  3481. }
  3482. void CGMSHLSLRuntime::FlattenValToInitList(CodeGenFunction &CGF, SmallVector<Value *, 4> &elts,
  3483. SmallVector<QualType, 4> &eltTys,
  3484. QualType Ty, Value *val) {
  3485. CGBuilderTy &Builder = CGF.Builder;
  3486. llvm::Type *valTy = val->getType();
  3487. if (valTy->isPointerTy()) {
  3488. llvm::Type *valEltTy = valTy->getPointerElementType();
  3489. if (valEltTy->isVectorTy() ||
  3490. valEltTy->isSingleValueType()) {
  3491. Value *ldVal = Builder.CreateLoad(val);
  3492. FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
  3493. } else if (HLMatrixLower::IsMatrixType(valEltTy)) {
  3494. Value *ldVal = EmitHLSLMatrixLoad(Builder, val, Ty);
  3495. FlattenValToInitList(CGF, elts, eltTys, Ty, ldVal);
  3496. } else {
  3497. llvm::Type *i32Ty = llvm::Type::getInt32Ty(valTy->getContext());
  3498. Value *zero = ConstantInt::get(i32Ty, 0);
  3499. if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(valEltTy)) {
  3500. QualType EltTy = Ty->getAsArrayTypeUnsafe()->getElementType();
  3501. for (unsigned i = 0; i < AT->getArrayNumElements(); i++) {
  3502. Value *gepIdx = ConstantInt::get(i32Ty, i);
  3503. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  3504. FlattenValToInitList(CGF, elts, eltTys, EltTy,EltPtr);
  3505. }
  3506. } else {
  3507. // Struct.
  3508. StructType *ST = cast<StructType>(valEltTy);
  3509. if (HLModule::IsHLSLObjectType(ST)) {
  3510. // Save object directly like basic type.
  3511. elts.emplace_back(Builder.CreateLoad(val));
  3512. eltTys.emplace_back(Ty);
  3513. } else {
  3514. RecordDecl *RD = Ty->getAsStructureType()->getDecl();
  3515. const CGRecordLayout& RL = CGF.getTypes().getCGRecordLayout(RD);
  3516. // Take care base.
  3517. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  3518. if (CXXRD->getNumBases()) {
  3519. for (const auto &I : CXXRD->bases()) {
  3520. const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
  3521. I.getType()->castAs<RecordType>()->getDecl());
  3522. if (BaseDecl->field_empty())
  3523. continue;
  3524. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  3525. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  3526. Value *gepIdx = ConstantInt::get(i32Ty, i);
  3527. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  3528. FlattenValToInitList(CGF, elts, eltTys, parentTy, EltPtr);
  3529. }
  3530. }
  3531. }
  3532. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  3533. fieldIter != fieldEnd; ++fieldIter) {
  3534. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  3535. Value *gepIdx = ConstantInt::get(i32Ty, i);
  3536. Value *EltPtr = Builder.CreateInBoundsGEP(val, {zero, gepIdx});
  3537. FlattenValToInitList(CGF, elts, eltTys, fieldIter->getType(), EltPtr);
  3538. }
  3539. }
  3540. }
  3541. }
  3542. } else {
  3543. if (HLMatrixLower::IsMatrixType(valTy)) {
  3544. unsigned col, row;
  3545. llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(valTy, col, row);
  3546. unsigned matSize = col * row;
  3547. bool isRowMajor = IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor);
  3548. HLCastOpcode opcode = isRowMajor ? HLCastOpcode::RowMatrixToVecCast
  3549. : HLCastOpcode::ColMatrixToVecCast;
  3550. // Cast to vector.
  3551. val = EmitHLSLMatrixOperationCallImp(
  3552. Builder, HLOpcodeGroup::HLCast,
  3553. static_cast<unsigned>(opcode),
  3554. llvm::VectorType::get(EltTy, matSize), {val}, TheModule);
  3555. valTy = val->getType();
  3556. }
  3557. if (valTy->isVectorTy()) {
  3558. QualType EltTy = GetHLSLVecMatElementType(Ty);
  3559. unsigned vecSize = valTy->getVectorNumElements();
  3560. for (unsigned i = 0; i < vecSize; i++) {
  3561. Value *Elt = Builder.CreateExtractElement(val, i);
  3562. elts.emplace_back(Elt);
  3563. eltTys.emplace_back(EltTy);
  3564. }
  3565. } else {
  3566. DXASSERT(valTy->isSingleValueType(), "must be single value type here");
  3567. elts.emplace_back(val);
  3568. eltTys.emplace_back(Ty);
  3569. }
  3570. }
  3571. }
  3572. // Cast elements in initlist if not match the target type.
  3573. // idx is current element index in initlist, Ty is target type.
  3574. static void AddMissingCastOpsInInitList(SmallVector<Value *, 4> &elts, SmallVector<QualType, 4> eltTys, unsigned &idx, QualType Ty, CodeGenFunction &CGF) {
  3575. if (Ty->isArrayType()) {
  3576. const clang::ArrayType *AT = Ty->getAsArrayTypeUnsafe();
  3577. // Must be ConstantArrayType here.
  3578. unsigned arraySize = cast<ConstantArrayType>(AT)->getSize().getLimitedValue();
  3579. QualType EltTy = AT->getElementType();
  3580. for (unsigned i = 0; i < arraySize; i++)
  3581. AddMissingCastOpsInInitList(elts, eltTys, idx, EltTy, CGF);
  3582. } else if (IsHLSLVecType(Ty)) {
  3583. QualType EltTy = GetHLSLVecElementType(Ty);
  3584. unsigned vecSize = GetHLSLVecSize(Ty);
  3585. for (unsigned i=0;i< vecSize;i++)
  3586. AddMissingCastOpsInInitList(elts, eltTys, idx, EltTy, CGF);
  3587. } else if (IsHLSLMatType(Ty)) {
  3588. QualType EltTy = GetHLSLMatElementType(Ty);
  3589. unsigned row, col;
  3590. GetHLSLMatRowColCount(Ty, row, col);
  3591. unsigned matSize = row*col;
  3592. for (unsigned i = 0; i < matSize; i++)
  3593. AddMissingCastOpsInInitList(elts, eltTys, idx, EltTy, CGF);
  3594. } else if (Ty->isRecordType()) {
  3595. if (HLModule::IsHLSLObjectType(CGF.ConvertType(Ty))) {
  3596. // Skip hlsl object.
  3597. idx++;
  3598. } else {
  3599. const RecordType *RT = Ty->getAsStructureType();
  3600. // For CXXRecord.
  3601. if (!RT)
  3602. RT = Ty->getAs<RecordType>();
  3603. RecordDecl *RD = RT->getDecl();
  3604. for (FieldDecl *field : RD->fields())
  3605. AddMissingCastOpsInInitList(elts, eltTys, idx, field->getType(), CGF);
  3606. }
  3607. }
  3608. else {
  3609. // Basic type.
  3610. Value *val = elts[idx];
  3611. llvm::Type *srcTy = val->getType();
  3612. llvm::Type *dstTy = CGF.ConvertType(Ty);
  3613. if (srcTy != dstTy) {
  3614. Instruction::CastOps castOp =
  3615. static_cast<Instruction::CastOps>(HLModule::FindCastOp(
  3616. IsUnsigned(eltTys[idx]), IsUnsigned(Ty), srcTy, dstTy));
  3617. elts[idx] = CGF.Builder.CreateCast(castOp, val, dstTy);
  3618. }
  3619. idx++;
  3620. }
  3621. }
  3622. static void StoreInitListToDestPtr(Value *DestPtr, SmallVector<Value *, 4> &elts, unsigned &idx, CGBuilderTy &Builder, llvm::Module &M) {
  3623. llvm::Type *Ty = DestPtr->getType()->getPointerElementType();
  3624. llvm::Type *i32Ty = llvm::Type::getInt32Ty(Ty->getContext());
  3625. if (Ty->isVectorTy()) {
  3626. Value *Result = UndefValue::get(Ty);
  3627. for (unsigned i = 0; i < Ty->getVectorNumElements(); i++)
  3628. Result = Builder.CreateInsertElement(Result, elts[idx+i], i);
  3629. Builder.CreateStore(Result, DestPtr);
  3630. idx += Ty->getVectorNumElements();
  3631. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  3632. unsigned row, col;
  3633. HLMatrixLower::GetMatrixInfo(Ty, col, row);
  3634. std::vector<Value*> matInitList(col*row);
  3635. for (unsigned i = 0; i < col; i++) {
  3636. for (unsigned r = 0; r < row; r++) {
  3637. unsigned matIdx = i * row + r;
  3638. matInitList[matIdx] = elts[idx+matIdx];
  3639. }
  3640. }
  3641. idx += row*col;
  3642. Value *matVal = EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLInit,
  3643. /*opcode*/0, Ty, matInitList, M);
  3644. EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLMatLoadStore,
  3645. static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore), Ty,
  3646. {DestPtr, matVal}, M);
  3647. } else if (Ty->isStructTy()) {
  3648. if (HLModule::IsHLSLObjectType(Ty)) {
  3649. Builder.CreateStore(elts[idx], DestPtr);
  3650. idx++;
  3651. } else {
  3652. Constant *zero = ConstantInt::get(i32Ty, 0);
  3653. for (unsigned i = 0; i < Ty->getStructNumElements(); i++) {
  3654. Constant *gepIdx = ConstantInt::get(i32Ty, i);
  3655. Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
  3656. StoreInitListToDestPtr(GEP, elts, idx, Builder, M);
  3657. }
  3658. }
  3659. } else if (Ty->isArrayTy()) {
  3660. Constant *zero = ConstantInt::get(i32Ty, 0);
  3661. for (unsigned i = 0; i < Ty->getArrayNumElements(); i++) {
  3662. Constant *gepIdx = ConstantInt::get(i32Ty, i);
  3663. Value *GEP = Builder.CreateInBoundsGEP(DestPtr, {zero, gepIdx});
  3664. StoreInitListToDestPtr(GEP, elts, idx, Builder, M);
  3665. }
  3666. } else {
  3667. DXASSERT(Ty->isSingleValueType(), "invalid type");
  3668. llvm::Type *i1Ty = Builder.getInt1Ty();
  3669. Value *V = elts[idx];
  3670. if (V->getType() == i1Ty && DestPtr->getType()->getPointerElementType() != i1Ty) {
  3671. V = Builder.CreateZExt(V, DestPtr->getType()->getPointerElementType());
  3672. }
  3673. Builder.CreateStore(V, DestPtr);
  3674. idx++;
  3675. }
  3676. }
  3677. void CGMSHLSLRuntime::ScanInitList(CodeGenFunction &CGF, InitListExpr *E,
  3678. SmallVector<Value *, 4> &EltValList,
  3679. SmallVector<QualType, 4> &EltTyList) {
  3680. unsigned NumInitElements = E->getNumInits();
  3681. for (unsigned i = 0; i != NumInitElements; ++i) {
  3682. Expr *init = E->getInit(i);
  3683. QualType iType = init->getType();
  3684. if (InitListExpr *initList = dyn_cast<InitListExpr>(init)) {
  3685. ScanInitList(CGF, initList, EltValList, EltTyList);
  3686. } else if (CodeGenFunction::hasScalarEvaluationKind(iType)) {
  3687. llvm::Value *initVal = CGF.EmitScalarExpr(init);
  3688. FlattenValToInitList(CGF, EltValList, EltTyList, iType, initVal);
  3689. } else {
  3690. AggValueSlot Slot =
  3691. CGF.CreateAggTemp(init->getType(), "Agg.InitList.tmp");
  3692. CGF.EmitAggExpr(init, Slot);
  3693. llvm::Value *aggPtr = Slot.getAddr();
  3694. FlattenValToInitList(CGF, EltValList, EltTyList, iType, aggPtr);
  3695. }
  3696. }
  3697. }
  3698. unsigned CGMSHLSLRuntime::ScanInitList(InitListExpr *E) {
  3699. unsigned NumInitElements = E->getNumInits();
  3700. unsigned size = 0;
  3701. for (unsigned i = 0; i != NumInitElements; ++i) {
  3702. Expr *init = E->getInit(i);
  3703. QualType iType = init->getType();
  3704. if (InitListExpr *initList = dyn_cast<InitListExpr>(init)) {
  3705. size += ScanInitList(initList);
  3706. } else if (CodeGenFunction::hasScalarEvaluationKind(iType)) {
  3707. size += GetElementCount(iType);
  3708. } else {
  3709. DXASSERT(0, "not support yet");
  3710. }
  3711. }
  3712. return size;
  3713. }
  3714. QualType CGMSHLSLRuntime::UpdateHLSLIncompleteArrayType(VarDecl &D) {
  3715. if (!D.hasInit())
  3716. return D.getType();
  3717. InitListExpr *E = dyn_cast<InitListExpr>(D.getInit());
  3718. if (!E)
  3719. return D.getType();
  3720. unsigned arrayEltCount = ScanInitList(E);
  3721. QualType ResultTy = E->getType();
  3722. QualType EltTy = QualType(ResultTy->getArrayElementTypeNoTypeQual(), 0);
  3723. unsigned eltCount = GetElementCount(EltTy);
  3724. llvm::APInt ArySize(32, arrayEltCount / eltCount);
  3725. QualType ArrayTy = CGM.getContext().getConstantArrayType(
  3726. EltTy, ArySize, clang::ArrayType::Normal, 0);
  3727. D.setType(ArrayTy);
  3728. E->setType(ArrayTy);
  3729. return ArrayTy;
  3730. }
  3731. Value *CGMSHLSLRuntime::EmitHLSLInitListExpr(CodeGenFunction &CGF, InitListExpr *E,
  3732. // The destPtr when emiting aggregate init, for normal case, it will be null.
  3733. Value *DestPtr) {
  3734. SmallVector<Value *, 4> EltValList;
  3735. SmallVector<QualType, 4> EltTyList;
  3736. ScanInitList(CGF, E, EltValList, EltTyList);
  3737. QualType ResultTy = E->getType();
  3738. unsigned idx = 0;
  3739. // Create cast if need.
  3740. AddMissingCastOpsInInitList(EltValList, EltTyList, idx, ResultTy, CGF);
  3741. DXASSERT(idx == EltValList.size(), "size must match");
  3742. llvm::Type *RetTy = CGF.ConvertType(ResultTy);
  3743. if (DestPtr) {
  3744. SmallVector<Value *, 4> ParamList;
  3745. DXASSERT(RetTy->isAggregateType(), "");
  3746. ParamList.emplace_back(DestPtr);
  3747. ParamList.append(EltValList.begin(), EltValList.end());
  3748. idx = 0;
  3749. StoreInitListToDestPtr(DestPtr, EltValList, idx, CGF.Builder, TheModule);
  3750. return nullptr;
  3751. }
  3752. if (IsHLSLVecType(ResultTy)) {
  3753. Value *Result = UndefValue::get(RetTy);
  3754. for (unsigned i = 0; i < RetTy->getVectorNumElements(); i++)
  3755. Result = CGF.Builder.CreateInsertElement(Result, EltValList[i], i);
  3756. return Result;
  3757. } else {
  3758. // Must be matrix here.
  3759. DXASSERT(IsHLSLMatType(ResultTy), "must be matrix type here.");
  3760. return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLInit,
  3761. /*opcode*/ 0, RetTy, EltValList,
  3762. TheModule);
  3763. }
  3764. }
  3765. Value *CGMSHLSLRuntime::EmitHLSLMatrixOperationCall(
  3766. CodeGenFunction &CGF, const clang::Expr *E, llvm::Type *RetType,
  3767. ArrayRef<Value *> paramList) {
  3768. HLOpcodeGroup group = GetHLOpcodeGroup(E->getStmtClass());
  3769. unsigned opcode = GetHLOpcode(E);
  3770. if (group == HLOpcodeGroup::HLInit)
  3771. return EmitHLSLArrayInit(CGF.Builder, group, opcode, RetType, paramList,
  3772. TheModule);
  3773. else
  3774. return EmitHLSLMatrixOperationCallImp(CGF.Builder, group, opcode, RetType,
  3775. paramList, TheModule);
  3776. }
  3777. void CGMSHLSLRuntime::EmitHLSLDiscard(CodeGenFunction &CGF) {
  3778. EmitHLSLMatrixOperationCallImp(
  3779. CGF.Builder, HLOpcodeGroup::HLIntrinsic,
  3780. static_cast<unsigned>(IntrinsicOp::IOP_clip),
  3781. llvm::Type::getVoidTy(CGF.getLLVMContext()),
  3782. {ConstantFP::get(llvm::Type::getFloatTy(CGF.getLLVMContext()), -1.0f)},
  3783. TheModule);
  3784. }
  3785. Value *CGMSHLSLRuntime::EmitHLSLLiteralCast(CodeGenFunction &CGF, Value *Src,
  3786. QualType SrcType,
  3787. QualType DstType) {
  3788. auto &Builder = CGF.Builder;
  3789. llvm::Type *DstTy = CGF.ConvertType(DstType);
  3790. bool bDstSigned = DstType->isSignedIntegerType();
  3791. if (ConstantInt *CI = dyn_cast<ConstantInt>(Src)) {
  3792. APInt v = CI->getValue();
  3793. if (llvm::IntegerType *IT = dyn_cast<llvm::IntegerType>(DstTy)) {
  3794. v = v.trunc(IT->getBitWidth());
  3795. switch (IT->getBitWidth()) {
  3796. case 32:
  3797. return Builder.getInt32(v.getLimitedValue());
  3798. case 64:
  3799. return Builder.getInt64(v.getLimitedValue());
  3800. case 16:
  3801. return Builder.getInt16(v.getLimitedValue());
  3802. case 8:
  3803. return Builder.getInt8(v.getLimitedValue());
  3804. default:
  3805. return nullptr;
  3806. }
  3807. } else {
  3808. DXASSERT_NOMSG(DstTy->isFloatingPointTy());
  3809. int64_t val = v.getLimitedValue();
  3810. if (v.isNegative())
  3811. val = 0-v.abs().getLimitedValue();
  3812. if (DstTy->isDoubleTy())
  3813. return ConstantFP::get(DstTy, (double)val);
  3814. else if (DstTy->isFloatTy())
  3815. return ConstantFP::get(DstTy, (float)val);
  3816. else {
  3817. if (bDstSigned)
  3818. return Builder.CreateSIToFP(Src, DstTy);
  3819. else
  3820. return Builder.CreateUIToFP(Src, DstTy);
  3821. }
  3822. }
  3823. } else if (ConstantFP *CF = dyn_cast<ConstantFP>(Src)) {
  3824. APFloat v = CF->getValueAPF();
  3825. double dv = v.convertToDouble();
  3826. if (llvm::IntegerType *IT = dyn_cast<llvm::IntegerType>(DstTy)) {
  3827. switch (IT->getBitWidth()) {
  3828. case 32:
  3829. return Builder.getInt32(dv);
  3830. case 64:
  3831. return Builder.getInt64(dv);
  3832. case 16:
  3833. return Builder.getInt16(dv);
  3834. case 8:
  3835. return Builder.getInt8(dv);
  3836. default:
  3837. return nullptr;
  3838. }
  3839. } else {
  3840. if (DstTy->isFloatTy()) {
  3841. float fv = dv;
  3842. return ConstantFP::get(DstTy->getContext(), APFloat(fv));
  3843. } else {
  3844. return Builder.CreateFPTrunc(Src, DstTy);
  3845. }
  3846. }
  3847. } else if (UndefValue *UV = dyn_cast<UndefValue>(Src)) {
  3848. return UndefValue::get(DstTy);
  3849. } else {
  3850. Instruction *I = cast<Instruction>(Src);
  3851. if (SelectInst *SI = dyn_cast<SelectInst>(I)) {
  3852. Value *T = SI->getTrueValue();
  3853. Value *F = SI->getFalseValue();
  3854. Value *Cond = SI->getCondition();
  3855. if (isa<llvm::ConstantInt>(T) && isa<llvm::ConstantInt>(F)) {
  3856. llvm::APInt lhs = cast<llvm::ConstantInt>(T)->getValue();
  3857. llvm::APInt rhs = cast<llvm::ConstantInt>(F)->getValue();
  3858. if (DstTy == Builder.getInt32Ty()) {
  3859. T = Builder.getInt32(lhs.getLimitedValue());
  3860. F = Builder.getInt32(rhs.getLimitedValue());
  3861. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  3862. return Sel;
  3863. } else if (DstTy->isFloatingPointTy()) {
  3864. T = ConstantFP::get(DstTy, lhs.getLimitedValue());
  3865. F = ConstantFP::get(DstTy, rhs.getLimitedValue());
  3866. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  3867. return Sel;
  3868. }
  3869. } else if (isa<llvm::ConstantFP>(T) && isa<llvm::ConstantFP>(F)) {
  3870. llvm::APFloat lhs = cast<llvm::ConstantFP>(T)->getValueAPF();
  3871. llvm::APFloat rhs = cast<llvm::ConstantFP>(F)->getValueAPF();
  3872. double ld = lhs.convertToDouble();
  3873. double rd = rhs.convertToDouble();
  3874. if (DstTy->isFloatTy()) {
  3875. float lf = ld;
  3876. float rf = rd;
  3877. T = ConstantFP::get(DstTy->getContext(), APFloat(lf));
  3878. F = ConstantFP::get(DstTy->getContext(), APFloat(rf));
  3879. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  3880. return Sel;
  3881. } else if (DstTy == Builder.getInt32Ty()) {
  3882. T = Builder.getInt32(ld);
  3883. F = Builder.getInt32(rd);
  3884. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  3885. return Sel;
  3886. } else if (DstTy == Builder.getInt64Ty()) {
  3887. T = Builder.getInt64(ld);
  3888. F = Builder.getInt64(rd);
  3889. Value *Sel = Builder.CreateSelect(Cond, T, F, "cond");
  3890. return Sel;
  3891. }
  3892. }
  3893. }
  3894. // TODO: support other opcode if need.
  3895. return nullptr;
  3896. }
  3897. }
  3898. Value *CGMSHLSLRuntime::EmitHLSLMatrixSubscript(CodeGenFunction &CGF,
  3899. llvm::Type *RetType,
  3900. llvm::Value *Ptr,
  3901. llvm::Value *Idx,
  3902. clang::QualType Ty) {
  3903. unsigned opcode =
  3904. IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
  3905. ? static_cast<unsigned>(HLSubscriptOpcode::RowMatSubscript)
  3906. : static_cast<unsigned>(HLSubscriptOpcode::ColMatSubscript);
  3907. Value *matBase = Ptr;
  3908. if (matBase->getType()->isPointerTy()) {
  3909. RetType =
  3910. llvm::PointerType::get(RetType->getPointerElementType(),
  3911. matBase->getType()->getPointerAddressSpace());
  3912. }
  3913. return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
  3914. opcode, RetType, {Ptr, Idx}, TheModule);
  3915. }
  3916. Value *CGMSHLSLRuntime::EmitHLSLMatrixElement(CodeGenFunction &CGF,
  3917. llvm::Type *RetType,
  3918. ArrayRef<Value *> paramList,
  3919. QualType Ty) {
  3920. unsigned opcode =
  3921. IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
  3922. ? static_cast<unsigned>(HLSubscriptOpcode::RowMatElement)
  3923. : static_cast<unsigned>(HLSubscriptOpcode::ColMatElement);
  3924. Value *matBase = paramList[0];
  3925. if (matBase->getType()->isPointerTy()) {
  3926. RetType =
  3927. llvm::PointerType::get(RetType->getPointerElementType(),
  3928. matBase->getType()->getPointerAddressSpace());
  3929. }
  3930. return EmitHLSLMatrixOperationCallImp(CGF.Builder, HLOpcodeGroup::HLSubscript,
  3931. opcode, RetType, paramList, TheModule);
  3932. }
  3933. Value *CGMSHLSLRuntime::EmitHLSLMatrixLoad(CGBuilderTy &Builder, Value *Ptr,
  3934. QualType Ty) {
  3935. unsigned opcode =
  3936. IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
  3937. ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatLoad)
  3938. : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatLoad);
  3939. return EmitHLSLMatrixOperationCallImp(
  3940. Builder, HLOpcodeGroup::HLMatLoadStore, opcode,
  3941. Ptr->getType()->getPointerElementType(), {Ptr}, TheModule);
  3942. }
  3943. void CGMSHLSLRuntime::EmitHLSLMatrixStore(CGBuilderTy &Builder, Value *Val,
  3944. Value *DestPtr, QualType Ty) {
  3945. unsigned opcode =
  3946. IsRowMajorMatrix(Ty, m_pHLModule->GetHLOptions().bDefaultRowMajor)
  3947. ? static_cast<unsigned>(HLMatLoadStoreOpcode::RowMatStore)
  3948. : static_cast<unsigned>(HLMatLoadStoreOpcode::ColMatStore);
  3949. EmitHLSLMatrixOperationCallImp(Builder, HLOpcodeGroup::HLMatLoadStore, opcode,
  3950. Val->getType(), {DestPtr, Val}, TheModule);
  3951. }
  3952. Value *CGMSHLSLRuntime::EmitHLSLMatrixLoad(CodeGenFunction &CGF, Value *Ptr,
  3953. QualType Ty) {
  3954. return EmitHLSLMatrixLoad(CGF.Builder, Ptr, Ty);
  3955. }
  3956. void CGMSHLSLRuntime::EmitHLSLMatrixStore(CodeGenFunction &CGF, Value *Val,
  3957. Value *DestPtr, QualType Ty) {
  3958. EmitHLSLMatrixStore(CGF.Builder, Val, DestPtr, Ty);
  3959. }
  3960. // Copy data from srcPtr to destPtr.
  3961. static void SimplePtrCopy(Value *DestPtr, Value *SrcPtr,
  3962. ArrayRef<Value *> idxList, CGBuilderTy &Builder) {
  3963. if (idxList.size() > 1) {
  3964. DestPtr = Builder.CreateInBoundsGEP(DestPtr, idxList);
  3965. SrcPtr = Builder.CreateInBoundsGEP(SrcPtr, idxList);
  3966. }
  3967. llvm::LoadInst *ld = Builder.CreateLoad(SrcPtr);
  3968. Builder.CreateStore(ld, DestPtr);
  3969. }
  3970. // Get Element val from SrvVal with extract value.
  3971. static Value *GetEltVal(Value *SrcVal, ArrayRef<Value*> idxList,
  3972. CGBuilderTy &Builder) {
  3973. Value *Val = SrcVal;
  3974. // Skip beginning pointer type.
  3975. for (unsigned i = 1; i < idxList.size(); i++) {
  3976. ConstantInt *idx = cast<ConstantInt>(idxList[i]);
  3977. llvm::Type *Ty = Val->getType();
  3978. if (Ty->isAggregateType()) {
  3979. Val = Builder.CreateExtractValue(Val, idx->getLimitedValue());
  3980. }
  3981. }
  3982. return Val;
  3983. }
  3984. // Copy srcVal to destPtr.
  3985. static void SimpleValCopy(Value *DestPtr, Value *SrcVal,
  3986. ArrayRef<Value*> idxList,
  3987. CGBuilderTy &Builder) {
  3988. Value *DestGEP = Builder.CreateInBoundsGEP(DestPtr, idxList);
  3989. Value *Val = GetEltVal(SrcVal, idxList, Builder);
  3990. Builder.CreateStore(Val, DestGEP);
  3991. }
  3992. static void SimpleCopy(Value *Dest, Value *Src,
  3993. ArrayRef<Value *> idxList,
  3994. CGBuilderTy &Builder) {
  3995. if (Src->getType()->isPointerTy())
  3996. SimplePtrCopy(Dest, Src, idxList, Builder);
  3997. else
  3998. SimpleValCopy(Dest, Src, idxList, Builder);
  3999. }
  4000. void CGMSHLSLRuntime::FlattenAggregatePtrToGepList(
  4001. CodeGenFunction &CGF, Value *Ptr, SmallVector<Value *, 4> &idxList,
  4002. clang::QualType Type, llvm::Type *Ty, SmallVector<Value *, 4> &GepList,
  4003. SmallVector<QualType, 4> &EltTyList) {
  4004. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  4005. Constant *idx = Constant::getIntegerValue(
  4006. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  4007. idxList.emplace_back(idx);
  4008. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, Type, PT->getElementType(),
  4009. GepList, EltTyList);
  4010. idxList.pop_back();
  4011. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  4012. // Use matLd/St for matrix.
  4013. unsigned col, row;
  4014. llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
  4015. llvm::PointerType *EltPtrTy =
  4016. llvm::PointerType::get(EltTy, Ptr->getType()->getPointerAddressSpace());
  4017. QualType EltQualTy = hlsl::GetHLSLMatElementType(Type);
  4018. Value *matPtr = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  4019. // Flatten matrix to elements.
  4020. for (unsigned r = 0; r < row; r++) {
  4021. for (unsigned c = 0; c < col; c++) {
  4022. ConstantInt *cRow = CGF.Builder.getInt32(r);
  4023. ConstantInt *cCol = CGF.Builder.getInt32(c);
  4024. Constant *CV = llvm::ConstantVector::get({cRow, cCol});
  4025. GepList.push_back(
  4026. EmitHLSLMatrixElement(CGF, EltPtrTy, {matPtr, CV}, Type));
  4027. EltTyList.push_back(EltQualTy);
  4028. }
  4029. }
  4030. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  4031. if (HLModule::IsHLSLObjectType(ST)) {
  4032. // Avoid split HLSL object.
  4033. Value *GEP = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  4034. GepList.push_back(GEP);
  4035. EltTyList.push_back(Type);
  4036. return;
  4037. }
  4038. const clang::RecordType *RT = Type->getAsStructureType();
  4039. RecordDecl *RD = RT->getDecl();
  4040. auto fieldIter = RD->field_begin();
  4041. const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
  4042. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  4043. if (CXXRD->getNumBases()) {
  4044. // Add base as field.
  4045. for (const auto &I : CXXRD->bases()) {
  4046. const CXXRecordDecl *BaseDecl =
  4047. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  4048. // Skip empty struct.
  4049. if (BaseDecl->field_empty())
  4050. continue;
  4051. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  4052. llvm::Type *parentType = CGF.ConvertType(parentTy);
  4053. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  4054. Constant *idx = llvm::Constant::getIntegerValue(
  4055. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4056. idxList.emplace_back(idx);
  4057. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, parentTy, parentType,
  4058. GepList, EltTyList);
  4059. idxList.pop_back();
  4060. }
  4061. }
  4062. }
  4063. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  4064. fieldIter != fieldEnd; ++fieldIter) {
  4065. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  4066. llvm::Type *ET = ST->getElementType(i);
  4067. Constant *idx = llvm::Constant::getIntegerValue(
  4068. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4069. idxList.emplace_back(idx);
  4070. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, fieldIter->getType(), ET,
  4071. GepList, EltTyList);
  4072. idxList.pop_back();
  4073. }
  4074. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  4075. llvm::Type *ET = AT->getElementType();
  4076. QualType EltType = CGF.getContext().getBaseElementType(Type);
  4077. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  4078. Constant *idx = Constant::getIntegerValue(
  4079. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4080. idxList.emplace_back(idx);
  4081. FlattenAggregatePtrToGepList(CGF, Ptr, idxList, EltType, ET, GepList,
  4082. EltTyList);
  4083. idxList.pop_back();
  4084. }
  4085. } else if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
  4086. // Flatten vector too.
  4087. QualType EltTy = hlsl::GetHLSLVecElementType(Type);
  4088. for (uint32_t i = 0; i < VT->getNumElements(); i++) {
  4089. Constant *idx = CGF.Builder.getInt8(i); // CGF.Builder.getInt32(i);
  4090. idxList.emplace_back(idx);
  4091. Value *GEP = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  4092. GepList.push_back(GEP);
  4093. EltTyList.push_back(EltTy);
  4094. idxList.pop_back();
  4095. }
  4096. } else {
  4097. Value *GEP = CGF.Builder.CreateInBoundsGEP(Ptr, idxList);
  4098. GepList.push_back(GEP);
  4099. EltTyList.push_back(Type);
  4100. }
  4101. }
  4102. void CGMSHLSLRuntime::LoadFlattenedGepList(CodeGenFunction &CGF,
  4103. ArrayRef<Value *> GepList,
  4104. ArrayRef<QualType> EltTyList,
  4105. SmallVector<Value *, 4> &EltList) {
  4106. unsigned eltSize = GepList.size();
  4107. for (unsigned i = 0; i < eltSize; i++) {
  4108. Value *Ptr = GepList[i];
  4109. QualType Type = EltTyList[i];
  4110. // Everying is element type.
  4111. EltList.push_back(CGF.Builder.CreateLoad(Ptr));
  4112. }
  4113. }
  4114. void CGMSHLSLRuntime::StoreFlattenedGepList(CodeGenFunction &CGF, ArrayRef<Value *> GepList,
  4115. ArrayRef<QualType> GepTyList, ArrayRef<Value *> EltValList, ArrayRef<QualType> SrcTyList) {
  4116. unsigned eltSize = GepList.size();
  4117. for (unsigned i = 0; i < eltSize; i++) {
  4118. Value *Ptr = GepList[i];
  4119. QualType DestType = GepTyList[i];
  4120. Value *Val = EltValList[i];
  4121. QualType SrcType = SrcTyList[i];
  4122. llvm::Type *Ty = Ptr->getType()->getPointerElementType();
  4123. // Everything is element type.
  4124. if (Ty != Val->getType()) {
  4125. Instruction::CastOps castOp =
  4126. static_cast<Instruction::CastOps>(HLModule::FindCastOp(
  4127. IsUnsigned(SrcType), IsUnsigned(DestType), Val->getType(), Ty));
  4128. Val = CGF.Builder.CreateCast(castOp, Val, Ty);
  4129. }
  4130. CGF.Builder.CreateStore(Val, Ptr);
  4131. }
  4132. }
  4133. // Copy element data from SrcPtr to DestPtr by generate following IR.
  4134. // element = Ld SrcGEP
  4135. // St element, DestGEP
  4136. // idxList stored the index to generate GetElementPtr for current element.
  4137. // Type is QualType of current element.
  4138. // Ty is llvm::Type of current element.
  4139. void CGMSHLSLRuntime::EmitHLSLAggregateCopy(
  4140. CodeGenFunction &CGF, llvm::Value *SrcPtr, llvm::Value *DestPtr,
  4141. SmallVector<Value *, 4> &idxList, clang::QualType Type, llvm::Type *Ty) {
  4142. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  4143. Constant *idx = Constant::getIntegerValue(
  4144. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  4145. idxList.emplace_back(idx);
  4146. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, Type,
  4147. PT->getElementType());
  4148. idxList.pop_back();
  4149. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  4150. // Use matLd/St for matrix.
  4151. Value *srcGEP = CGF.Builder.CreateInBoundsGEP(SrcPtr, idxList);
  4152. Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
  4153. Value *ldMat = EmitHLSLMatrixLoad(CGF, srcGEP, Type);
  4154. EmitHLSLMatrixStore(CGF, ldMat, dstGEP, Type);
  4155. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  4156. if (HLModule::IsHLSLObjectType(ST)) {
  4157. // Avoid split HLSL object.
  4158. SimpleCopy(DestPtr, SrcPtr, idxList, CGF.Builder);
  4159. return;
  4160. }
  4161. const clang::RecordType *RT = Type->getAsStructureType();
  4162. RecordDecl *RD = RT->getDecl();
  4163. auto fieldIter = RD->field_begin();
  4164. const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
  4165. // Take care base.
  4166. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  4167. if (CXXRD->getNumBases()) {
  4168. for (const auto &I : CXXRD->bases()) {
  4169. const CXXRecordDecl *BaseDecl =
  4170. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  4171. if (BaseDecl->field_empty())
  4172. continue;
  4173. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  4174. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  4175. llvm::Type *ET = ST->getElementType(i);
  4176. Constant *idx = llvm::Constant::getIntegerValue(
  4177. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4178. idxList.emplace_back(idx);
  4179. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList,
  4180. parentTy, ET);
  4181. idxList.pop_back();
  4182. }
  4183. }
  4184. }
  4185. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  4186. fieldIter != fieldEnd; ++fieldIter) {
  4187. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  4188. llvm::Type *ET = ST->getElementType(i);
  4189. Constant *idx = llvm::Constant::getIntegerValue(
  4190. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4191. idxList.emplace_back(idx);
  4192. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, fieldIter->getType(),
  4193. ET);
  4194. idxList.pop_back();
  4195. }
  4196. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  4197. llvm::Type *ET = AT->getElementType();
  4198. QualType EltType = CGF.getContext().getBaseElementType(Type);
  4199. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  4200. Constant *idx = Constant::getIntegerValue(
  4201. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4202. idxList.emplace_back(idx);
  4203. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, EltType, ET);
  4204. idxList.pop_back();
  4205. }
  4206. } else {
  4207. SimpleCopy(DestPtr, SrcPtr, idxList, CGF.Builder);
  4208. }
  4209. }
  4210. void CGMSHLSLRuntime::EmitHLSLAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  4211. llvm::Value *DestPtr,
  4212. clang::QualType Ty) {
  4213. SmallVector<Value *, 4> idxList;
  4214. EmitHLSLAggregateCopy(CGF, SrcPtr, DestPtr, idxList, Ty, SrcPtr->getType());
  4215. }
  4216. void CGMSHLSLRuntime::EmitHLSLFlatConversionAggregateCopy(CodeGenFunction &CGF, llvm::Value *SrcPtr,
  4217. clang::QualType SrcTy,
  4218. llvm::Value *DestPtr,
  4219. clang::QualType DestTy) {
  4220. // It is possiable to implement EmitHLSLAggregateCopy, EmitHLSLAggregateStore the same way.
  4221. // But split value to scalar will generate many instruction when src type is same as dest type.
  4222. SmallVector<Value *, 4> idxList;
  4223. SmallVector<Value *, 4> SrcGEPList;
  4224. SmallVector<QualType, 4> SrcEltTyList;
  4225. FlattenAggregatePtrToGepList(CGF, SrcPtr, idxList, SrcTy, SrcPtr->getType(), SrcGEPList,
  4226. SrcEltTyList);
  4227. SmallVector<Value *, 4> LdEltList;
  4228. LoadFlattenedGepList(CGF, SrcGEPList, SrcEltTyList, LdEltList);
  4229. idxList.clear();
  4230. SmallVector<Value *, 4> DestGEPList;
  4231. SmallVector<QualType, 4> DestEltTyList;
  4232. FlattenAggregatePtrToGepList(CGF, DestPtr, idxList, DestTy, DestPtr->getType(), DestGEPList, DestEltTyList);
  4233. StoreFlattenedGepList(CGF, DestGEPList, DestEltTyList, LdEltList, SrcEltTyList);
  4234. }
  4235. // Store element data from Val to DestPtr by generate following IR.
  4236. // element = ExtractVal SrcVal
  4237. // St element, DestGEP
  4238. // idxList stored the index to generate GetElementPtr for current element.
  4239. // Type is QualType of current element.
  4240. // Ty is llvm::Type of current element.
  4241. void CGMSHLSLRuntime::EmitHLSLAggregateStore(
  4242. CodeGenFunction &CGF, llvm::Value *SrcVal, llvm::Value *DestPtr,
  4243. SmallVector<Value *, 4> &idxList, clang::QualType Type, llvm::Type *Ty) {
  4244. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  4245. Constant *idx = Constant::getIntegerValue(
  4246. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  4247. idxList.emplace_back(idx);
  4248. EmitHLSLAggregateStore(CGF, SrcVal, DestPtr, idxList, Type, PT->getElementType());
  4249. idxList.pop_back();
  4250. }
  4251. else if (HLMatrixLower::IsMatrixType(Ty)) {
  4252. // Use matLd/St for matrix.
  4253. Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
  4254. Value *ldMat = GetEltVal(SrcVal, idxList, CGF.Builder);
  4255. EmitHLSLMatrixStore(CGF, ldMat, dstGEP, Type);
  4256. }
  4257. else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  4258. if (HLModule::IsHLSLObjectType(ST)) {
  4259. // Avoid split HLSL object.
  4260. SimpleCopy(DestPtr, SrcVal, idxList, CGF.Builder);
  4261. return;
  4262. }
  4263. const clang::RecordType *RT = Type->getAsStructureType();
  4264. RecordDecl *RD = RT->getDecl();
  4265. auto fieldIter = RD->field_begin();
  4266. const CGRecordLayout& RL = CGF.getTypes().getCGRecordLayout(RD);
  4267. // Take care base.
  4268. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  4269. if (CXXRD->getNumBases()) {
  4270. for (const auto &I : CXXRD->bases()) {
  4271. const CXXRecordDecl *BaseDecl = cast<CXXRecordDecl>(
  4272. I.getType()->castAs<RecordType>()->getDecl());
  4273. if (BaseDecl->field_empty())
  4274. continue;
  4275. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  4276. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  4277. llvm::Type *ET = ST->getElementType(i);
  4278. Constant *idx = llvm::Constant::getIntegerValue(
  4279. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4280. idxList.emplace_back(idx);
  4281. EmitHLSLAggregateStore(CGF, SrcVal, DestPtr, idxList,
  4282. parentTy, ET);
  4283. idxList.pop_back();
  4284. }
  4285. }
  4286. }
  4287. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  4288. fieldIter != fieldEnd; ++fieldIter) {
  4289. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  4290. llvm::Type *ET = ST->getElementType(i);
  4291. Constant *idx = llvm::Constant::getIntegerValue(
  4292. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4293. idxList.emplace_back(idx);
  4294. EmitHLSLAggregateStore(CGF, SrcVal, DestPtr, idxList, fieldIter->getType(), ET);
  4295. idxList.pop_back();
  4296. }
  4297. }
  4298. else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  4299. llvm::Type *ET = AT->getElementType();
  4300. QualType EltType = CGF.getContext().getBaseElementType(Type);
  4301. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  4302. Constant *idx = Constant::getIntegerValue(
  4303. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4304. idxList.emplace_back(idx);
  4305. EmitHLSLAggregateStore(CGF, SrcVal, DestPtr, idxList, EltType, ET);
  4306. idxList.pop_back();
  4307. }
  4308. }
  4309. else {
  4310. SimpleValCopy(DestPtr, SrcVal, idxList, CGF.Builder);
  4311. }
  4312. }
  4313. void CGMSHLSLRuntime::EmitHLSLAggregateStore(CodeGenFunction &CGF, llvm::Value *SrcVal,
  4314. llvm::Value *DestPtr,
  4315. clang::QualType Ty) {
  4316. SmallVector<Value *, 4> idxList;
  4317. // Add first 0 for DestPtr.
  4318. Constant *idx = Constant::getIntegerValue(
  4319. IntegerType::get(SrcVal->getContext(), 32), APInt(32, 0));
  4320. idxList.emplace_back(idx);
  4321. EmitHLSLAggregateStore(CGF, SrcVal, DestPtr, idxList, Ty, SrcVal->getType());
  4322. }
  4323. static void SimpleFlatValCopy(Value *DestPtr, Value *SrcVal, QualType Ty,
  4324. QualType SrcTy, ArrayRef<Value *> idxList,
  4325. CGBuilderTy &Builder) {
  4326. Value *DestGEP = Builder.CreateInBoundsGEP(DestPtr, idxList);
  4327. llvm::Type *ToTy = DestGEP->getType()->getPointerElementType();
  4328. llvm::Type *EltToTy = ToTy;
  4329. if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(ToTy)) {
  4330. EltToTy = VT->getElementType();
  4331. }
  4332. if (EltToTy != SrcVal->getType()) {
  4333. Instruction::CastOps castOp =
  4334. static_cast<Instruction::CastOps>(HLModule::FindCastOp(
  4335. IsUnsigned(SrcTy), IsUnsigned(Ty), SrcVal->getType(), ToTy));
  4336. SrcVal = Builder.CreateCast(castOp, SrcVal, EltToTy);
  4337. }
  4338. if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(ToTy)) {
  4339. llvm::VectorType *VT1 = llvm::VectorType::get(EltToTy, 1);
  4340. Value *V1 =
  4341. Builder.CreateInsertElement(UndefValue::get(VT1), SrcVal, (uint64_t)0);
  4342. std::vector<int> shufIdx(VT->getNumElements(), 0);
  4343. Value *Vec = Builder.CreateShuffleVector(V1, V1, shufIdx);
  4344. Builder.CreateStore(Vec, DestGEP);
  4345. } else
  4346. Builder.CreateStore(SrcVal, DestGEP);
  4347. }
  4348. void CGMSHLSLRuntime::EmitHLSLFlatConversionToAggregate(
  4349. CodeGenFunction &CGF, Value *SrcVal, llvm::Value *DestPtr,
  4350. SmallVector<Value *, 4> &idxList, QualType Type, QualType SrcType,
  4351. llvm::Type *Ty) {
  4352. if (llvm::PointerType *PT = dyn_cast<llvm::PointerType>(Ty)) {
  4353. Constant *idx = Constant::getIntegerValue(
  4354. IntegerType::get(Ty->getContext(), 32), APInt(32, 0));
  4355. idxList.emplace_back(idx);
  4356. EmitHLSLFlatConversionToAggregate(CGF, SrcVal, DestPtr, idxList, Type,
  4357. SrcType, PT->getElementType());
  4358. idxList.pop_back();
  4359. } else if (HLMatrixLower::IsMatrixType(Ty)) {
  4360. // Use matLd/St for matrix.
  4361. Value *dstGEP = CGF.Builder.CreateInBoundsGEP(DestPtr, idxList);
  4362. unsigned row, col;
  4363. llvm::Type *EltTy = HLMatrixLower::GetMatrixInfo(Ty, col, row);
  4364. llvm::VectorType *VT1 = llvm::VectorType::get(EltTy, 1);
  4365. if (EltTy != SrcVal->getType()) {
  4366. Instruction::CastOps castOp =
  4367. static_cast<Instruction::CastOps>(HLModule::FindCastOp(
  4368. IsUnsigned(SrcType), IsUnsigned(Type), SrcVal->getType(), EltTy));
  4369. SrcVal = CGF.Builder.CreateCast(castOp, SrcVal, EltTy);
  4370. }
  4371. Value *V1 = CGF.Builder.CreateInsertElement(UndefValue::get(VT1), SrcVal,
  4372. (uint64_t)0);
  4373. std::vector<int> shufIdx(col * row, 0);
  4374. Value *VecMat = CGF.Builder.CreateShuffleVector(V1, V1, shufIdx);
  4375. Value *MatInit = EmitHLSLMatrixOperationCallImp(
  4376. CGF.Builder, HLOpcodeGroup::HLInit, 0, Ty, {VecMat}, TheModule);
  4377. EmitHLSLMatrixStore(CGF, MatInit, dstGEP, Type);
  4378. } else if (StructType *ST = dyn_cast<StructType>(Ty)) {
  4379. DXASSERT(!HLModule::IsHLSLObjectType(ST), "cannot cast to hlsl object, Sema should reject");
  4380. const clang::RecordType *RT = Type->getAsStructureType();
  4381. RecordDecl *RD = RT->getDecl();
  4382. auto fieldIter = RD->field_begin();
  4383. const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
  4384. // Take care base.
  4385. if (const CXXRecordDecl *CXXRD = dyn_cast<CXXRecordDecl>(RD)) {
  4386. if (CXXRD->getNumBases()) {
  4387. for (const auto &I : CXXRD->bases()) {
  4388. const CXXRecordDecl *BaseDecl =
  4389. cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
  4390. if (BaseDecl->field_empty())
  4391. continue;
  4392. QualType parentTy = QualType(BaseDecl->getTypeForDecl(), 0);
  4393. unsigned i = RL.getNonVirtualBaseLLVMFieldNo(BaseDecl);
  4394. llvm::Type *ET = ST->getElementType(i);
  4395. Constant *idx = llvm::Constant::getIntegerValue(
  4396. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4397. idxList.emplace_back(idx);
  4398. EmitHLSLFlatConversionToAggregate(CGF, SrcVal, DestPtr, idxList,
  4399. parentTy, SrcType, ET);
  4400. idxList.pop_back();
  4401. }
  4402. }
  4403. }
  4404. for (auto fieldIter = RD->field_begin(), fieldEnd = RD->field_end();
  4405. fieldIter != fieldEnd; ++fieldIter) {
  4406. unsigned i = RL.getLLVMFieldNo(*fieldIter);
  4407. llvm::Type *ET = ST->getElementType(i);
  4408. Constant *idx = llvm::Constant::getIntegerValue(
  4409. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4410. idxList.emplace_back(idx);
  4411. EmitHLSLFlatConversionToAggregate(CGF, SrcVal, DestPtr, idxList,
  4412. fieldIter->getType(), SrcType, ET);
  4413. idxList.pop_back();
  4414. }
  4415. } else if (llvm::ArrayType *AT = dyn_cast<llvm::ArrayType>(Ty)) {
  4416. llvm::Type *ET = AT->getElementType();
  4417. QualType EltType = CGF.getContext().getBaseElementType(Type);
  4418. for (uint32_t i = 0; i < AT->getNumElements(); i++) {
  4419. Constant *idx = Constant::getIntegerValue(
  4420. IntegerType::get(Ty->getContext(), 32), APInt(32, i));
  4421. idxList.emplace_back(idx);
  4422. EmitHLSLFlatConversionToAggregate(CGF, SrcVal, DestPtr, idxList, EltType,
  4423. SrcType, ET);
  4424. idxList.pop_back();
  4425. }
  4426. } else {
  4427. SimpleFlatValCopy(DestPtr, SrcVal, Type, SrcType, idxList, CGF.Builder);
  4428. }
  4429. }
  4430. void CGMSHLSLRuntime::EmitHLSLFlatConversionToAggregate(CodeGenFunction &CGF,
  4431. Value *Val,
  4432. Value *DestPtr,
  4433. QualType Ty,
  4434. QualType SrcTy) {
  4435. if (SrcTy->isBuiltinType()) {
  4436. SmallVector<Value *, 4> idxList;
  4437. // Add first 0 for DestPtr.
  4438. Constant *idx = Constant::getIntegerValue(
  4439. IntegerType::get(Val->getContext(), 32), APInt(32, 0));
  4440. idxList.emplace_back(idx);
  4441. EmitHLSLFlatConversionToAggregate(
  4442. CGF, Val, DestPtr, idxList, Ty, SrcTy,
  4443. DestPtr->getType()->getPointerElementType());
  4444. }
  4445. else {
  4446. SmallVector<Value *, 4> idxList;
  4447. SmallVector<Value *, 4> DestGEPList;
  4448. SmallVector<QualType, 4> DestEltTyList;
  4449. FlattenAggregatePtrToGepList(CGF, DestPtr, idxList, Ty, DestPtr->getType(), DestGEPList, DestEltTyList);
  4450. SmallVector<Value *, 4> EltList;
  4451. SmallVector<QualType, 4> EltTyList;
  4452. FlattenValToInitList(CGF, EltList, EltTyList, SrcTy, Val);
  4453. StoreFlattenedGepList(CGF, DestGEPList, DestEltTyList, EltList, EltTyList);
  4454. }
  4455. }
  4456. void CGMSHLSLRuntime::EmitHLSLRootSignature(CodeGenFunction &CGF,
  4457. HLSLRootSignatureAttr *RSA,
  4458. Function *Fn) {
  4459. StringRef StrRef = RSA->getSignatureName();
  4460. DiagnosticsEngine &Diags = CGF.getContext().getDiagnostics();
  4461. SourceLocation SLoc = RSA->getLocation();
  4462. std::string OSStr;
  4463. raw_string_ostream OS(OSStr);
  4464. hlsl::DxilVersionedRootSignatureDesc *D = nullptr;
  4465. DXASSERT(CGF.getLangOpts().RootSigMajor == 1,
  4466. "else EmitHLSLRootSignature needs to be updated");
  4467. hlsl::DxilRootSignatureVersion Ver;
  4468. if (CGF.getLangOpts().RootSigMinor == 0) {
  4469. Ver = hlsl::DxilRootSignatureVersion::Version_1_0;
  4470. }
  4471. else {
  4472. DXASSERT(CGF.getLangOpts().RootSigMinor == 1,
  4473. "else EmitHLSLRootSignature needs to be updated");
  4474. Ver = hlsl::DxilRootSignatureVersion::Version_1_1;
  4475. }
  4476. if (ParseHLSLRootSignature(StrRef.data(), StrRef.size(), Ver, &D, SLoc,
  4477. Diags)) {
  4478. CComPtr<IDxcBlob> pSignature;
  4479. CComPtr<IDxcBlobEncoding> pErrors;
  4480. hlsl::SerializeRootSignature(D, &pSignature, &pErrors, false);
  4481. if (pSignature == nullptr) {
  4482. DXASSERT(pErrors != nullptr, "else serialize failed with no msg");
  4483. ReportHLSLRootSigError(Diags, SLoc,
  4484. (char *)pErrors->GetBufferPointer(), pErrors->GetBufferSize());
  4485. hlsl::DeleteRootSignature(D);
  4486. }
  4487. else {
  4488. llvm::Module *pModule = Fn->getParent();
  4489. pModule->GetHLModule().GetRootSignature().Assign(D, pSignature);
  4490. }
  4491. }
  4492. }
  4493. void CGMSHLSLRuntime::EmitHLSLOutParamConversionInit(
  4494. CodeGenFunction &CGF, const FunctionDecl *FD, const CallExpr *E,
  4495. llvm::SmallVector<LValue, 8> &castArgList,
  4496. llvm::SmallVector<const Stmt *, 8> &argList,
  4497. const std::function<void(const VarDecl *, llvm::Value *)> &TmpArgMap) {
  4498. // Special case: skip first argument of CXXOperatorCall (it is "this").
  4499. unsigned ArgsToSkip = isa<CXXOperatorCallExpr>(E) ? 1 : 0;
  4500. for (uint32_t i = 0; i < FD->getNumParams(); i++) {
  4501. const ParmVarDecl *Param = FD->getParamDecl(i);
  4502. const Expr *Arg = E->getArg(i+ArgsToSkip);
  4503. QualType ParamTy = Param->getType().getNonReferenceType();
  4504. if (!Param->isModifierOut())
  4505. continue;
  4506. // get original arg
  4507. LValue argLV = CGF.EmitLValue(Arg);
  4508. // create temp Var
  4509. VarDecl *tmpArg =
  4510. VarDecl::Create(CGF.getContext(), const_cast<FunctionDecl *>(FD),
  4511. SourceLocation(), SourceLocation(),
  4512. /*IdentifierInfo*/ nullptr, ParamTy,
  4513. CGF.getContext().getTrivialTypeSourceInfo(ParamTy),
  4514. StorageClass::SC_Auto);
  4515. // Aggregate type will be indirect param convert to pointer type.
  4516. // So don't update to ReferenceType, use RValue for it.
  4517. bool isAggregateType = (ParamTy->isArrayType() || ParamTy->isRecordType()) &&
  4518. !hlsl::IsHLSLVecMatType(ParamTy);
  4519. const DeclRefExpr *tmpRef = DeclRefExpr::Create(
  4520. CGF.getContext(), NestedNameSpecifierLoc(), SourceLocation(), tmpArg,
  4521. /*enclosing*/ false, tmpArg->getLocation(), ParamTy,
  4522. isAggregateType ? VK_RValue : VK_LValue);
  4523. // update the arg
  4524. argList[i] = tmpRef;
  4525. // create alloc for the tmp arg
  4526. Value *tmpArgAddr = nullptr;
  4527. BasicBlock *InsertBlock = CGF.Builder.GetInsertBlock();
  4528. Function *F = InsertBlock->getParent();
  4529. BasicBlock *EntryBlock = &F->getEntryBlock();
  4530. if (ParamTy->isBooleanType()) {
  4531. // Create i8 for bool.
  4532. ParamTy = CGM.getContext().CharTy;
  4533. }
  4534. // Make sure the alloca is in entry block to stop inline create stacksave.
  4535. IRBuilder<> Builder(EntryBlock->getFirstInsertionPt());
  4536. tmpArgAddr = Builder.CreateAlloca(CGF.ConvertType(ParamTy));
  4537. // add it to local decl map
  4538. TmpArgMap(tmpArg, tmpArgAddr);
  4539. LValue tmpLV = LValue::MakeAddr(tmpArgAddr, ParamTy, argLV.getAlignment(),
  4540. CGF.getContext());
  4541. // save for cast after call
  4542. castArgList.emplace_back(tmpLV);
  4543. castArgList.emplace_back(argLV);
  4544. bool isObject = HLModule::IsHLSLObjectType(
  4545. tmpArgAddr->getType()->getPointerElementType());
  4546. // cast before the call
  4547. if (Param->isModifierIn() &&
  4548. // Don't copy object
  4549. !isObject) {
  4550. Value *outVal = nullptr;
  4551. bool isAggrageteTy = ParamTy->isAggregateType();
  4552. isAggrageteTy &= !IsHLSLVecMatType(ParamTy);
  4553. if (!isAggrageteTy) {
  4554. if (!IsHLSLMatType(ParamTy)) {
  4555. RValue outRVal = CGF.EmitLoadOfLValue(argLV, SourceLocation());
  4556. outVal = outRVal.getScalarVal();
  4557. } else {
  4558. Value *argAddr = argLV.getAddress();
  4559. outVal = EmitHLSLMatrixLoad(CGF, argAddr, ParamTy);
  4560. }
  4561. llvm::Type *ToTy = tmpArgAddr->getType()->getPointerElementType();
  4562. Instruction::CastOps castOp =
  4563. static_cast<Instruction::CastOps>(HLModule::FindCastOp(
  4564. IsUnsigned(argLV.getType()), IsUnsigned(tmpLV.getType()),
  4565. outVal->getType(), ToTy));
  4566. Value *castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
  4567. if (!HLMatrixLower::IsMatrixType(ToTy))
  4568. CGF.Builder.CreateStore(castVal, tmpArgAddr);
  4569. else
  4570. EmitHLSLMatrixStore(CGF, castVal, tmpArgAddr, ParamTy);
  4571. } else {
  4572. EmitHLSLAggregateCopy(CGF, argLV.getAddress(), tmpLV.getAddress(),
  4573. ParamTy);
  4574. }
  4575. }
  4576. }
  4577. }
  4578. void CGMSHLSLRuntime::EmitHLSLOutParamConversionCopyBack(
  4579. CodeGenFunction &CGF, llvm::SmallVector<LValue, 8> &castArgList) {
  4580. for (uint32_t i = 0; i < castArgList.size(); i += 2) {
  4581. // cast after the call
  4582. LValue tmpLV = castArgList[i];
  4583. LValue argLV = castArgList[i + 1];
  4584. QualType argTy = argLV.getType().getNonReferenceType();
  4585. Value *tmpArgAddr = tmpLV.getAddress();
  4586. Value *outVal = nullptr;
  4587. bool isAggrageteTy = argTy->isAggregateType();
  4588. isAggrageteTy &= !IsHLSLVecMatType(argTy);
  4589. bool isObject = HLModule::IsHLSLObjectType(
  4590. tmpArgAddr->getType()->getPointerElementType());
  4591. if (!isObject) {
  4592. if (!isAggrageteTy) {
  4593. if (!IsHLSLMatType(argTy))
  4594. outVal = CGF.Builder.CreateLoad(tmpArgAddr);
  4595. else
  4596. outVal = EmitHLSLMatrixLoad(CGF, tmpArgAddr, argTy);
  4597. llvm::Type *ToTy = CGF.ConvertType(argTy);
  4598. llvm::Type *FromTy = outVal->getType();
  4599. Value *castVal = outVal;
  4600. if (ToTy == FromTy) {
  4601. // Don't need cast.
  4602. } else if (ToTy->getScalarType() == FromTy->getScalarType()) {
  4603. if (ToTy->getScalarType() == ToTy) {
  4604. DXASSERT(FromTy->isVectorTy() &&
  4605. FromTy->getVectorNumElements() == 1,
  4606. "must be vector of 1 element");
  4607. castVal = CGF.Builder.CreateExtractElement(outVal, (uint64_t)0);
  4608. } else {
  4609. DXASSERT(!FromTy->isVectorTy(), "must be scalar type");
  4610. DXASSERT(ToTy->isVectorTy() && ToTy->getVectorNumElements() == 1,
  4611. "must be vector of 1 element");
  4612. castVal = UndefValue::get(ToTy);
  4613. castVal =
  4614. CGF.Builder.CreateInsertElement(castVal, outVal, (uint64_t)0);
  4615. }
  4616. } else {
  4617. Instruction::CastOps castOp =
  4618. static_cast<Instruction::CastOps>(HLModule::FindCastOp(
  4619. IsUnsigned(tmpLV.getType()), IsUnsigned(argLV.getType()),
  4620. outVal->getType(), ToTy));
  4621. castVal = CGF.Builder.CreateCast(castOp, outVal, ToTy);
  4622. }
  4623. if (!HLMatrixLower::IsMatrixType(ToTy))
  4624. CGF.EmitStoreThroughLValue(RValue::get(castVal), argLV);
  4625. else {
  4626. Value *destPtr = argLV.getAddress();
  4627. EmitHLSLMatrixStore(CGF, castVal, destPtr, argTy);
  4628. }
  4629. } else {
  4630. EmitHLSLAggregateCopy(CGF, tmpLV.getAddress(), argLV.getAddress(),
  4631. argTy);
  4632. }
  4633. } else
  4634. tmpArgAddr->replaceAllUsesWith(argLV.getAddress());
  4635. }
  4636. }
  4637. CGHLSLRuntime *CodeGen::CreateMSHLSLRuntime(CodeGenModule &CGM) {
  4638. return new CGMSHLSLRuntime(CGM);
  4639. }