ASTContextHLSL.cpp 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162
  1. //===--- ASTContextHLSL.cpp - HLSL support for AST nodes and operations ---===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // ASTContextHLSL.cpp //
  5. // Copyright (C) Microsoft Corporation. All rights reserved. //
  6. // This file is distributed under the University of Illinois Open Source //
  7. // License. See LICENSE.TXT for details. //
  8. // //
  9. // This file implements the ASTContext interface for HLSL. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "clang/AST/ASTContext.h"
  13. #include "clang/AST/Attr.h"
  14. #include "clang/AST/DeclCXX.h"
  15. #include "clang/AST/DeclTemplate.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/ExprCXX.h"
  18. #include "clang/AST/ExternalASTSource.h"
  19. #include "clang/AST/HlslBuiltinTypeDeclBuilder.h"
  20. #include "clang/AST/TypeLoc.h"
  21. #include "clang/Sema/SemaDiagnostic.h"
  22. #include "clang/Sema/Sema.h"
  23. #include "clang/Sema/Overload.h"
  24. #include "dxc/Support/Global.h"
  25. #include "dxc/HLSL/HLOperations.h"
  26. #include "dxc/DXIL/DxilSemantic.h"
  27. using namespace clang;
  28. using namespace hlsl;
  29. static const int FirstTemplateDepth = 0;
  30. static const int FirstParamPosition = 0;
  31. static const bool ForConstFalse = false; // a construct is targeting a const type
  32. static const bool ForConstTrue = true; // a construct is targeting a non-const type
  33. static const bool ParameterPackFalse = false; // template parameter is not an ellipsis.
  34. static const bool TypenameFalse = false; // 'typename' specified rather than 'class' for a template argument.
  35. static const bool DelayTypeCreationTrue = true; // delay type creation for a declaration
  36. static const SourceLocation NoLoc; // no source location attribution available
  37. static const bool InlineFalse = false; // namespace is not an inline namespace
  38. static const bool InlineSpecifiedFalse = false; // function was not specified as inline
  39. static const bool ExplicitFalse = false; // constructor was not specified as explicit
  40. static const bool IsConstexprFalse = false; // function is not constexpr
  41. static const bool VirtualFalse = false; // whether the base class is declares 'virtual'
  42. static const bool BaseClassFalse = false; // whether the base class is declared as 'class' (vs. 'struct')
  43. /// <summary>Names of HLSLScalarType enumeration values, in matching order to HLSLScalarType.</summary>
  44. const char* HLSLScalarTypeNames[] = {
  45. "<unknown>",
  46. "bool",
  47. "int",
  48. "uint",
  49. "dword",
  50. "half",
  51. "float",
  52. "double",
  53. "min10float",
  54. "min16float",
  55. "min12int",
  56. "min16int",
  57. "min16uint",
  58. "literal float",
  59. "literal int",
  60. "int16_t",
  61. "int32_t",
  62. "int64_t",
  63. "uint16_t",
  64. "uint32_t",
  65. "uint64_t",
  66. "float16_t",
  67. "float32_t",
  68. "float64_t"
  69. };
  70. static_assert(HLSLScalarTypeCount == _countof(HLSLScalarTypeNames), "otherwise scalar constants are not aligned");
  71. static HLSLScalarType FindScalarTypeByName(const char *typeName, const size_t typeLen, const LangOptions& langOptions) {
  72. // skipped HLSLScalarType: unknown, literal int, literal float
  73. switch (typeLen) {
  74. case 3: // int
  75. if (typeName[0] == 'i') {
  76. if (strncmp(typeName, "int", 3))
  77. break;
  78. return HLSLScalarType_int;
  79. }
  80. break;
  81. case 4: // bool, uint, half
  82. if (typeName[0] == 'b') {
  83. if (strncmp(typeName, "bool", 4))
  84. break;
  85. return HLSLScalarType_bool;
  86. }
  87. else if (typeName[0] == 'u') {
  88. if (strncmp(typeName, "uint", 4))
  89. break;
  90. return HLSLScalarType_uint;
  91. }
  92. else if (typeName[0] == 'h') {
  93. if (strncmp(typeName, "half", 4))
  94. break;
  95. return HLSLScalarType_half;
  96. }
  97. break;
  98. case 5: // dword, float
  99. if (typeName[0] == 'd') {
  100. if (strncmp(typeName, "dword", 5))
  101. break;
  102. return HLSLScalarType_dword;
  103. }
  104. else if (typeName[0] == 'f') {
  105. if (strncmp(typeName, "float", 5))
  106. break;
  107. return HLSLScalarType_float;
  108. }
  109. break;
  110. case 6: // double
  111. if (typeName[0] == 'd') {
  112. if (strncmp(typeName, "double", 6))
  113. break;
  114. return HLSLScalarType_double;
  115. }
  116. break;
  117. case 7: // int64_t
  118. if (typeName[0] == 'i' && typeName[1] == 'n') {
  119. if (typeName[3] == '6') {
  120. if (strncmp(typeName, "int64_t", 7))
  121. break;
  122. return HLSLScalarType_int64;
  123. }
  124. }
  125. case 8: // min12int, min16int, uint64_t
  126. if (typeName[0] == 'm' && typeName[1] == 'i') {
  127. if (typeName[4] == '2') {
  128. if (strncmp(typeName, "min12int", 8))
  129. break;
  130. return HLSLScalarType_int_min12;
  131. }
  132. else if (typeName[4] == '6') {
  133. if (strncmp(typeName, "min16int", 8))
  134. break;
  135. return HLSLScalarType_int_min16;
  136. }
  137. }
  138. else if (typeName[0] == 'u' && typeName[1] == 'i') {
  139. if (typeName[4] == '6') {
  140. if (strncmp(typeName, "uint64_t", 8))
  141. break;
  142. return HLSLScalarType_uint64;
  143. }
  144. }
  145. break;
  146. case 9: // min16uint
  147. if (typeName[0] == 'm' && typeName[1] == 'i') {
  148. if (strncmp(typeName, "min16uint", 9))
  149. break;
  150. return HLSLScalarType_uint_min16;
  151. }
  152. break;
  153. case 10: // min10float, min16float
  154. if (typeName[0] == 'm' && typeName[1] == 'i') {
  155. if (typeName[4] == '0') {
  156. if (strncmp(typeName, "min10float", 10))
  157. break;
  158. return HLSLScalarType_float_min10;
  159. }
  160. if (typeName[4] == '6') {
  161. if (strncmp(typeName, "min16float", 10))
  162. break;
  163. return HLSLScalarType_float_min16;
  164. }
  165. }
  166. break;
  167. default:
  168. break;
  169. }
  170. // fixed width types (int16_t, uint16_t, int32_t, uint32_t, float16_t, float32_t, float64_t)
  171. // are only supported in HLSL 2018
  172. if (langOptions.HLSLVersion >= 2018) {
  173. switch (typeLen) {
  174. case 7: // int16_t, int32_t
  175. if (typeName[0] == 'i' && typeName[1] == 'n') {
  176. if (!langOptions.UseMinPrecision) {
  177. if (typeName[3] == '1') {
  178. if (strncmp(typeName, "int16_t", 7))
  179. break;
  180. return HLSLScalarType_int16;
  181. }
  182. }
  183. if (typeName[3] == '3') {
  184. if (strncmp(typeName, "int32_t", 7))
  185. break;
  186. return HLSLScalarType_int32;
  187. }
  188. }
  189. case 8: // uint16_t, uint32_t
  190. if (!langOptions.UseMinPrecision) {
  191. if (typeName[0] == 'u' && typeName[1] == 'i') {
  192. if (typeName[4] == '1') {
  193. if (strncmp(typeName, "uint16_t", 8))
  194. break;
  195. return HLSLScalarType_uint16;
  196. }
  197. }
  198. }
  199. if (typeName[4] == '3') {
  200. if (strncmp(typeName, "uint32_t", 8))
  201. break;
  202. return HLSLScalarType_uint32;
  203. }
  204. case 9: // float16_t, float32_t, float64_t
  205. if (typeName[0] == 'f' && typeName[1] == 'l') {
  206. if (!langOptions.UseMinPrecision) {
  207. if (typeName[5] == '1') {
  208. if (strncmp(typeName, "float16_t", 9))
  209. break;
  210. return HLSLScalarType_float16;
  211. }
  212. }
  213. if (typeName[5] == '3') {
  214. if (strncmp(typeName, "float32_t", 9))
  215. break;
  216. return HLSLScalarType_float32;
  217. }
  218. else if (typeName[5] == '6') {
  219. if (strncmp(typeName, "float64_t", 9))
  220. break;
  221. return HLSLScalarType_float64;
  222. }
  223. }
  224. }
  225. }
  226. return HLSLScalarType_unknown;
  227. }
  228. /// <summary>Provides the primitive type for lowering matrix types to IR.</summary>
  229. static
  230. CanQualType GetHLSLObjectHandleType(ASTContext& context)
  231. {
  232. return context.IntTy;
  233. }
  234. static
  235. void AddSubscriptOperator(
  236. ASTContext& context, unsigned int templateDepth, TemplateTypeParmDecl *elementTemplateParamDecl,
  237. NonTypeTemplateParmDecl* colCountTemplateParamDecl, QualType intType, CXXRecordDecl* templateRecordDecl,
  238. ClassTemplateDecl* vectorTemplateDecl,
  239. bool forConst)
  240. {
  241. QualType elementType = context.getTemplateTypeParmType(
  242. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  243. Expr* sizeExpr = DeclRefExpr::Create(context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl, false,
  244. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  245. intType, ExprValueKind::VK_RValue);
  246. CXXRecordDecl *vecTemplateRecordDecl = vectorTemplateDecl->getTemplatedDecl();
  247. const clang::Type *vecTy = vecTemplateRecordDecl->getTypeForDecl();
  248. TemplateArgument templateArgs[2] =
  249. {
  250. TemplateArgument(elementType),
  251. TemplateArgument(sizeExpr)
  252. };
  253. TemplateName canonName = context.getCanonicalTemplateName(TemplateName(vectorTemplateDecl));
  254. QualType vectorType = context.getTemplateSpecializationType(
  255. canonName, templateArgs, _countof(templateArgs), QualType(vecTy, 0));
  256. vectorType = context.getLValueReferenceType(vectorType);
  257. if (forConst)
  258. vectorType = context.getConstType(vectorType);
  259. QualType indexType = intType;
  260. CreateObjectFunctionDeclarationWithParams(
  261. context, templateRecordDecl, vectorType,
  262. ArrayRef<QualType>(indexType), ArrayRef<StringRef>(StringRef("index")),
  263. context.DeclarationNames.getCXXOperatorName(OO_Subscript), forConst);
  264. }
  265. /// <summary>Adds up-front support for HLSL matrix types (just the template declaration).</summary>
  266. void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorTemplateDecl, ClassTemplateDecl** matrixTemplateDecl)
  267. {
  268. DXASSERT_NOMSG(matrixTemplateDecl != nullptr);
  269. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  270. // Create a matrix template declaration in translation unit scope.
  271. // template<typename element, int row_count, int col_count> matrix { ... }
  272. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "matrix");
  273. TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", (QualType)context.FloatTy);
  274. NonTypeTemplateParmDecl* rowCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("row_count", context.IntTy, 4);
  275. NonTypeTemplateParmDecl* colCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("col_count", context.IntTy, 4);
  276. typeDeclBuilder.startDefinition();
  277. CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
  278. ClassTemplateDecl* classTemplateDecl = typeDeclBuilder.getTemplateDecl();
  279. // Add an 'h' field to hold the handle.
  280. // The type is vector<element, col>[row].
  281. QualType elementType = context.getTemplateTypeParmType(
  282. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  283. Expr *sizeExpr = DeclRefExpr::Create(
  284. context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
  285. false,
  286. DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
  287. context.IntTy, ExprValueKind::VK_RValue);
  288. Expr *rowSizeExpr = DeclRefExpr::Create(
  289. context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
  290. false,
  291. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  292. context.IntTy, ExprValueKind::VK_RValue);
  293. QualType vectorType = context.getDependentSizedExtVectorType(
  294. elementType, rowSizeExpr, SourceLocation());
  295. QualType vectorArrayType = context.getDependentSizedArrayType(
  296. vectorType, sizeExpr, ArrayType::Normal, 0, SourceRange());
  297. typeDeclBuilder.addField("h", vectorArrayType);
  298. // Add an operator[]. The operator ranges from zero to rowcount-1, and returns a vector of colcount elements.
  299. const unsigned int templateDepth = 0;
  300. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  301. colCountTemplateParamDecl, context.UnsignedIntTy,
  302. templateRecordDecl, vectorTemplateDecl, ForConstFalse);
  303. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  304. colCountTemplateParamDecl, context.UnsignedIntTy,
  305. templateRecordDecl, vectorTemplateDecl, ForConstTrue);
  306. typeDeclBuilder.completeDefinition();
  307. *matrixTemplateDecl = classTemplateDecl;
  308. }
  309. static void AddHLSLVectorSubscriptAttr(Decl *D, ASTContext &context) {
  310. StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
  311. D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)));
  312. }
  313. /// <summary>Adds up-front support for HLSL vector types (just the template declaration).</summary>
  314. void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vectorTemplateDecl)
  315. {
  316. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  317. // Create a vector template declaration in translation unit scope.
  318. // template<typename element, int element_count> vector { ... }
  319. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "vector");
  320. TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", (QualType)context.FloatTy);
  321. NonTypeTemplateParmDecl* colCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("element_count", context.IntTy, 4);
  322. typeDeclBuilder.startDefinition();
  323. CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
  324. ClassTemplateDecl* classTemplateDecl = typeDeclBuilder.getTemplateDecl();
  325. Expr *vecSizeExpr = DeclRefExpr::Create(
  326. context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
  327. false,
  328. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  329. context.IntTy, ExprValueKind::VK_RValue);
  330. const unsigned int templateDepth = 0;
  331. QualType resultType = context.getTemplateTypeParmType(
  332. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  333. QualType vectorType = context.getDependentSizedExtVectorType(
  334. resultType, vecSizeExpr, SourceLocation());
  335. // Add an 'h' field to hold the handle.
  336. typeDeclBuilder.addField("h", vectorType);
  337. // Add an operator[]. The operator ranges from zero to colcount-1, and returns a scalar.
  338. // ForConstTrue:
  339. QualType refResultType = context.getConstType(context.getLValueReferenceType(resultType));
  340. CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
  341. context, templateRecordDecl, refResultType,
  342. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  343. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstTrue);
  344. AddHLSLVectorSubscriptAttr(functionDecl, context);
  345. // ForConstFalse:
  346. resultType = context.getLValueReferenceType(resultType);
  347. functionDecl = CreateObjectFunctionDeclarationWithParams(
  348. context, templateRecordDecl, resultType,
  349. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  350. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstFalse);
  351. AddHLSLVectorSubscriptAttr(functionDecl, context);
  352. typeDeclBuilder.completeDefinition();
  353. *vectorTemplateDecl = classTemplateDecl;
  354. }
  355. /// <summary>
  356. /// Adds a new record type in the specified context with the given name. The record type will have a handle field.
  357. /// </summary>
  358. CXXRecordDecl* hlsl::DeclareRecordTypeWithHandle(ASTContext& context, StringRef name) {
  359. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name, TagDecl::TagKind::TTK_Struct);
  360. typeDeclBuilder.startDefinition();
  361. typeDeclBuilder.addField("h", GetHLSLObjectHandleType(context));
  362. return typeDeclBuilder.completeDefinition();
  363. }
  364. // creates a global static constant unsigned integer with value.
  365. // equivalent to: static const uint name = val;
  366. static void AddConstUInt(clang::ASTContext& context, DeclContext *DC, StringRef name, unsigned val) {
  367. IdentifierInfo &Id = context.Idents.get(name, tok::TokenKind::identifier);
  368. QualType type = context.getConstType(context.UnsignedIntTy);
  369. VarDecl *varDecl = VarDecl::Create(context, DC, NoLoc, NoLoc, &Id, type,
  370. context.getTrivialTypeSourceInfo(type),
  371. clang::StorageClass::SC_Static);
  372. Expr *exprVal = IntegerLiteral::Create(
  373. context, llvm::APInt(context.getIntWidth(type), val), type, NoLoc);
  374. varDecl->setInit(exprVal);
  375. varDecl->setImplicit(true);
  376. DC->addDecl(varDecl);
  377. }
  378. static void AddConstUInt(clang::ASTContext& context, StringRef name, unsigned val) {
  379. AddConstUInt(context, context.getTranslationUnitDecl(), name, val);
  380. }
  381. // Adds a top-level enum with the given enumerants.
  382. struct Enumerant { StringRef name; unsigned value; };
  383. static void AddTypedefPseudoEnum(ASTContext& context, StringRef name, ArrayRef<Enumerant> enumerants) {
  384. DeclContext* curDC = context.getTranslationUnitDecl();
  385. // typedef uint <name>;
  386. IdentifierInfo& enumId = context.Idents.get(name, tok::TokenKind::identifier);
  387. TypeSourceInfo* uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
  388. TypedefDecl* enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
  389. curDC->addDecl(enumDecl);
  390. enumDecl->setImplicit(true);
  391. // static const uint <enumerant.name> = <enumerant.value>;
  392. for (const Enumerant& enumerant : enumerants) {
  393. AddConstUInt(context, curDC, enumerant.name, enumerant.value);
  394. }
  395. }
  396. /// <summary> Adds all constants and enums for ray tracing </summary>
  397. void hlsl::AddRaytracingConstants(ASTContext& context) {
  398. AddTypedefPseudoEnum(context, "RAY_FLAG", {
  399. { "RAY_FLAG_NONE", (unsigned)DXIL::RayFlag::None },
  400. { "RAY_FLAG_FORCE_OPAQUE", (unsigned)DXIL::RayFlag::ForceOpaque },
  401. { "RAY_FLAG_FORCE_NON_OPAQUE", (unsigned)DXIL::RayFlag::ForceNonOpaque },
  402. { "RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH", (unsigned)DXIL::RayFlag::AcceptFirstHitAndEndSearch },
  403. { "RAY_FLAG_SKIP_CLOSEST_HIT_SHADER", (unsigned)DXIL::RayFlag::SkipClosestHitShader },
  404. { "RAY_FLAG_CULL_BACK_FACING_TRIANGLES", (unsigned)DXIL::RayFlag::CullBackFacingTriangles },
  405. { "RAY_FLAG_CULL_FRONT_FACING_TRIANGLES", (unsigned)DXIL::RayFlag::CullFrontFacingTriangles },
  406. { "RAY_FLAG_CULL_OPAQUE", (unsigned)DXIL::RayFlag::CullOpaque },
  407. { "RAY_FLAG_CULL_NON_OPAQUE", (unsigned)DXIL::RayFlag::CullNonOpaque },
  408. { "RAY_FLAG_SKIP_TRIANGLES", (unsigned)DXIL::RayFlag::SkipTriangles },
  409. { "RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES", (unsigned)DXIL::RayFlag::SkipProceduralPrimitives },
  410. });
  411. AddTypedefPseudoEnum(context, "COMMITTED_STATUS", {
  412. { "COMMITTED_NOTHING", (unsigned)DXIL::CommittedStatus::CommittedNothing },
  413. { "COMMITTED_TRIANGLE_HIT", (unsigned)DXIL::CommittedStatus::CommittedTriangleHit },
  414. { "COMMITTED_PROCEDURAL_PRIMITIVE_HIT", (unsigned)DXIL::CommittedStatus::CommittedProceduralPrimitiveHit }
  415. });
  416. AddTypedefPseudoEnum(context, "CANDIDATE_TYPE", {
  417. { "CANDIDATE_NON_OPAQUE_TRIANGLE", (unsigned)DXIL::CandidateType::CandidateNonOpaqueTriangle },
  418. { "CANDIDATE_PROCEDURAL_PRIMITIVE", (unsigned)DXIL::CandidateType::CandidateProceduralPrimitive }
  419. });
  420. // static const uint HIT_KIND_* = *;
  421. AddConstUInt(context, StringRef("HIT_KIND_NONE"), (unsigned)DXIL::HitKind::None);
  422. AddConstUInt(context, StringRef("HIT_KIND_TRIANGLE_FRONT_FACE"), (unsigned)DXIL::HitKind::TriangleFrontFace);
  423. AddConstUInt(context, StringRef("HIT_KIND_TRIANGLE_BACK_FACE"), (unsigned)DXIL::HitKind::TriangleBackFace);
  424. AddConstUInt(context, StringRef("STATE_OBJECT_FLAGS_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS"), (unsigned)DXIL::StateObjectFlags::AllowLocalDependenciesOnExternalDefinitions);
  425. AddConstUInt(context, StringRef("STATE_OBJECT_FLAGS_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
  426. // The above "_FLAGS_" was a typo, leaving in to avoid breaking anyone. Supposed to be _FLAG_ below.
  427. AddConstUInt(context, StringRef("STATE_OBJECT_FLAG_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS"), (unsigned)DXIL::StateObjectFlags::AllowLocalDependenciesOnExternalDefinitions);
  428. AddConstUInt(context, StringRef("STATE_OBJECT_FLAG_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
  429. AddConstUInt(context, StringRef("STATE_OBJECT_FLAG_ALLOW_STATE_OBJECT_ADDITIONS"), (unsigned)DXIL::StateObjectFlags::AllowStateObjectAdditions);
  430. AddConstUInt(context, StringRef("RAYTRACING_PIPELINE_FLAG_NONE"), (unsigned)DXIL::RaytracingPipelineFlags::None);
  431. AddConstUInt(context, StringRef("RAYTRACING_PIPELINE_FLAG_SKIP_TRIANGLES"), (unsigned)DXIL::RaytracingPipelineFlags::SkipTriangles);
  432. AddConstUInt(context, StringRef("RAYTRACING_PIPELINE_FLAG_SKIP_PROCEDURAL_PRIMITIVES"), (unsigned)DXIL::RaytracingPipelineFlags::SkipProceduralPrimitives);
  433. }
  434. /// <summary> Adds all constants and enums for sampler feedback </summary>
  435. void hlsl::AddSamplerFeedbackConstants(ASTContext& context) {
  436. AddConstUInt(context, StringRef("SAMPLER_FEEDBACK_MIN_MIP"), (unsigned)DXIL::SamplerFeedbackType::MinMip);
  437. AddConstUInt(context, StringRef("SAMPLER_FEEDBACK_MIP_REGION_USED"), (unsigned)DXIL::SamplerFeedbackType::MipRegionUsed);
  438. }
  439. static
  440. Expr* IntConstantAsBoolExpr(clang::Sema& sema, uint64_t value)
  441. {
  442. return sema.ImpCastExprToType(
  443. sema.ActOnIntegerConstant(NoLoc, value).get(), sema.getASTContext().BoolTy, CK_IntegralToBoolean).get();
  444. }
  445. static
  446. CXXRecordDecl* CreateStdStructWithStaticBool(clang::ASTContext& context, NamespaceDecl* stdNamespace, IdentifierInfo& trueTypeId, IdentifierInfo& valueId, Expr* trueExpression)
  447. {
  448. // struct true_type { static const bool value = true; }
  449. TypeSourceInfo* boolTypeSource = context.getTrivialTypeSourceInfo(context.BoolTy.withConst());
  450. CXXRecordDecl* trueTypeDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &trueTypeId, nullptr, DelayTypeCreationTrue);
  451. // static fields are variables in the AST
  452. VarDecl* trueValueDecl = VarDecl::Create(context, trueTypeDecl, NoLoc, NoLoc, &valueId,
  453. context.BoolTy.withConst(), boolTypeSource, SC_Static);
  454. trueValueDecl->setInit(trueExpression);
  455. trueValueDecl->setConstexpr(true);
  456. trueValueDecl->setAccess(AS_public);
  457. trueTypeDecl->setLexicalDeclContext(stdNamespace);
  458. trueTypeDecl->startDefinition();
  459. trueTypeDecl->addDecl(trueValueDecl);
  460. trueTypeDecl->completeDefinition();
  461. stdNamespace->addDecl(trueTypeDecl);
  462. return trueTypeDecl;
  463. }
  464. static
  465. void DefineRecordWithBase(CXXRecordDecl* decl, DeclContext* lexicalContext, const CXXBaseSpecifier* base)
  466. {
  467. decl->setLexicalDeclContext(lexicalContext);
  468. decl->startDefinition();
  469. decl->setBases(&base, 1);
  470. decl->completeDefinition();
  471. lexicalContext->addDecl(decl);
  472. }
  473. static
  474. void SetPartialExplicitSpecialization(ClassTemplateDecl* templateDecl, ClassTemplatePartialSpecializationDecl* specializationDecl)
  475. {
  476. specializationDecl->setSpecializationKind(TSK_ExplicitSpecialization);
  477. templateDecl->AddPartialSpecialization(specializationDecl, nullptr);
  478. }
  479. static
  480. void CreateIsEqualSpecialization(ASTContext& context, ClassTemplateDecl* templateDecl, TemplateName& templateName,
  481. DeclContext* lexicalContext, const CXXBaseSpecifier* base, TemplateParameterList* templateParamList,
  482. TemplateArgument (&templateArgs)[2])
  483. {
  484. QualType specializationCanonType = context.getTemplateSpecializationType(templateName, templateArgs, _countof(templateArgs));
  485. TemplateArgumentListInfo templateArgsListInfo = TemplateArgumentListInfo(NoLoc, NoLoc);
  486. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[0], context.getTrivialTypeSourceInfo(templateArgs[0].getAsType())));
  487. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[1], context.getTrivialTypeSourceInfo(templateArgs[1].getAsType())));
  488. ClassTemplatePartialSpecializationDecl* specializationDecl =
  489. ClassTemplatePartialSpecializationDecl::Create(context, TTK_Struct, lexicalContext, NoLoc, NoLoc,
  490. templateParamList, templateDecl, templateArgs, _countof(templateArgs),
  491. templateArgsListInfo, specializationCanonType, nullptr);
  492. context.getTagDeclType(specializationDecl); // Fault this in now.
  493. DefineRecordWithBase(specializationDecl, lexicalContext, base);
  494. SetPartialExplicitSpecialization(templateDecl, specializationDecl);
  495. }
  496. /// <summary>Adds the implementation for std::is_equal.</summary>
  497. void hlsl::AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema)
  498. {
  499. // The goal is to support std::is_same<T, T>::value for testing purposes, in a manner that can
  500. // evolve into a compliant feature in the future.
  501. //
  502. // The definitions necessary are as follows (all in the std namespace).
  503. // template <class T, T v>
  504. // struct integral_constant {
  505. // typedef T value_type;
  506. // static const value_type value = v;
  507. // operator value_type() { return value; }
  508. // };
  509. //
  510. // typedef integral_constant<bool, true> true_type;
  511. // typedef integral_constant<bool, false> false_type;
  512. //
  513. // template<typename T, typename U> struct is_same : public false_type {};
  514. // template<typename T> struct is_same<T, T> : public true_type{};
  515. //
  516. // We instead use these simpler definitions for true_type and false_type.
  517. // struct false_type { static const bool value = false; };
  518. // struct true_type { static const bool value = true; };
  519. DeclContext* tuContext = context.getTranslationUnitDecl();
  520. IdentifierInfo& stdId = context.Idents.get(StringRef("std"), tok::TokenKind::identifier);
  521. IdentifierInfo& trueTypeId = context.Idents.get(StringRef("true_type"), tok::TokenKind::identifier);
  522. IdentifierInfo& falseTypeId = context.Idents.get(StringRef("false_type"), tok::TokenKind::identifier);
  523. IdentifierInfo& valueId = context.Idents.get(StringRef("value"), tok::TokenKind::identifier);
  524. IdentifierInfo& isSameId = context.Idents.get(StringRef("is_same"), tok::TokenKind::identifier);
  525. IdentifierInfo& tId = context.Idents.get(StringRef("T"), tok::TokenKind::identifier);
  526. IdentifierInfo& vId = context.Idents.get(StringRef("V"), tok::TokenKind::identifier);
  527. Expr* trueExpression = IntConstantAsBoolExpr(sema, 1);
  528. Expr* falseExpression = IntConstantAsBoolExpr(sema, 0);
  529. // namespace std
  530. NamespaceDecl* stdNamespace = NamespaceDecl::Create(context, tuContext, InlineFalse, NoLoc, NoLoc, &stdId, nullptr);
  531. CXXRecordDecl* trueTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, trueTypeId, valueId, trueExpression);
  532. CXXRecordDecl* falseTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, falseTypeId, valueId, falseExpression);
  533. // template<typename T, typename U> struct is_same : public false_type {};
  534. CXXRecordDecl* isSameFalseRecordDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &isSameId, nullptr, false);
  535. TemplateTypeParmDecl* tParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition, &tId, TypenameFalse, ParameterPackFalse);
  536. TemplateTypeParmDecl* uParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition + 1, &vId, TypenameFalse, ParameterPackFalse);
  537. NamedDecl* falseParams[] = { tParam, uParam };
  538. TemplateParameterList* falseParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, falseParams, _countof(falseParams), NoLoc);
  539. ClassTemplateDecl* isSameFalseTemplateDecl = ClassTemplateDecl::Create(context, stdNamespace, NoLoc, DeclarationName(&isSameId), falseParamList, isSameFalseRecordDecl, nullptr);
  540. context.getTagDeclType(isSameFalseRecordDecl); // Fault this in now.
  541. CXXBaseSpecifier* falseBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  542. context.getTrivialTypeSourceInfo(context.getTypeDeclType(falseTypeDecl)), NoLoc);
  543. isSameFalseRecordDecl->setDescribedClassTemplate(isSameFalseTemplateDecl);
  544. isSameFalseTemplateDecl->setLexicalDeclContext(stdNamespace);
  545. DefineRecordWithBase(isSameFalseRecordDecl, stdNamespace, falseBase);
  546. // is_same for 'true' is a specialization of is_same for 'false', taking a single T, where both T will match
  547. // template<typename T> struct is_same<T, T> : public true_type{};
  548. TemplateName tn = TemplateName(isSameFalseTemplateDecl);
  549. NamedDecl* trueParams[] = { tParam };
  550. TemplateParameterList* trueParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, trueParams, _countof(trueParams), NoLoc);
  551. CXXBaseSpecifier* trueBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  552. context.getTrivialTypeSourceInfo(context.getTypeDeclType(trueTypeDecl)), NoLoc);
  553. TemplateArgument ta = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)));
  554. TemplateArgument isSameTrueTemplateArgs[] = { ta, ta };
  555. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueTemplateArgs);
  556. stdNamespace->addDecl(isSameFalseTemplateDecl);
  557. stdNamespace->setImplicit(true);
  558. tuContext->addDecl(stdNamespace);
  559. // This could be a parameter if ever needed.
  560. const bool SupportExtensions = true;
  561. // Consider right-hand const and right-hand ref to be true for is_same:
  562. // template<typename T> struct is_same<T, const T> : public true_type{};
  563. // template<typename T> struct is_same<T, T&> : public true_type{};
  564. if (SupportExtensions)
  565. {
  566. TemplateArgument trueConstArg = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)).withConst());
  567. TemplateArgument isSameTrueConstTemplateArgs[] = { ta, trueConstArg };
  568. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueConstTemplateArgs);
  569. TemplateArgument trueRefArg = TemplateArgument(
  570. context.getLValueReferenceType(context.getCanonicalType(context.getTypeDeclType(tParam))));
  571. TemplateArgument isSameTrueRefTemplateArgs[] = { ta, trueRefArg };
  572. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueRefTemplateArgs);
  573. }
  574. }
  575. /// <summary>
  576. /// Adds a new template type in the specified context with the given name. The record type will have a handle field.
  577. /// </summary>
  578. /// <parm name="context">AST context to which template will be added.</param>
  579. /// <parm name="typeName">Name of template to create.</param>
  580. /// <parm name="templateArgCount">Number of template arguments (one or two).</param>
  581. /// <parm name="defaultTypeArgValue">If assigned, the default argument for the element template.</param>
  582. CXXRecordDecl* hlsl::DeclareTemplateTypeWithHandle(
  583. ASTContext& context,
  584. StringRef name,
  585. uint8_t templateArgCount,
  586. _In_opt_ TypeSourceInfo* defaultTypeArgValue)
  587. {
  588. DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
  589. DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");
  590. // Create an object template declaration in translation unit scope.
  591. // templateArgCount=1: template<typename element> typeName { ... }
  592. // templateArgCount=2: template<typename element, int count> typeName { ... }
  593. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name);
  594. TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", defaultTypeArgValue);
  595. NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
  596. if (templateArgCount > 1)
  597. countTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("count", context.IntTy, 0);
  598. typeDeclBuilder.startDefinition();
  599. CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
  600. // Add an 'h' field to hold the handle.
  601. QualType elementType = context.getTemplateTypeParmType(
  602. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  603. if (templateArgCount > 1 &&
  604. // Only need array type for inputpatch and outputpatch.
  605. // Avoid Texture2DMS which may use 0 count.
  606. // TODO: use hlsl types to do the check.
  607. !name.startswith("Texture")) {
  608. Expr *countExpr = DeclRefExpr::Create(
  609. context, NestedNameSpecifierLoc(), NoLoc, countTemplateParamDecl, false,
  610. DeclarationNameInfo(countTemplateParamDecl->getDeclName(), NoLoc),
  611. context.IntTy, ExprValueKind::VK_RValue);
  612. elementType = context.getDependentSizedArrayType(
  613. elementType, countExpr, ArrayType::ArraySizeModifier::Normal, 0,
  614. SourceRange());
  615. // InputPatch and OutputPatch also have a "Length" static const member for the number of control points
  616. IdentifierInfo& lengthId = context.Idents.get(StringRef("Length"), tok::TokenKind::identifier);
  617. TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(context.IntTy.withConst());
  618. VarDecl* lengthValueDecl = VarDecl::Create(context, templateRecordDecl, NoLoc, NoLoc, &lengthId,
  619. context.IntTy.withConst(), lengthTypeSource, SC_Static);
  620. lengthValueDecl->setInit(countExpr);
  621. lengthValueDecl->setAccess(AS_public);
  622. templateRecordDecl->addDecl(lengthValueDecl);
  623. }
  624. typeDeclBuilder.addField("h", elementType);
  625. return typeDeclBuilder.completeDefinition();
  626. }
  627. FunctionTemplateDecl* hlsl::CreateFunctionTemplateDecl(
  628. ASTContext& context,
  629. _In_ CXXRecordDecl* recordDecl,
  630. _In_ CXXMethodDecl* functionDecl,
  631. _In_count_(templateParamNamedDeclsCount) NamedDecl** templateParamNamedDecls,
  632. size_t templateParamNamedDeclsCount)
  633. {
  634. DXASSERT_NOMSG(recordDecl != nullptr);
  635. DXASSERT_NOMSG(templateParamNamedDecls != nullptr);
  636. DXASSERT(templateParamNamedDeclsCount > 0, "otherwise caller shouldn't invoke this function");
  637. TemplateParameterList* templateParams = TemplateParameterList::Create(
  638. context, NoLoc, NoLoc, &templateParamNamedDecls[0], templateParamNamedDeclsCount, NoLoc);
  639. FunctionTemplateDecl* functionTemplate =
  640. FunctionTemplateDecl::Create(context, recordDecl, NoLoc, functionDecl->getDeclName(), templateParams, functionDecl);
  641. functionTemplate->setAccess(AccessSpecifier::AS_public);
  642. functionTemplate->setLexicalDeclContext(recordDecl);
  643. functionDecl->setDescribedFunctionTemplate(functionTemplate);
  644. recordDecl->addDecl(functionTemplate);
  645. return functionTemplate;
  646. }
  647. static
  648. void AssociateParametersToFunctionPrototype(
  649. _In_ TypeSourceInfo* tinfo,
  650. _In_count_(numParams) ParmVarDecl** paramVarDecls,
  651. unsigned int numParams)
  652. {
  653. FunctionProtoTypeLoc protoLoc = tinfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
  654. DXASSERT(protoLoc.getNumParams() == numParams, "otherwise unexpected number of parameters available");
  655. for (unsigned i = 0; i < numParams; i++) {
  656. DXASSERT(protoLoc.getParam(i) == nullptr, "otherwise prototype parameters were already initialized");
  657. protoLoc.setParam(i, paramVarDecls[i]);
  658. }
  659. }
  660. static void CreateConstructorDeclaration(
  661. ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
  662. ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
  663. _Out_ CXXConstructorDecl **constructorDecl, _Out_ TypeSourceInfo **tinfo) {
  664. DXASSERT_NOMSG(recordDecl != nullptr);
  665. DXASSERT_NOMSG(constructorDecl != nullptr);
  666. FunctionProtoType::ExtProtoInfo functionExtInfo;
  667. functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
  668. QualType functionQT = context.getFunctionType(
  669. resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
  670. DeclarationNameInfo declNameInfo(declarationName, NoLoc);
  671. *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
  672. DXASSERT_NOMSG(*tinfo != nullptr);
  673. *constructorDecl = CXXConstructorDecl::Create(
  674. context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
  675. StorageClass::SC_None, ExplicitFalse, InlineSpecifiedFalse, IsConstexprFalse);
  676. DXASSERT_NOMSG(*constructorDecl != nullptr);
  677. (*constructorDecl)->setLexicalDeclContext(recordDecl);
  678. (*constructorDecl)->setAccess(AccessSpecifier::AS_public);
  679. }
  680. static void CreateObjectFunctionDeclaration(
  681. ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
  682. ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
  683. _Out_ CXXMethodDecl **functionDecl, _Out_ TypeSourceInfo **tinfo) {
  684. DXASSERT_NOMSG(recordDecl != nullptr);
  685. DXASSERT_NOMSG(functionDecl != nullptr);
  686. FunctionProtoType::ExtProtoInfo functionExtInfo;
  687. functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
  688. QualType functionQT = context.getFunctionType(
  689. resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
  690. DeclarationNameInfo declNameInfo(declarationName, NoLoc);
  691. *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
  692. DXASSERT_NOMSG(*tinfo != nullptr);
  693. *functionDecl = CXXMethodDecl::Create(
  694. context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
  695. StorageClass::SC_None, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
  696. DXASSERT_NOMSG(*functionDecl != nullptr);
  697. (*functionDecl)->setLexicalDeclContext(recordDecl);
  698. (*functionDecl)->setAccess(AccessSpecifier::AS_public);
  699. }
  700. CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
  701. ASTContext& context,
  702. _In_ CXXRecordDecl* recordDecl,
  703. QualType resultType,
  704. ArrayRef<QualType> paramTypes,
  705. ArrayRef<StringRef> paramNames,
  706. DeclarationName declarationName,
  707. bool isConst)
  708. {
  709. DXASSERT_NOMSG(recordDecl != nullptr);
  710. DXASSERT_NOMSG(!resultType.isNull());
  711. DXASSERT_NOMSG(paramTypes.size() == paramNames.size());
  712. TypeSourceInfo* tinfo;
  713. CXXMethodDecl* functionDecl;
  714. CreateObjectFunctionDeclaration(context, recordDecl, resultType, paramTypes,
  715. declarationName, isConst, &functionDecl,
  716. &tinfo);
  717. // Create and associate parameters to method.
  718. SmallVector<ParmVarDecl *, 2> parmVarDecls;
  719. if (!paramTypes.empty()) {
  720. for (unsigned int i = 0; i < paramTypes.size(); ++i) {
  721. IdentifierInfo *argIi = &context.Idents.get(paramNames[i]);
  722. ParmVarDecl *parmVarDecl = ParmVarDecl::Create(
  723. context, functionDecl, NoLoc, NoLoc, argIi, paramTypes[i],
  724. context.getTrivialTypeSourceInfo(paramTypes[i], NoLoc),
  725. StorageClass::SC_None, nullptr);
  726. parmVarDecl->setScopeInfo(0, i);
  727. DXASSERT(parmVarDecl->getFunctionScopeIndex() == i,
  728. "otherwise failed to set correct index");
  729. parmVarDecls.push_back(parmVarDecl);
  730. }
  731. functionDecl->setParams(ArrayRef<ParmVarDecl *>(parmVarDecls));
  732. AssociateParametersToFunctionPrototype(tinfo, &parmVarDecls.front(),
  733. parmVarDecls.size());
  734. }
  735. recordDecl->addDecl(functionDecl);
  736. return functionDecl;
  737. }
  738. CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
  739. ASTContext& context, StringRef typeName, StringRef templateParamName) {
  740. // template<uint kind> FeedbackTexture2D[Array] { ... }
  741. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName);
  742. typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy);
  743. typeDeclBuilder.startDefinition();
  744. typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
  745. return typeDeclBuilder.completeDefinition();
  746. }
  747. CXXRecordDecl* hlsl::DeclareRayQueryType(ASTContext& context) {
  748. // template<uint kind> RayQuery { ... }
  749. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "RayQuery");
  750. typeDeclBuilder.addIntegerTemplateParam("flags", context.UnsignedIntTy);
  751. typeDeclBuilder.startDefinition();
  752. typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
  753. // Add constructor that will be lowered to the intrinsic that produces
  754. // the RayQuery handle for this object.
  755. CanQualType canQualType = typeDeclBuilder.getRecordDecl()->getTypeForDecl()->getCanonicalTypeUnqualified();
  756. CXXConstructorDecl *pConstructorDecl = nullptr;
  757. TypeSourceInfo *pTypeSourceInfo = nullptr;
  758. CreateConstructorDeclaration(context, typeDeclBuilder.getRecordDecl(), context.VoidTy, {}, context.DeclarationNames.getCXXConstructorName(canQualType), false, &pConstructorDecl, &pTypeSourceInfo);
  759. typeDeclBuilder.getRecordDecl()->addDecl(pConstructorDecl);
  760. return typeDeclBuilder.completeDefinition();
  761. }
  762. CXXRecordDecl* hlsl::DeclareResourceType(ASTContext& context) {
  763. // struct ResourceDescriptor { uint8 desc; }
  764. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(),
  765. ".Resource",
  766. TagDecl::TagKind::TTK_Struct);
  767. typeDeclBuilder.startDefinition();
  768. typeDeclBuilder.addField("h", GetHLSLObjectHandleType(context));
  769. return typeDeclBuilder.completeDefinition();
  770. }
  771. bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
  772. return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
  773. }
  774. bool hlsl::GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
  775. llvm::StringRef &group) {
  776. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  777. return false;
  778. }
  779. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  780. opcode = A->getOpcode();
  781. group = A->getGroup();
  782. return true;
  783. }
  784. bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S) {
  785. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  786. return false;
  787. }
  788. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  789. S = A->getLowering();
  790. return true;
  791. }
  792. /// <summary>Parses a column or row digit.</summary>
  793. static
  794. bool TryParseColOrRowChar(const char digit, _Out_ int* count) {
  795. if ('1' <= digit && digit <= '4') {
  796. *count = digit - '0';
  797. return true;
  798. }
  799. *count = 0;
  800. return false;
  801. }
  802. /// <summary>Parses a matrix shorthand identifier (eg, float3x2).</summary>
  803. _Use_decl_annotations_
  804. bool hlsl::TryParseMatrixShorthand(
  805. const char* typeName,
  806. size_t typeNameLen,
  807. HLSLScalarType* parsedType,
  808. int* rowCount,
  809. int* colCount,
  810. const clang::LangOptions& langOptions
  811. )
  812. {
  813. //
  814. // Matrix shorthand format is PrimitiveTypeRxC, where R is the row count and C is the column count.
  815. // R and C should be between 1 and 4 inclusive.
  816. // x is a literal 'x' character.
  817. // PrimitiveType is one of the HLSLScalarTypeNames values.
  818. //
  819. if (TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions) &&
  820. *rowCount != 0 && *colCount != 0) {
  821. // compare scalar component
  822. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-3, langOptions);
  823. if (type!= HLSLScalarType_unknown) {
  824. *parsedType = type;
  825. return true;
  826. }
  827. }
  828. // Unable to parse.
  829. return false;
  830. }
  831. /// <summary>Parses a vector shorthand identifier (eg, float3).</summary>
  832. _Use_decl_annotations_
  833. bool hlsl::TryParseVectorShorthand(
  834. const char* typeName,
  835. size_t typeNameLen,
  836. HLSLScalarType* parsedType,
  837. int* elementCount,
  838. const clang::LangOptions& langOptions
  839. )
  840. {
  841. // At least *something*N characters necessary, where something is at least 'int'
  842. if (TryParseColOrRowChar(typeName[typeNameLen - 1], elementCount)) {
  843. // compare scalar component
  844. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-1, langOptions);
  845. if (type!= HLSLScalarType_unknown) {
  846. *parsedType = type;
  847. return true;
  848. }
  849. }
  850. // Unable to parse.
  851. return false;
  852. }
  853. /// <summary>Parses a hlsl scalar type (e.g min16float, uint3x4) </summary>
  854. _Use_decl_annotations_
  855. bool hlsl::TryParseScalar(
  856. _In_count_(typenameLen)
  857. const char* typeName,
  858. size_t typeNameLen,
  859. _Out_ HLSLScalarType *parsedType,
  860. _In_ const clang::LangOptions& langOptions) {
  861. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen, langOptions);
  862. if (type!= HLSLScalarType_unknown) {
  863. *parsedType = type;
  864. return true;
  865. }
  866. return false; // unable to parse
  867. }
  868. /// <summary>Parse any (scalar, vector, matrix) hlsl types (e.g float, int3x4, uint2) </summary>
  869. _Use_decl_annotations_
  870. bool hlsl::TryParseAny(
  871. _In_count_(typenameLen)
  872. const char* typeName,
  873. size_t typeNameLen,
  874. _Out_ HLSLScalarType *parsedType,
  875. int *rowCount,
  876. int *colCount,
  877. _In_ const clang::LangOptions& langOptions) {
  878. // at least 'int'
  879. const size_t MinValidLen = 3;
  880. if (typeNameLen >= MinValidLen) {
  881. TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions);
  882. int suffixLen = *colCount == 0 ? 0 :
  883. *rowCount == 0 ? 1 : 3;
  884. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-suffixLen, langOptions);
  885. if (type!= HLSLScalarType_unknown) {
  886. *parsedType = type;
  887. return true;
  888. }
  889. }
  890. return false;
  891. }
  892. /// <summary>Parse string hlsl type</summary>
  893. _Use_decl_annotations_
  894. bool hlsl::TryParseString(
  895. _In_count_(typenameLen)
  896. const char* typeName,
  897. size_t typeNameLen,
  898. _In_ const clang::LangOptions& langOptions) {
  899. if (typeNameLen == 6 && typeName[0] == 's' && strncmp(typeName, "string", 6) == 0) {
  900. return true;
  901. }
  902. return false;
  903. }
  904. /// <summary>Parse any kind of dimension for vector or matrix (e.g 4,3 in int4x3).
  905. /// If it's a matrix type, rowCount and colCount will be nonzero. If it's a vector type, colCount is 0.
  906. /// Otherwise both rowCount and colCount is 0. Returns true if either matrix or vector dimensions detected. </summary>
  907. _Use_decl_annotations_
  908. bool hlsl::TryParseMatrixOrVectorDimension(
  909. _In_count_(typeNameLen)
  910. const char *typeName,
  911. size_t typeNameLen,
  912. _Out_opt_ int *rowCount,
  913. _Out_opt_ int *colCount,
  914. _In_ const clang::LangOptions& langOptions) {
  915. *rowCount = 0;
  916. *colCount = 0;
  917. size_t MinValidLen = 3; // at least int
  918. if (typeNameLen > MinValidLen) {
  919. if (TryParseColOrRowChar(typeName[typeNameLen - 1], colCount)) {
  920. // Try parse matrix
  921. if (typeName[typeNameLen - 2] == 'x')
  922. TryParseColOrRowChar(typeName[typeNameLen - 3], rowCount);
  923. return true;
  924. }
  925. }
  926. return false;
  927. }
  928. /// <summary>Creates a typedef for a matrix shorthand (eg, float3x2).</summary>
  929. TypedefDecl* hlsl::CreateMatrixSpecializationShorthand(
  930. ASTContext& context,
  931. QualType matrixSpecialization,
  932. HLSLScalarType scalarType,
  933. size_t rowCount,
  934. size_t colCount)
  935. {
  936. DXASSERT(rowCount <= 4, "else caller didn't validate rowCount");
  937. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  938. char typeName[64];
  939. sprintf_s(typeName, _countof(typeName), "%s%ux%u",
  940. HLSLScalarTypeNames[scalarType], (unsigned)rowCount, (unsigned)colCount);
  941. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  942. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  943. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  944. context.getTrivialTypeSourceInfo(matrixSpecialization, NoLoc));
  945. decl->setImplicit(true);
  946. currentDeclContext->addDecl(decl);
  947. return decl;
  948. }
  949. /// <summary>Creates a typedef for a vector shorthand (eg, float3).</summary>
  950. TypedefDecl* hlsl::CreateVectorSpecializationShorthand(
  951. ASTContext& context,
  952. QualType vectorSpecialization,
  953. HLSLScalarType scalarType,
  954. size_t colCount)
  955. {
  956. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  957. char typeName[64];
  958. sprintf_s(typeName, _countof(typeName), "%s%u",
  959. HLSLScalarTypeNames[scalarType], (unsigned)colCount);
  960. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  961. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  962. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  963. context.getTrivialTypeSourceInfo(vectorSpecialization, NoLoc));
  964. decl->setImplicit(true);
  965. currentDeclContext->addDecl(decl);
  966. return decl;
  967. }
  968. llvm::ArrayRef<hlsl::UnusualAnnotation*>
  969. hlsl::UnusualAnnotation::CopyToASTContextArray(
  970. clang::ASTContext& Context, hlsl::UnusualAnnotation** begin, size_t count) {
  971. if (count == 0) {
  972. return llvm::ArrayRef<hlsl::UnusualAnnotation*>();
  973. }
  974. UnusualAnnotation** arr = ::new (Context) UnusualAnnotation*[count];
  975. for (size_t i = 0; i < count; ++i) {
  976. arr[i] = begin[i]->CopyToASTContext(Context);
  977. }
  978. return llvm::ArrayRef<hlsl::UnusualAnnotation*>(arr, count);
  979. }
  980. UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context) {
  981. // All UnusualAnnotation instances can be blitted.
  982. size_t instanceSize;
  983. switch (Kind) {
  984. case UA_RegisterAssignment:
  985. instanceSize = sizeof(hlsl::RegisterAssignment);
  986. break;
  987. case UA_ConstantPacking:
  988. instanceSize = sizeof(hlsl::ConstantPacking);
  989. break;
  990. default:
  991. DXASSERT(Kind == UA_SemanticDecl, "Kind == UA_SemanticDecl -- otherwise switch is incomplete");
  992. instanceSize = sizeof(hlsl::SemanticDecl);
  993. break;
  994. }
  995. void* result = Context.Allocate(instanceSize);
  996. memcpy(result, this, instanceSize);
  997. return (UnusualAnnotation*)result;
  998. }
  999. static bool HasTessFactorSemantic(const ValueDecl *decl) {
  1000. for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  1001. if (it->getKind() == UnusualAnnotation::UA_SemanticDecl) {
  1002. const SemanticDecl *sd = cast<SemanticDecl>(it);
  1003. const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
  1004. if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
  1005. return true;
  1006. }
  1007. }
  1008. return false;
  1009. }
  1010. static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
  1011. if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
  1012. return false;
  1013. if (const RecordType *RT = Ty->getAsStructureType()) {
  1014. RecordDecl *RD = RT->getDecl();
  1015. for (FieldDecl *fieldDecl : RD->fields()) {
  1016. if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
  1017. return true;
  1018. }
  1019. return false;
  1020. }
  1021. if (Ty->getAsArrayTypeUnsafe())
  1022. return HasTessFactorSemantic(decl);
  1023. return false;
  1024. }
  1025. bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
  1026. // This checks whether the function is structurally capable of being a patch
  1027. // constant function, not whether it is in fact the patch constant function
  1028. // for the entry point of a compiled hull shader (which may not have been
  1029. // seen yet). So the answer is conservative.
  1030. if (!FD->getReturnType()->isVoidType()) {
  1031. // Try to find TessFactor in return type.
  1032. if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
  1033. return true;
  1034. }
  1035. // Try to find TessFactor in out param.
  1036. for (const ParmVarDecl *param : FD->params()) {
  1037. if (param->hasAttr<HLSLOutAttr>()) {
  1038. if (HasTessFactorSemanticRecurse(param, param->getType()))
  1039. return true;
  1040. }
  1041. }
  1042. return false;
  1043. }