ASTContextHLSL.cpp 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236
  1. //===--- ASTContextHLSL.cpp - HLSL support for AST nodes and operations ---===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // ASTContextHLSL.cpp //
  5. // Copyright (C) Microsoft Corporation. All rights reserved. //
  6. // This file is distributed under the University of Illinois Open Source //
  7. // License. See LICENSE.TXT for details. //
  8. // //
  9. // This file implements the ASTContext interface for HLSL. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "clang/AST/ASTContext.h"
  13. #include "clang/AST/Attr.h"
  14. #include "clang/AST/DeclCXX.h"
  15. #include "clang/AST/DeclTemplate.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/ExprCXX.h"
  18. #include "clang/AST/ExternalASTSource.h"
  19. #include "clang/AST/HlslBuiltinTypeDeclBuilder.h"
  20. #include "clang/AST/TypeLoc.h"
  21. #include "clang/Sema/SemaDiagnostic.h"
  22. #include "clang/Sema/Sema.h"
  23. #include "clang/Sema/Overload.h"
  24. #include "dxc/Support/Global.h"
  25. #include "dxc/HLSL/HLOperations.h"
  26. #include "dxc/DXIL/DxilSemantic.h"
  27. #include "dxc/HlslIntrinsicOp.h"
  28. using namespace clang;
  29. using namespace hlsl;
  30. static const int FirstTemplateDepth = 0;
  31. static const int FirstParamPosition = 0;
  32. static const bool ForConstFalse = false; // a construct is targeting a const type
  33. static const bool ForConstTrue = true; // a construct is targeting a non-const type
  34. static const bool ParameterPackFalse = false; // template parameter is not an ellipsis.
  35. static const bool TypenameFalse = false; // 'typename' specified rather than 'class' for a template argument.
  36. static const bool DelayTypeCreationTrue = true; // delay type creation for a declaration
  37. static const SourceLocation NoLoc; // no source location attribution available
  38. static const bool InlineFalse = false; // namespace is not an inline namespace
  39. static const bool InlineSpecifiedFalse = false; // function was not specified as inline
  40. static const bool ExplicitFalse = false; // constructor was not specified as explicit
  41. static const bool IsConstexprFalse = false; // function is not constexpr
  42. static const bool VirtualFalse = false; // whether the base class is declares 'virtual'
  43. static const bool BaseClassFalse = false; // whether the base class is declared as 'class' (vs. 'struct')
  44. /// <summary>Names of HLSLScalarType enumeration values, in matching order to HLSLScalarType.</summary>
  45. const char* HLSLScalarTypeNames[] = {
  46. "<unknown>",
  47. "bool",
  48. "int",
  49. "uint",
  50. "dword",
  51. "half",
  52. "float",
  53. "double",
  54. "min10float",
  55. "min16float",
  56. "min12int",
  57. "min16int",
  58. "min16uint",
  59. "literal float",
  60. "literal int",
  61. "int16_t",
  62. "int32_t",
  63. "int64_t",
  64. "uint16_t",
  65. "uint32_t",
  66. "uint64_t",
  67. "float16_t",
  68. "float32_t",
  69. "float64_t",
  70. "int8_t4_packed",
  71. "uint8_t4_packed"
  72. };
  73. static_assert(HLSLScalarTypeCount == _countof(HLSLScalarTypeNames), "otherwise scalar constants are not aligned");
  74. static HLSLScalarType FindScalarTypeByName(const char *typeName, const size_t typeLen, const LangOptions& langOptions) {
  75. // skipped HLSLScalarType: unknown, literal int, literal float
  76. switch (typeLen) {
  77. case 3: // int
  78. if (typeName[0] == 'i') {
  79. if (strncmp(typeName, "int", 3))
  80. break;
  81. return HLSLScalarType_int;
  82. }
  83. break;
  84. case 4: // bool, uint, half
  85. if (typeName[0] == 'b') {
  86. if (strncmp(typeName, "bool", 4))
  87. break;
  88. return HLSLScalarType_bool;
  89. }
  90. else if (typeName[0] == 'u') {
  91. if (strncmp(typeName, "uint", 4))
  92. break;
  93. return HLSLScalarType_uint;
  94. }
  95. else if (typeName[0] == 'h') {
  96. if (strncmp(typeName, "half", 4))
  97. break;
  98. return HLSLScalarType_half;
  99. }
  100. break;
  101. case 5: // dword, float
  102. if (typeName[0] == 'd') {
  103. if (strncmp(typeName, "dword", 5))
  104. break;
  105. return HLSLScalarType_dword;
  106. }
  107. else if (typeName[0] == 'f') {
  108. if (strncmp(typeName, "float", 5))
  109. break;
  110. return HLSLScalarType_float;
  111. }
  112. break;
  113. case 6: // double
  114. if (typeName[0] == 'd') {
  115. if (strncmp(typeName, "double", 6))
  116. break;
  117. return HLSLScalarType_double;
  118. }
  119. break;
  120. case 7: // int64_t
  121. if (typeName[0] == 'i' && typeName[1] == 'n') {
  122. if (typeName[3] == '6') {
  123. if (strncmp(typeName, "int64_t", 7))
  124. break;
  125. return HLSLScalarType_int64;
  126. }
  127. }
  128. case 8: // min12int, min16int, uint64_t
  129. if (typeName[0] == 'm' && typeName[1] == 'i') {
  130. if (typeName[4] == '2') {
  131. if (strncmp(typeName, "min12int", 8))
  132. break;
  133. return HLSLScalarType_int_min12;
  134. }
  135. else if (typeName[4] == '6') {
  136. if (strncmp(typeName, "min16int", 8))
  137. break;
  138. return HLSLScalarType_int_min16;
  139. }
  140. }
  141. else if (typeName[0] == 'u' && typeName[1] == 'i') {
  142. if (typeName[4] == '6') {
  143. if (strncmp(typeName, "uint64_t", 8))
  144. break;
  145. return HLSLScalarType_uint64;
  146. }
  147. }
  148. break;
  149. case 9: // min16uint
  150. if (typeName[0] == 'm' && typeName[1] == 'i') {
  151. if (strncmp(typeName, "min16uint", 9))
  152. break;
  153. return HLSLScalarType_uint_min16;
  154. }
  155. break;
  156. case 10: // min10float, min16float
  157. if (typeName[0] == 'm' && typeName[1] == 'i') {
  158. if (typeName[4] == '0') {
  159. if (strncmp(typeName, "min10float", 10))
  160. break;
  161. return HLSLScalarType_float_min10;
  162. }
  163. if (typeName[4] == '6') {
  164. if (strncmp(typeName, "min16float", 10))
  165. break;
  166. return HLSLScalarType_float_min16;
  167. }
  168. }
  169. break;
  170. case 14: // int8_t4_packed
  171. if (typeName[0] == 'i' && typeName[1] == 'n') {
  172. if (strncmp(typeName, "int8_t4_packed", 14))
  173. break;
  174. return HLSLScalarType_int8_4packed;
  175. }
  176. break;
  177. case 15: // uint8_t4_packed
  178. if (typeName[0] == 'u' && typeName[1] == 'i') {
  179. if (strncmp(typeName, "uint8_t4_packed", 15))
  180. break;
  181. return HLSLScalarType_uint8_4packed;
  182. }
  183. break;
  184. default:
  185. break;
  186. }
  187. // fixed width types (int16_t, uint16_t, int32_t, uint32_t, float16_t, float32_t, float64_t)
  188. // are only supported in HLSL 2018
  189. if (langOptions.HLSLVersion >= 2018) {
  190. switch (typeLen) {
  191. case 7: // int16_t, int32_t
  192. if (typeName[0] == 'i' && typeName[1] == 'n') {
  193. if (!langOptions.UseMinPrecision) {
  194. if (typeName[3] == '1') {
  195. if (strncmp(typeName, "int16_t", 7))
  196. break;
  197. return HLSLScalarType_int16;
  198. }
  199. }
  200. if (typeName[3] == '3') {
  201. if (strncmp(typeName, "int32_t", 7))
  202. break;
  203. return HLSLScalarType_int32;
  204. }
  205. }
  206. case 8: // uint16_t, uint32_t
  207. if (!langOptions.UseMinPrecision) {
  208. if (typeName[0] == 'u' && typeName[1] == 'i') {
  209. if (typeName[4] == '1') {
  210. if (strncmp(typeName, "uint16_t", 8))
  211. break;
  212. return HLSLScalarType_uint16;
  213. }
  214. }
  215. }
  216. if (typeName[4] == '3') {
  217. if (strncmp(typeName, "uint32_t", 8))
  218. break;
  219. return HLSLScalarType_uint32;
  220. }
  221. case 9: // float16_t, float32_t, float64_t
  222. if (typeName[0] == 'f' && typeName[1] == 'l') {
  223. if (!langOptions.UseMinPrecision) {
  224. if (typeName[5] == '1') {
  225. if (strncmp(typeName, "float16_t", 9))
  226. break;
  227. return HLSLScalarType_float16;
  228. }
  229. }
  230. if (typeName[5] == '3') {
  231. if (strncmp(typeName, "float32_t", 9))
  232. break;
  233. return HLSLScalarType_float32;
  234. }
  235. else if (typeName[5] == '6') {
  236. if (strncmp(typeName, "float64_t", 9))
  237. break;
  238. return HLSLScalarType_float64;
  239. }
  240. }
  241. }
  242. }
  243. return HLSLScalarType_unknown;
  244. }
  245. /// <summary>Provides the primitive type for lowering matrix types to IR.</summary>
  246. static
  247. CanQualType GetHLSLObjectHandleType(ASTContext& context)
  248. {
  249. return context.IntTy;
  250. }
  251. static
  252. void AddSubscriptOperator(
  253. ASTContext& context, unsigned int templateDepth, TemplateTypeParmDecl *elementTemplateParamDecl,
  254. NonTypeTemplateParmDecl* colCountTemplateParamDecl, QualType intType, CXXRecordDecl* templateRecordDecl,
  255. ClassTemplateDecl* vectorTemplateDecl,
  256. bool forConst)
  257. {
  258. QualType elementType = context.getTemplateTypeParmType(
  259. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  260. Expr* sizeExpr = DeclRefExpr::Create(context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl, false,
  261. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  262. intType, ExprValueKind::VK_RValue);
  263. CXXRecordDecl *vecTemplateRecordDecl = vectorTemplateDecl->getTemplatedDecl();
  264. const clang::Type *vecTy = vecTemplateRecordDecl->getTypeForDecl();
  265. TemplateArgument templateArgs[2] =
  266. {
  267. TemplateArgument(elementType),
  268. TemplateArgument(sizeExpr)
  269. };
  270. TemplateName canonName = context.getCanonicalTemplateName(TemplateName(vectorTemplateDecl));
  271. QualType vectorType = context.getTemplateSpecializationType(
  272. canonName, templateArgs, _countof(templateArgs), QualType(vecTy, 0));
  273. vectorType = context.getLValueReferenceType(vectorType);
  274. if (forConst)
  275. vectorType = context.getConstType(vectorType);
  276. QualType indexType = intType;
  277. CreateObjectFunctionDeclarationWithParams(
  278. context, templateRecordDecl, vectorType,
  279. ArrayRef<QualType>(indexType), ArrayRef<StringRef>(StringRef("index")),
  280. context.DeclarationNames.getCXXOperatorName(OO_Subscript), forConst);
  281. }
  282. /// <summary>Adds up-front support for HLSL matrix types (just the template declaration).</summary>
  283. void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorTemplateDecl, ClassTemplateDecl** matrixTemplateDecl)
  284. {
  285. DXASSERT_NOMSG(matrixTemplateDecl != nullptr);
  286. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  287. // Create a matrix template declaration in translation unit scope.
  288. // template<typename element, int row_count, int col_count> matrix { ... }
  289. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "matrix");
  290. TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", (QualType)context.FloatTy);
  291. NonTypeTemplateParmDecl* rowCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("row_count", context.IntTy, 4);
  292. NonTypeTemplateParmDecl* colCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("col_count", context.IntTy, 4);
  293. typeDeclBuilder.startDefinition();
  294. CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
  295. ClassTemplateDecl* classTemplateDecl = typeDeclBuilder.getTemplateDecl();
  296. // Add an 'h' field to hold the handle.
  297. // The type is vector<element, col>[row].
  298. QualType elementType = context.getTemplateTypeParmType(
  299. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  300. Expr *sizeExpr = DeclRefExpr::Create(
  301. context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
  302. false,
  303. DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
  304. context.IntTy, ExprValueKind::VK_RValue);
  305. Expr *rowSizeExpr = DeclRefExpr::Create(
  306. context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
  307. false,
  308. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  309. context.IntTy, ExprValueKind::VK_RValue);
  310. QualType vectorType = context.getDependentSizedExtVectorType(
  311. elementType, rowSizeExpr, SourceLocation());
  312. QualType vectorArrayType = context.getDependentSizedArrayType(
  313. vectorType, sizeExpr, ArrayType::Normal, 0, SourceRange());
  314. typeDeclBuilder.addField("h", vectorArrayType);
  315. // Add an operator[]. The operator ranges from zero to rowcount-1, and returns a vector of colcount elements.
  316. const unsigned int templateDepth = 0;
  317. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  318. colCountTemplateParamDecl, context.UnsignedIntTy,
  319. templateRecordDecl, vectorTemplateDecl, ForConstFalse);
  320. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  321. colCountTemplateParamDecl, context.UnsignedIntTy,
  322. templateRecordDecl, vectorTemplateDecl, ForConstTrue);
  323. typeDeclBuilder.completeDefinition();
  324. *matrixTemplateDecl = classTemplateDecl;
  325. }
  326. static void AddHLSLVectorSubscriptAttr(Decl *D, ASTContext &context) {
  327. StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
  328. D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)));
  329. }
  330. /// <summary>Adds up-front support for HLSL vector types (just the template declaration).</summary>
  331. void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vectorTemplateDecl)
  332. {
  333. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  334. // Create a vector template declaration in translation unit scope.
  335. // template<typename element, int element_count> vector { ... }
  336. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "vector");
  337. TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", (QualType)context.FloatTy);
  338. NonTypeTemplateParmDecl* colCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("element_count", context.IntTy, 4);
  339. typeDeclBuilder.startDefinition();
  340. CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
  341. ClassTemplateDecl* classTemplateDecl = typeDeclBuilder.getTemplateDecl();
  342. Expr *vecSizeExpr = DeclRefExpr::Create(
  343. context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
  344. false,
  345. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  346. context.IntTy, ExprValueKind::VK_RValue);
  347. const unsigned int templateDepth = 0;
  348. QualType resultType = context.getTemplateTypeParmType(
  349. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  350. QualType vectorType = context.getDependentSizedExtVectorType(
  351. resultType, vecSizeExpr, SourceLocation());
  352. // Add an 'h' field to hold the handle.
  353. typeDeclBuilder.addField("h", vectorType);
  354. // Add an operator[]. The operator ranges from zero to colcount-1, and returns a scalar.
  355. // ForConstTrue:
  356. QualType refResultType = context.getConstType(context.getLValueReferenceType(resultType));
  357. CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
  358. context, templateRecordDecl, refResultType,
  359. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  360. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstTrue);
  361. AddHLSLVectorSubscriptAttr(functionDecl, context);
  362. // ForConstFalse:
  363. resultType = context.getLValueReferenceType(resultType);
  364. functionDecl = CreateObjectFunctionDeclarationWithParams(
  365. context, templateRecordDecl, resultType,
  366. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  367. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstFalse);
  368. AddHLSLVectorSubscriptAttr(functionDecl, context);
  369. typeDeclBuilder.completeDefinition();
  370. *vectorTemplateDecl = classTemplateDecl;
  371. }
  372. /// <summary>
  373. /// Adds a new record type in the specified context with the given name. The record type will have a handle field.
  374. /// </summary>
  375. CXXRecordDecl* hlsl::DeclareRecordTypeWithHandle(ASTContext& context, StringRef name) {
  376. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name, TagDecl::TagKind::TTK_Struct);
  377. typeDeclBuilder.startDefinition();
  378. typeDeclBuilder.addField("h", GetHLSLObjectHandleType(context));
  379. return typeDeclBuilder.completeDefinition();
  380. }
  381. // creates a global static constant unsigned integer with value.
  382. // equivalent to: static const uint name = val;
  383. static void AddConstUInt(clang::ASTContext& context, DeclContext *DC, StringRef name, unsigned val) {
  384. IdentifierInfo &Id = context.Idents.get(name, tok::TokenKind::identifier);
  385. QualType type = context.getConstType(context.UnsignedIntTy);
  386. VarDecl *varDecl = VarDecl::Create(context, DC, NoLoc, NoLoc, &Id, type,
  387. context.getTrivialTypeSourceInfo(type),
  388. clang::StorageClass::SC_Static);
  389. Expr *exprVal = IntegerLiteral::Create(
  390. context, llvm::APInt(context.getIntWidth(type), val), type, NoLoc);
  391. varDecl->setInit(exprVal);
  392. varDecl->setImplicit(true);
  393. DC->addDecl(varDecl);
  394. }
  395. static void AddConstUInt(clang::ASTContext& context, StringRef name, unsigned val) {
  396. AddConstUInt(context, context.getTranslationUnitDecl(), name, val);
  397. }
  398. // Adds a top-level enum with the given enumerants.
  399. struct Enumerant { StringRef name; unsigned value; };
  400. static void AddTypedefPseudoEnum(ASTContext& context, StringRef name, ArrayRef<Enumerant> enumerants) {
  401. DeclContext* curDC = context.getTranslationUnitDecl();
  402. // typedef uint <name>;
  403. IdentifierInfo& enumId = context.Idents.get(name, tok::TokenKind::identifier);
  404. TypeSourceInfo* uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
  405. TypedefDecl* enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
  406. curDC->addDecl(enumDecl);
  407. enumDecl->setImplicit(true);
  408. // static const uint <enumerant.name> = <enumerant.value>;
  409. for (const Enumerant& enumerant : enumerants) {
  410. AddConstUInt(context, curDC, enumerant.name, enumerant.value);
  411. }
  412. }
  413. /// <summary> Adds all constants and enums for ray tracing </summary>
  414. void hlsl::AddRaytracingConstants(ASTContext& context) {
  415. AddTypedefPseudoEnum(context, "RAY_FLAG", {
  416. { "RAY_FLAG_NONE", (unsigned)DXIL::RayFlag::None },
  417. { "RAY_FLAG_FORCE_OPAQUE", (unsigned)DXIL::RayFlag::ForceOpaque },
  418. { "RAY_FLAG_FORCE_NON_OPAQUE", (unsigned)DXIL::RayFlag::ForceNonOpaque },
  419. { "RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH", (unsigned)DXIL::RayFlag::AcceptFirstHitAndEndSearch },
  420. { "RAY_FLAG_SKIP_CLOSEST_HIT_SHADER", (unsigned)DXIL::RayFlag::SkipClosestHitShader },
  421. { "RAY_FLAG_CULL_BACK_FACING_TRIANGLES", (unsigned)DXIL::RayFlag::CullBackFacingTriangles },
  422. { "RAY_FLAG_CULL_FRONT_FACING_TRIANGLES", (unsigned)DXIL::RayFlag::CullFrontFacingTriangles },
  423. { "RAY_FLAG_CULL_OPAQUE", (unsigned)DXIL::RayFlag::CullOpaque },
  424. { "RAY_FLAG_CULL_NON_OPAQUE", (unsigned)DXIL::RayFlag::CullNonOpaque },
  425. { "RAY_FLAG_SKIP_TRIANGLES", (unsigned)DXIL::RayFlag::SkipTriangles },
  426. { "RAY_FLAG_SKIP_PROCEDURAL_PRIMITIVES", (unsigned)DXIL::RayFlag::SkipProceduralPrimitives },
  427. });
  428. AddTypedefPseudoEnum(context, "COMMITTED_STATUS", {
  429. { "COMMITTED_NOTHING", (unsigned)DXIL::CommittedStatus::CommittedNothing },
  430. { "COMMITTED_TRIANGLE_HIT", (unsigned)DXIL::CommittedStatus::CommittedTriangleHit },
  431. { "COMMITTED_PROCEDURAL_PRIMITIVE_HIT", (unsigned)DXIL::CommittedStatus::CommittedProceduralPrimitiveHit }
  432. });
  433. AddTypedefPseudoEnum(context, "CANDIDATE_TYPE", {
  434. { "CANDIDATE_NON_OPAQUE_TRIANGLE", (unsigned)DXIL::CandidateType::CandidateNonOpaqueTriangle },
  435. { "CANDIDATE_PROCEDURAL_PRIMITIVE", (unsigned)DXIL::CandidateType::CandidateProceduralPrimitive }
  436. });
  437. // static const uint HIT_KIND_* = *;
  438. AddConstUInt(context, StringRef("HIT_KIND_NONE"), (unsigned)DXIL::HitKind::None);
  439. AddConstUInt(context, StringRef("HIT_KIND_TRIANGLE_FRONT_FACE"), (unsigned)DXIL::HitKind::TriangleFrontFace);
  440. AddConstUInt(context, StringRef("HIT_KIND_TRIANGLE_BACK_FACE"), (unsigned)DXIL::HitKind::TriangleBackFace);
  441. AddConstUInt(context, StringRef("STATE_OBJECT_FLAGS_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS"), (unsigned)DXIL::StateObjectFlags::AllowLocalDependenciesOnExternalDefinitions);
  442. AddConstUInt(context, StringRef("STATE_OBJECT_FLAGS_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
  443. // The above "_FLAGS_" was a typo, leaving in to avoid breaking anyone. Supposed to be _FLAG_ below.
  444. AddConstUInt(context, StringRef("STATE_OBJECT_FLAG_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS"), (unsigned)DXIL::StateObjectFlags::AllowLocalDependenciesOnExternalDefinitions);
  445. AddConstUInt(context, StringRef("STATE_OBJECT_FLAG_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
  446. AddConstUInt(context, StringRef("STATE_OBJECT_FLAG_ALLOW_STATE_OBJECT_ADDITIONS"), (unsigned)DXIL::StateObjectFlags::AllowStateObjectAdditions);
  447. AddConstUInt(context, StringRef("RAYTRACING_PIPELINE_FLAG_NONE"), (unsigned)DXIL::RaytracingPipelineFlags::None);
  448. AddConstUInt(context, StringRef("RAYTRACING_PIPELINE_FLAG_SKIP_TRIANGLES"), (unsigned)DXIL::RaytracingPipelineFlags::SkipTriangles);
  449. AddConstUInt(context, StringRef("RAYTRACING_PIPELINE_FLAG_SKIP_PROCEDURAL_PRIMITIVES"), (unsigned)DXIL::RaytracingPipelineFlags::SkipProceduralPrimitives);
  450. }
  451. /// <summary> Adds all constants and enums for sampler feedback </summary>
  452. void hlsl::AddSamplerFeedbackConstants(ASTContext& context) {
  453. AddConstUInt(context, StringRef("SAMPLER_FEEDBACK_MIN_MIP"), (unsigned)DXIL::SamplerFeedbackType::MinMip);
  454. AddConstUInt(context, StringRef("SAMPLER_FEEDBACK_MIP_REGION_USED"), (unsigned)DXIL::SamplerFeedbackType::MipRegionUsed);
  455. }
  456. static
  457. Expr* IntConstantAsBoolExpr(clang::Sema& sema, uint64_t value)
  458. {
  459. return sema.ImpCastExprToType(
  460. sema.ActOnIntegerConstant(NoLoc, value).get(), sema.getASTContext().BoolTy, CK_IntegralToBoolean).get();
  461. }
  462. static
  463. CXXRecordDecl* CreateStdStructWithStaticBool(clang::ASTContext& context, NamespaceDecl* stdNamespace, IdentifierInfo& trueTypeId, IdentifierInfo& valueId, Expr* trueExpression)
  464. {
  465. // struct true_type { static const bool value = true; }
  466. TypeSourceInfo* boolTypeSource = context.getTrivialTypeSourceInfo(context.BoolTy.withConst());
  467. CXXRecordDecl* trueTypeDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &trueTypeId, nullptr, DelayTypeCreationTrue);
  468. // static fields are variables in the AST
  469. VarDecl* trueValueDecl = VarDecl::Create(context, trueTypeDecl, NoLoc, NoLoc, &valueId,
  470. context.BoolTy.withConst(), boolTypeSource, SC_Static);
  471. trueValueDecl->setInit(trueExpression);
  472. trueValueDecl->setConstexpr(true);
  473. trueValueDecl->setAccess(AS_public);
  474. trueTypeDecl->setLexicalDeclContext(stdNamespace);
  475. trueTypeDecl->startDefinition();
  476. trueTypeDecl->addDecl(trueValueDecl);
  477. trueTypeDecl->completeDefinition();
  478. stdNamespace->addDecl(trueTypeDecl);
  479. return trueTypeDecl;
  480. }
  481. static
  482. void DefineRecordWithBase(CXXRecordDecl* decl, DeclContext* lexicalContext, const CXXBaseSpecifier* base)
  483. {
  484. decl->setLexicalDeclContext(lexicalContext);
  485. decl->startDefinition();
  486. decl->setBases(&base, 1);
  487. decl->completeDefinition();
  488. lexicalContext->addDecl(decl);
  489. }
  490. static
  491. void SetPartialExplicitSpecialization(ClassTemplateDecl* templateDecl, ClassTemplatePartialSpecializationDecl* specializationDecl)
  492. {
  493. specializationDecl->setSpecializationKind(TSK_ExplicitSpecialization);
  494. templateDecl->AddPartialSpecialization(specializationDecl, nullptr);
  495. }
  496. static
  497. void CreateIsEqualSpecialization(ASTContext& context, ClassTemplateDecl* templateDecl, TemplateName& templateName,
  498. DeclContext* lexicalContext, const CXXBaseSpecifier* base, TemplateParameterList* templateParamList,
  499. TemplateArgument (&templateArgs)[2])
  500. {
  501. QualType specializationCanonType = context.getTemplateSpecializationType(templateName, templateArgs, _countof(templateArgs));
  502. TemplateArgumentListInfo templateArgsListInfo = TemplateArgumentListInfo(NoLoc, NoLoc);
  503. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[0], context.getTrivialTypeSourceInfo(templateArgs[0].getAsType())));
  504. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[1], context.getTrivialTypeSourceInfo(templateArgs[1].getAsType())));
  505. ClassTemplatePartialSpecializationDecl* specializationDecl =
  506. ClassTemplatePartialSpecializationDecl::Create(context, TTK_Struct, lexicalContext, NoLoc, NoLoc,
  507. templateParamList, templateDecl, templateArgs, _countof(templateArgs),
  508. templateArgsListInfo, specializationCanonType, nullptr);
  509. context.getTagDeclType(specializationDecl); // Fault this in now.
  510. DefineRecordWithBase(specializationDecl, lexicalContext, base);
  511. SetPartialExplicitSpecialization(templateDecl, specializationDecl);
  512. }
  513. /// <summary>Adds the implementation for std::is_equal.</summary>
  514. void hlsl::AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema)
  515. {
  516. // The goal is to support std::is_same<T, T>::value for testing purposes, in a manner that can
  517. // evolve into a compliant feature in the future.
  518. //
  519. // The definitions necessary are as follows (all in the std namespace).
  520. // template <class T, T v>
  521. // struct integral_constant {
  522. // typedef T value_type;
  523. // static const value_type value = v;
  524. // operator value_type() { return value; }
  525. // };
  526. //
  527. // typedef integral_constant<bool, true> true_type;
  528. // typedef integral_constant<bool, false> false_type;
  529. //
  530. // template<typename T, typename U> struct is_same : public false_type {};
  531. // template<typename T> struct is_same<T, T> : public true_type{};
  532. //
  533. // We instead use these simpler definitions for true_type and false_type.
  534. // struct false_type { static const bool value = false; };
  535. // struct true_type { static const bool value = true; };
  536. DeclContext* tuContext = context.getTranslationUnitDecl();
  537. IdentifierInfo& stdId = context.Idents.get(StringRef("std"), tok::TokenKind::identifier);
  538. IdentifierInfo& trueTypeId = context.Idents.get(StringRef("true_type"), tok::TokenKind::identifier);
  539. IdentifierInfo& falseTypeId = context.Idents.get(StringRef("false_type"), tok::TokenKind::identifier);
  540. IdentifierInfo& valueId = context.Idents.get(StringRef("value"), tok::TokenKind::identifier);
  541. IdentifierInfo& isSameId = context.Idents.get(StringRef("is_same"), tok::TokenKind::identifier);
  542. IdentifierInfo& tId = context.Idents.get(StringRef("T"), tok::TokenKind::identifier);
  543. IdentifierInfo& vId = context.Idents.get(StringRef("V"), tok::TokenKind::identifier);
  544. Expr* trueExpression = IntConstantAsBoolExpr(sema, 1);
  545. Expr* falseExpression = IntConstantAsBoolExpr(sema, 0);
  546. // namespace std
  547. NamespaceDecl* stdNamespace = NamespaceDecl::Create(context, tuContext, InlineFalse, NoLoc, NoLoc, &stdId, nullptr);
  548. CXXRecordDecl* trueTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, trueTypeId, valueId, trueExpression);
  549. CXXRecordDecl* falseTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, falseTypeId, valueId, falseExpression);
  550. // template<typename T, typename U> struct is_same : public false_type {};
  551. CXXRecordDecl* isSameFalseRecordDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &isSameId, nullptr, false);
  552. TemplateTypeParmDecl* tParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition, &tId, TypenameFalse, ParameterPackFalse);
  553. TemplateTypeParmDecl* uParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition + 1, &vId, TypenameFalse, ParameterPackFalse);
  554. NamedDecl* falseParams[] = { tParam, uParam };
  555. TemplateParameterList* falseParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, falseParams, _countof(falseParams), NoLoc);
  556. ClassTemplateDecl* isSameFalseTemplateDecl = ClassTemplateDecl::Create(context, stdNamespace, NoLoc, DeclarationName(&isSameId), falseParamList, isSameFalseRecordDecl, nullptr);
  557. context.getTagDeclType(isSameFalseRecordDecl); // Fault this in now.
  558. CXXBaseSpecifier* falseBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  559. context.getTrivialTypeSourceInfo(context.getTypeDeclType(falseTypeDecl)), NoLoc);
  560. isSameFalseRecordDecl->setDescribedClassTemplate(isSameFalseTemplateDecl);
  561. isSameFalseTemplateDecl->setLexicalDeclContext(stdNamespace);
  562. DefineRecordWithBase(isSameFalseRecordDecl, stdNamespace, falseBase);
  563. // is_same for 'true' is a specialization of is_same for 'false', taking a single T, where both T will match
  564. // template<typename T> struct is_same<T, T> : public true_type{};
  565. TemplateName tn = TemplateName(isSameFalseTemplateDecl);
  566. NamedDecl* trueParams[] = { tParam };
  567. TemplateParameterList* trueParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, trueParams, _countof(trueParams), NoLoc);
  568. CXXBaseSpecifier* trueBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  569. context.getTrivialTypeSourceInfo(context.getTypeDeclType(trueTypeDecl)), NoLoc);
  570. TemplateArgument ta = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)));
  571. TemplateArgument isSameTrueTemplateArgs[] = { ta, ta };
  572. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueTemplateArgs);
  573. stdNamespace->addDecl(isSameFalseTemplateDecl);
  574. stdNamespace->setImplicit(true);
  575. tuContext->addDecl(stdNamespace);
  576. // This could be a parameter if ever needed.
  577. const bool SupportExtensions = true;
  578. // Consider right-hand const and right-hand ref to be true for is_same:
  579. // template<typename T> struct is_same<T, const T> : public true_type{};
  580. // template<typename T> struct is_same<T, T&> : public true_type{};
  581. if (SupportExtensions)
  582. {
  583. TemplateArgument trueConstArg = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)).withConst());
  584. TemplateArgument isSameTrueConstTemplateArgs[] = { ta, trueConstArg };
  585. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueConstTemplateArgs);
  586. TemplateArgument trueRefArg = TemplateArgument(
  587. context.getLValueReferenceType(context.getCanonicalType(context.getTypeDeclType(tParam))));
  588. TemplateArgument isSameTrueRefTemplateArgs[] = { ta, trueRefArg };
  589. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueRefTemplateArgs);
  590. }
  591. }
  592. /// <summary>
  593. /// Adds a new template type in the specified context with the given name. The record type will have a handle field.
  594. /// </summary>
  595. /// <parm name="context">AST context to which template will be added.</param>
  596. /// <parm name="typeName">Name of template to create.</param>
  597. /// <parm name="templateArgCount">Number of template arguments (one or two).</param>
  598. /// <parm name="defaultTypeArgValue">If assigned, the default argument for the element template.</param>
  599. CXXRecordDecl* hlsl::DeclareTemplateTypeWithHandle(
  600. ASTContext& context,
  601. StringRef name,
  602. uint8_t templateArgCount,
  603. _In_opt_ TypeSourceInfo* defaultTypeArgValue)
  604. {
  605. DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
  606. DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");
  607. // Create an object template declaration in translation unit scope.
  608. // templateArgCount=1: template<typename element> typeName { ... }
  609. // templateArgCount=2: template<typename element, int count> typeName { ... }
  610. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name);
  611. TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", defaultTypeArgValue);
  612. NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
  613. if (templateArgCount > 1)
  614. countTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("count", context.IntTy, 0);
  615. typeDeclBuilder.startDefinition();
  616. CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
  617. // Add an 'h' field to hold the handle.
  618. QualType elementType = context.getTemplateTypeParmType(
  619. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  620. if (templateArgCount > 1 &&
  621. // Only need array type for inputpatch and outputpatch.
  622. // Avoid Texture2DMS which may use 0 count.
  623. // TODO: use hlsl types to do the check.
  624. !name.startswith("Texture")) {
  625. Expr *countExpr = DeclRefExpr::Create(
  626. context, NestedNameSpecifierLoc(), NoLoc, countTemplateParamDecl, false,
  627. DeclarationNameInfo(countTemplateParamDecl->getDeclName(), NoLoc),
  628. context.IntTy, ExprValueKind::VK_RValue);
  629. elementType = context.getDependentSizedArrayType(
  630. elementType, countExpr, ArrayType::ArraySizeModifier::Normal, 0,
  631. SourceRange());
  632. // InputPatch and OutputPatch also have a "Length" static const member for the number of control points
  633. IdentifierInfo& lengthId = context.Idents.get(StringRef("Length"), tok::TokenKind::identifier);
  634. TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(context.IntTy.withConst());
  635. VarDecl* lengthValueDecl = VarDecl::Create(context, templateRecordDecl, NoLoc, NoLoc, &lengthId,
  636. context.IntTy.withConst(), lengthTypeSource, SC_Static);
  637. lengthValueDecl->setInit(countExpr);
  638. lengthValueDecl->setAccess(AS_public);
  639. templateRecordDecl->addDecl(lengthValueDecl);
  640. }
  641. typeDeclBuilder.addField("h", elementType);
  642. return typeDeclBuilder.completeDefinition();
  643. }
  644. FunctionTemplateDecl* hlsl::CreateFunctionTemplateDecl(
  645. ASTContext& context,
  646. _In_ CXXRecordDecl* recordDecl,
  647. _In_ CXXMethodDecl* functionDecl,
  648. _In_count_(templateParamNamedDeclsCount) NamedDecl** templateParamNamedDecls,
  649. size_t templateParamNamedDeclsCount)
  650. {
  651. DXASSERT_NOMSG(recordDecl != nullptr);
  652. DXASSERT_NOMSG(templateParamNamedDecls != nullptr);
  653. DXASSERT(templateParamNamedDeclsCount > 0, "otherwise caller shouldn't invoke this function");
  654. TemplateParameterList* templateParams = TemplateParameterList::Create(
  655. context, NoLoc, NoLoc, &templateParamNamedDecls[0], templateParamNamedDeclsCount, NoLoc);
  656. FunctionTemplateDecl* functionTemplate =
  657. FunctionTemplateDecl::Create(context, recordDecl, NoLoc, functionDecl->getDeclName(), templateParams, functionDecl);
  658. functionTemplate->setAccess(AccessSpecifier::AS_public);
  659. functionTemplate->setLexicalDeclContext(recordDecl);
  660. functionDecl->setDescribedFunctionTemplate(functionTemplate);
  661. recordDecl->addDecl(functionTemplate);
  662. return functionTemplate;
  663. }
  664. static
  665. void AssociateParametersToFunctionPrototype(
  666. _In_ TypeSourceInfo* tinfo,
  667. _In_count_(numParams) ParmVarDecl** paramVarDecls,
  668. unsigned int numParams)
  669. {
  670. FunctionProtoTypeLoc protoLoc = tinfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
  671. DXASSERT(protoLoc.getNumParams() == numParams, "otherwise unexpected number of parameters available");
  672. for (unsigned i = 0; i < numParams; i++) {
  673. DXASSERT(protoLoc.getParam(i) == nullptr, "otherwise prototype parameters were already initialized");
  674. protoLoc.setParam(i, paramVarDecls[i]);
  675. }
  676. }
  677. static void CreateConstructorDeclaration(
  678. ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
  679. ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
  680. _Out_ CXXConstructorDecl **constructorDecl, _Out_ TypeSourceInfo **tinfo) {
  681. DXASSERT_NOMSG(recordDecl != nullptr);
  682. DXASSERT_NOMSG(constructorDecl != nullptr);
  683. FunctionProtoType::ExtProtoInfo functionExtInfo;
  684. functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
  685. QualType functionQT = context.getFunctionType(
  686. resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
  687. DeclarationNameInfo declNameInfo(declarationName, NoLoc);
  688. *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
  689. DXASSERT_NOMSG(*tinfo != nullptr);
  690. *constructorDecl = CXXConstructorDecl::Create(
  691. context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
  692. StorageClass::SC_None, ExplicitFalse, InlineSpecifiedFalse, IsConstexprFalse);
  693. DXASSERT_NOMSG(*constructorDecl != nullptr);
  694. (*constructorDecl)->setLexicalDeclContext(recordDecl);
  695. (*constructorDecl)->setAccess(AccessSpecifier::AS_public);
  696. }
  697. static void CreateObjectFunctionDeclaration(
  698. ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
  699. ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
  700. _Out_ CXXMethodDecl **functionDecl, _Out_ TypeSourceInfo **tinfo) {
  701. DXASSERT_NOMSG(recordDecl != nullptr);
  702. DXASSERT_NOMSG(functionDecl != nullptr);
  703. FunctionProtoType::ExtProtoInfo functionExtInfo;
  704. functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
  705. QualType functionQT = context.getFunctionType(
  706. resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
  707. DeclarationNameInfo declNameInfo(declarationName, NoLoc);
  708. *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
  709. DXASSERT_NOMSG(*tinfo != nullptr);
  710. *functionDecl = CXXMethodDecl::Create(
  711. context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
  712. StorageClass::SC_None, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
  713. DXASSERT_NOMSG(*functionDecl != nullptr);
  714. (*functionDecl)->setLexicalDeclContext(recordDecl);
  715. (*functionDecl)->setAccess(AccessSpecifier::AS_public);
  716. }
  717. CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
  718. ASTContext& context,
  719. _In_ CXXRecordDecl* recordDecl,
  720. QualType resultType,
  721. ArrayRef<QualType> paramTypes,
  722. ArrayRef<StringRef> paramNames,
  723. DeclarationName declarationName,
  724. bool isConst)
  725. {
  726. DXASSERT_NOMSG(recordDecl != nullptr);
  727. DXASSERT_NOMSG(!resultType.isNull());
  728. DXASSERT_NOMSG(paramTypes.size() == paramNames.size());
  729. TypeSourceInfo* tinfo;
  730. CXXMethodDecl* functionDecl;
  731. CreateObjectFunctionDeclaration(context, recordDecl, resultType, paramTypes,
  732. declarationName, isConst, &functionDecl,
  733. &tinfo);
  734. // Create and associate parameters to method.
  735. SmallVector<ParmVarDecl *, 2> parmVarDecls;
  736. if (!paramTypes.empty()) {
  737. for (unsigned int i = 0; i < paramTypes.size(); ++i) {
  738. IdentifierInfo *argIi = &context.Idents.get(paramNames[i]);
  739. ParmVarDecl *parmVarDecl = ParmVarDecl::Create(
  740. context, functionDecl, NoLoc, NoLoc, argIi, paramTypes[i],
  741. context.getTrivialTypeSourceInfo(paramTypes[i], NoLoc),
  742. StorageClass::SC_None, nullptr);
  743. parmVarDecl->setScopeInfo(0, i);
  744. DXASSERT(parmVarDecl->getFunctionScopeIndex() == i,
  745. "otherwise failed to set correct index");
  746. parmVarDecls.push_back(parmVarDecl);
  747. }
  748. functionDecl->setParams(ArrayRef<ParmVarDecl *>(parmVarDecls));
  749. AssociateParametersToFunctionPrototype(tinfo, &parmVarDecls.front(),
  750. parmVarDecls.size());
  751. }
  752. recordDecl->addDecl(functionDecl);
  753. return functionDecl;
  754. }
  755. CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
  756. ASTContext& context, StringRef typeName, StringRef templateParamName) {
  757. // template<uint kind> FeedbackTexture2D[Array] { ... }
  758. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName);
  759. typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy);
  760. typeDeclBuilder.startDefinition();
  761. typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
  762. return typeDeclBuilder.completeDefinition();
  763. }
  764. clang::CXXRecordDecl *
  765. hlsl::DeclareConstantBufferViewType(clang::ASTContext &context, bool bTBuf) {
  766. // Create ConstantBufferView template declaration in translation unit scope
  767. // like other resource.
  768. // template<typename T> ConstantBuffer { int h; }
  769. DeclContext *DC = context.getTranslationUnitDecl();
  770. BuiltinTypeDeclBuilder typeDeclBuilder(DC,
  771. bTBuf ? "TextureBuffer"
  772. : "ConstantBuffer",
  773. TagDecl::TagKind::TTK_Struct);
  774. (void)typeDeclBuilder.addTypeTemplateParam("T");
  775. typeDeclBuilder.startDefinition();
  776. CXXRecordDecl *templateRecordDecl = typeDeclBuilder.getRecordDecl();
  777. typeDeclBuilder.addField(
  778. "h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
  779. typeDeclBuilder.completeDefinition();
  780. return templateRecordDecl;
  781. }
  782. CXXRecordDecl* hlsl::DeclareRayQueryType(ASTContext& context) {
  783. // template<uint kind> RayQuery { ... }
  784. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "RayQuery");
  785. typeDeclBuilder.addIntegerTemplateParam("flags", context.UnsignedIntTy);
  786. typeDeclBuilder.startDefinition();
  787. typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
  788. // Add constructor that will be lowered to the intrinsic that produces
  789. // the RayQuery handle for this object.
  790. CanQualType canQualType = typeDeclBuilder.getRecordDecl()->getTypeForDecl()->getCanonicalTypeUnqualified();
  791. CXXConstructorDecl *pConstructorDecl = nullptr;
  792. TypeSourceInfo *pTypeSourceInfo = nullptr;
  793. CreateConstructorDeclaration(context, typeDeclBuilder.getRecordDecl(), context.VoidTy, {}, context.DeclarationNames.getCXXConstructorName(canQualType), false, &pConstructorDecl, &pTypeSourceInfo);
  794. typeDeclBuilder.getRecordDecl()->addDecl(pConstructorDecl);
  795. return typeDeclBuilder.completeDefinition();
  796. }
  797. CXXRecordDecl* hlsl::DeclareResourceType(ASTContext& context, bool bSampler) {
  798. // struct ResourceDescriptor { uint8 desc; }
  799. StringRef Name = bSampler?".Sampler":".Resource";
  800. BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(),
  801. Name,
  802. TagDecl::TagKind::TTK_Struct);
  803. typeDeclBuilder.startDefinition();
  804. typeDeclBuilder.addField("h", GetHLSLObjectHandleType(context));
  805. CXXRecordDecl *recordDecl = typeDeclBuilder.completeDefinition();
  806. QualType indexType = context.UnsignedIntTy;
  807. QualType resultType = context.getRecordType(recordDecl);
  808. resultType = context.getConstType(resultType);
  809. CXXMethodDecl *functionDecl = CreateObjectFunctionDeclarationWithParams(
  810. context, recordDecl, resultType, ArrayRef<QualType>(indexType),
  811. ArrayRef<StringRef>(StringRef("index")),
  812. context.DeclarationNames.getCXXOperatorName(OO_Subscript), true);
  813. // Mark function as createResourceFromHeap intrinsic.
  814. functionDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit(
  815. context, "op", "",
  816. static_cast<int>(hlsl::IntrinsicOp::IOP_CreateResourceFromHeap)));
  817. return recordDecl;
  818. }
  819. VarDecl *hlsl::DeclareBuiltinGlobal(llvm::StringRef name, clang::QualType Ty,
  820. clang::ASTContext &context) {
  821. IdentifierInfo &II = context.Idents.get(name);
  822. auto *curDeclCtx =context.getTranslationUnitDecl();
  823. VarDecl *varDecl = VarDecl::Create(context, curDeclCtx,
  824. SourceLocation(), SourceLocation(), &II, Ty,
  825. context.getTrivialTypeSourceInfo(Ty),
  826. StorageClass::SC_Extern);
  827. // Mark implicit to avoid print it when rewrite.
  828. varDecl->setImplicit();
  829. curDeclCtx->addDecl(varDecl);
  830. return varDecl;
  831. }
  832. bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
  833. return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
  834. }
  835. bool hlsl::GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
  836. llvm::StringRef &group) {
  837. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  838. return false;
  839. }
  840. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  841. opcode = A->getOpcode();
  842. group = A->getGroup();
  843. return true;
  844. }
  845. bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S) {
  846. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  847. return false;
  848. }
  849. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  850. S = A->getLowering();
  851. return true;
  852. }
  853. /// <summary>Parses a column or row digit.</summary>
  854. static
  855. bool TryParseColOrRowChar(const char digit, _Out_ int* count) {
  856. if ('1' <= digit && digit <= '4') {
  857. *count = digit - '0';
  858. return true;
  859. }
  860. *count = 0;
  861. return false;
  862. }
  863. /// <summary>Parses a matrix shorthand identifier (eg, float3x2).</summary>
  864. _Use_decl_annotations_
  865. bool hlsl::TryParseMatrixShorthand(
  866. const char* typeName,
  867. size_t typeNameLen,
  868. HLSLScalarType* parsedType,
  869. int* rowCount,
  870. int* colCount,
  871. const clang::LangOptions& langOptions
  872. )
  873. {
  874. //
  875. // Matrix shorthand format is PrimitiveTypeRxC, where R is the row count and C is the column count.
  876. // R and C should be between 1 and 4 inclusive.
  877. // x is a literal 'x' character.
  878. // PrimitiveType is one of the HLSLScalarTypeNames values.
  879. //
  880. if (TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions) &&
  881. *rowCount != 0 && *colCount != 0) {
  882. // compare scalar component
  883. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-3, langOptions);
  884. if (type!= HLSLScalarType_unknown) {
  885. *parsedType = type;
  886. return true;
  887. }
  888. }
  889. // Unable to parse.
  890. return false;
  891. }
  892. /// <summary>Parses a vector shorthand identifier (eg, float3).</summary>
  893. _Use_decl_annotations_
  894. bool hlsl::TryParseVectorShorthand(
  895. const char* typeName,
  896. size_t typeNameLen,
  897. HLSLScalarType* parsedType,
  898. int* elementCount,
  899. const clang::LangOptions& langOptions
  900. )
  901. {
  902. // At least *something*N characters necessary, where something is at least 'int'
  903. if (TryParseColOrRowChar(typeName[typeNameLen - 1], elementCount)) {
  904. // compare scalar component
  905. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-1, langOptions);
  906. if (type!= HLSLScalarType_unknown) {
  907. *parsedType = type;
  908. return true;
  909. }
  910. }
  911. // Unable to parse.
  912. return false;
  913. }
  914. /// <summary>Parses a hlsl scalar type (e.g min16float, uint3x4) </summary>
  915. _Use_decl_annotations_
  916. bool hlsl::TryParseScalar(
  917. _In_count_(typenameLen)
  918. const char* typeName,
  919. size_t typeNameLen,
  920. _Out_ HLSLScalarType *parsedType,
  921. _In_ const clang::LangOptions& langOptions) {
  922. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen, langOptions);
  923. if (type!= HLSLScalarType_unknown) {
  924. *parsedType = type;
  925. return true;
  926. }
  927. return false; // unable to parse
  928. }
  929. /// <summary>Parse any (scalar, vector, matrix) hlsl types (e.g float, int3x4, uint2) </summary>
  930. _Use_decl_annotations_
  931. bool hlsl::TryParseAny(
  932. _In_count_(typenameLen)
  933. const char* typeName,
  934. size_t typeNameLen,
  935. _Out_ HLSLScalarType *parsedType,
  936. int *rowCount,
  937. int *colCount,
  938. _In_ const clang::LangOptions& langOptions) {
  939. // at least 'int'
  940. const size_t MinValidLen = 3;
  941. if (typeNameLen >= MinValidLen) {
  942. TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions);
  943. int suffixLen = *colCount == 0 ? 0 :
  944. *rowCount == 0 ? 1 : 3;
  945. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-suffixLen, langOptions);
  946. if (type!= HLSLScalarType_unknown) {
  947. *parsedType = type;
  948. return true;
  949. }
  950. }
  951. return false;
  952. }
  953. /// <summary>Parse string hlsl type</summary>
  954. _Use_decl_annotations_
  955. bool hlsl::TryParseString(
  956. _In_count_(typenameLen)
  957. const char* typeName,
  958. size_t typeNameLen,
  959. _In_ const clang::LangOptions& langOptions) {
  960. if (typeNameLen == 6 && typeName[0] == 's' && strncmp(typeName, "string", 6) == 0) {
  961. return true;
  962. }
  963. return false;
  964. }
  965. /// <summary>Parse any kind of dimension for vector or matrix (e.g 4,3 in int4x3).
  966. /// If it's a matrix type, rowCount and colCount will be nonzero. If it's a vector type, colCount is 0.
  967. /// Otherwise both rowCount and colCount is 0. Returns true if either matrix or vector dimensions detected. </summary>
  968. _Use_decl_annotations_
  969. bool hlsl::TryParseMatrixOrVectorDimension(
  970. _In_count_(typeNameLen)
  971. const char *typeName,
  972. size_t typeNameLen,
  973. _Out_opt_ int *rowCount,
  974. _Out_opt_ int *colCount,
  975. _In_ const clang::LangOptions& langOptions) {
  976. *rowCount = 0;
  977. *colCount = 0;
  978. size_t MinValidLen = 3; // at least int
  979. if (typeNameLen > MinValidLen) {
  980. if (TryParseColOrRowChar(typeName[typeNameLen - 1], colCount)) {
  981. // Try parse matrix
  982. if (typeName[typeNameLen - 2] == 'x')
  983. TryParseColOrRowChar(typeName[typeNameLen - 3], rowCount);
  984. return true;
  985. }
  986. }
  987. return false;
  988. }
  989. /// <summary>Creates a typedef for a matrix shorthand (eg, float3x2).</summary>
  990. TypedefDecl* hlsl::CreateMatrixSpecializationShorthand(
  991. ASTContext& context,
  992. QualType matrixSpecialization,
  993. HLSLScalarType scalarType,
  994. size_t rowCount,
  995. size_t colCount)
  996. {
  997. DXASSERT(rowCount <= 4, "else caller didn't validate rowCount");
  998. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  999. char typeName[64];
  1000. sprintf_s(typeName, _countof(typeName), "%s%ux%u",
  1001. HLSLScalarTypeNames[scalarType], (unsigned)rowCount, (unsigned)colCount);
  1002. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  1003. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  1004. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  1005. context.getTrivialTypeSourceInfo(matrixSpecialization, NoLoc));
  1006. decl->setImplicit(true);
  1007. currentDeclContext->addDecl(decl);
  1008. return decl;
  1009. }
  1010. /// <summary>Creates a typedef for a vector shorthand (eg, float3).</summary>
  1011. TypedefDecl* hlsl::CreateVectorSpecializationShorthand(
  1012. ASTContext& context,
  1013. QualType vectorSpecialization,
  1014. HLSLScalarType scalarType,
  1015. size_t colCount)
  1016. {
  1017. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  1018. char typeName[64];
  1019. sprintf_s(typeName, _countof(typeName), "%s%u",
  1020. HLSLScalarTypeNames[scalarType], (unsigned)colCount);
  1021. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  1022. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  1023. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  1024. context.getTrivialTypeSourceInfo(vectorSpecialization, NoLoc));
  1025. decl->setImplicit(true);
  1026. currentDeclContext->addDecl(decl);
  1027. return decl;
  1028. }
  1029. llvm::ArrayRef<hlsl::UnusualAnnotation*>
  1030. hlsl::UnusualAnnotation::CopyToASTContextArray(
  1031. clang::ASTContext& Context, hlsl::UnusualAnnotation** begin, size_t count) {
  1032. if (count == 0) {
  1033. return llvm::ArrayRef<hlsl::UnusualAnnotation*>();
  1034. }
  1035. UnusualAnnotation** arr = ::new (Context) UnusualAnnotation*[count];
  1036. for (size_t i = 0; i < count; ++i) {
  1037. arr[i] = begin[i]->CopyToASTContext(Context);
  1038. }
  1039. return llvm::ArrayRef<hlsl::UnusualAnnotation*>(arr, count);
  1040. }
  1041. UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context) {
  1042. // All UnusualAnnotation instances can be blitted.
  1043. size_t instanceSize;
  1044. switch (Kind) {
  1045. case UA_RegisterAssignment:
  1046. instanceSize = sizeof(hlsl::RegisterAssignment);
  1047. break;
  1048. case UA_ConstantPacking:
  1049. instanceSize = sizeof(hlsl::ConstantPacking);
  1050. break;
  1051. case UA_PayloadAccessQualifier:
  1052. instanceSize = sizeof(hlsl::PayloadAccessAnnotation);
  1053. break;
  1054. default:
  1055. DXASSERT(Kind == UA_SemanticDecl, "Kind == UA_SemanticDecl -- otherwise switch is incomplete");
  1056. instanceSize = sizeof(hlsl::SemanticDecl);
  1057. break;
  1058. }
  1059. void* result = Context.Allocate(instanceSize);
  1060. memcpy(result, this, instanceSize);
  1061. return (UnusualAnnotation*)result;
  1062. }
  1063. static bool HasTessFactorSemantic(const ValueDecl *decl) {
  1064. for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  1065. if (it->getKind() == UnusualAnnotation::UA_SemanticDecl) {
  1066. const SemanticDecl *sd = cast<SemanticDecl>(it);
  1067. const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
  1068. if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
  1069. return true;
  1070. }
  1071. }
  1072. return false;
  1073. }
  1074. static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
  1075. if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
  1076. return false;
  1077. if (const RecordType *RT = Ty->getAsStructureType()) {
  1078. RecordDecl *RD = RT->getDecl();
  1079. for (FieldDecl *fieldDecl : RD->fields()) {
  1080. if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
  1081. return true;
  1082. }
  1083. return false;
  1084. }
  1085. if (Ty->getAsArrayTypeUnsafe())
  1086. return HasTessFactorSemantic(decl);
  1087. return false;
  1088. }
  1089. bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
  1090. // This checks whether the function is structurally capable of being a patch
  1091. // constant function, not whether it is in fact the patch constant function
  1092. // for the entry point of a compiled hull shader (which may not have been
  1093. // seen yet). So the answer is conservative.
  1094. if (!FD->getReturnType()->isVoidType()) {
  1095. // Try to find TessFactor in return type.
  1096. if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
  1097. return true;
  1098. }
  1099. // Try to find TessFactor in out param.
  1100. for (const ParmVarDecl *param : FD->params()) {
  1101. if (param->hasAttr<HLSLOutAttr>()) {
  1102. if (HasTessFactorSemanticRecurse(param, param->getType()))
  1103. return true;
  1104. }
  1105. }
  1106. return false;
  1107. }