ASTContextHLSL.cpp 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191
  1. //===--- ASTContextHLSL.cpp - HLSL support for AST nodes and operations ---===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // ASTContextHLSL.cpp //
  5. // Copyright (C) Microsoft Corporation. All rights reserved. //
  6. // This file is distributed under the University of Illinois Open Source //
  7. // License. See LICENSE.TXT for details. //
  8. // //
  9. // This file implements the ASTContext interface for HLSL. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "clang/AST/ASTContext.h"
  13. #include "clang/AST/Attr.h"
  14. #include "clang/AST/DeclCXX.h"
  15. #include "clang/AST/DeclTemplate.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/ExprCXX.h"
  18. #include "clang/AST/ExternalASTSource.h"
  19. #include "clang/AST/TypeLoc.h"
  20. #include "clang/Sema/SemaDiagnostic.h"
  21. #include "clang/Sema/Sema.h"
  22. #include "clang/Sema/Overload.h"
  23. #include "dxc/Support/Global.h"
  24. #include "dxc/HLSL/HLOperations.h"
  25. #include "dxc/HLSL/DxilSemantic.h"
  26. using namespace clang;
  27. using namespace hlsl;
  28. static const int FirstTemplateDepth = 0;
  29. static const int FirstParamPosition = 0;
  30. static const bool ForConstFalse = false; // a construct is targeting a const type
  31. static const bool ForConstTrue = true; // a construct is targeting a non-const type
  32. static const bool ExplicitConversionFalse = false;// a conversion operation is not the result of an explicit cast
  33. static const bool InheritedFalse = false; // template parameter default value is not inherited.
  34. static const bool ParameterPackFalse = false; // template parameter is not an ellipsis.
  35. static const bool TypenameFalse = false; // 'typename' specified rather than 'class' for a template argument.
  36. static const bool DelayTypeCreationTrue = true; // delay type creation for a declaration
  37. static const bool DelayTypeCreationFalse = false; // immediately create a type when the declaration is created
  38. static const unsigned int NoQuals = 0; // no qualifiers in effect
  39. static const SourceLocation NoLoc; // no source location attribution available
  40. static const bool HasWrittenPrototypeTrue = true; // function had the prototype written
  41. static const bool InlineFalse = false; // namespace is not an inline namespace
  42. static const bool InlineSpecifiedFalse = false; // function was not specified as inline
  43. static const bool IsConstexprFalse = false; // function is not constexpr
  44. static const bool ListInitializationFalse = false;// not performing a list initialization
  45. static const bool SuppressDiagTrue = true; // suppress diagnostics
  46. static const bool VirtualFalse = false; // whether the base class is declares 'virtual'
  47. static const bool BaseClassFalse = false; // whether the base class is declared as 'class' (vs. 'struct')
  48. /// <summary>Names of HLSLScalarType enumeration values, in matching order to HLSLScalarType.</summary>
  49. const char* HLSLScalarTypeNames[] = {
  50. "<unknown>",
  51. "bool",
  52. "int",
  53. "uint",
  54. "dword",
  55. "half",
  56. "float",
  57. "double",
  58. "min10float",
  59. "min16float",
  60. "min12int",
  61. "min16int",
  62. "min16uint",
  63. "literal float",
  64. "literal int",
  65. "int16_t",
  66. "int32_t",
  67. "int64_t",
  68. "uint16_t",
  69. "uint32_t",
  70. "uint64_t",
  71. "float16_t",
  72. "float32_t",
  73. "float64_t"
  74. };
  75. static_assert(HLSLScalarTypeCount == _countof(HLSLScalarTypeNames), "otherwise scalar constants are not aligned");
  76. static HLSLScalarType FindScalarTypeByName(const char *typeName, const size_t typeLen, const LangOptions& langOptions) {
  77. // skipped HLSLScalarType: unknown, literal int, literal float
  78. switch (typeLen) {
  79. case 3: // int
  80. if (typeName[0] == 'i') {
  81. if (strncmp(typeName, "int", 3))
  82. break;
  83. return HLSLScalarType_int;
  84. }
  85. break;
  86. case 4: // bool, uint, half
  87. if (typeName[0] == 'b') {
  88. if (strncmp(typeName, "bool", 4))
  89. break;
  90. return HLSLScalarType_bool;
  91. }
  92. else if (typeName[0] == 'u') {
  93. if (strncmp(typeName, "uint", 4))
  94. break;
  95. return HLSLScalarType_uint;
  96. }
  97. else if (typeName[0] == 'h') {
  98. if (strncmp(typeName, "half", 4))
  99. break;
  100. return HLSLScalarType_half;
  101. }
  102. break;
  103. case 5: // dword, float
  104. if (typeName[0] == 'd') {
  105. if (strncmp(typeName, "dword", 5))
  106. break;
  107. return HLSLScalarType_dword;
  108. }
  109. else if (typeName[0] == 'f') {
  110. if (strncmp(typeName, "float", 5))
  111. break;
  112. return HLSLScalarType_float;
  113. }
  114. break;
  115. case 6: // double
  116. if (typeName[0] == 'd') {
  117. if (strncmp(typeName, "double", 6))
  118. break;
  119. return HLSLScalarType_double;
  120. }
  121. break;
  122. case 7: // int64_t
  123. if (typeName[0] == 'i' && typeName[1] == 'n') {
  124. if (typeName[3] == '6') {
  125. if (strncmp(typeName, "int64_t", 7))
  126. break;
  127. return HLSLScalarType_int64;
  128. }
  129. }
  130. case 8: // min12int, min16int, uint64_t
  131. if (typeName[0] == 'm' && typeName[1] == 'i') {
  132. if (typeName[4] == '2') {
  133. if (strncmp(typeName, "min12int", 8))
  134. break;
  135. return HLSLScalarType_int_min12;
  136. }
  137. else if (typeName[4] == '6') {
  138. if (strncmp(typeName, "min16int", 8))
  139. break;
  140. return HLSLScalarType_int_min16;
  141. }
  142. }
  143. else if (typeName[0] == 'u' && typeName[1] == 'i') {
  144. if (typeName[4] == '6') {
  145. if (strncmp(typeName, "uint64_t", 8))
  146. break;
  147. return HLSLScalarType_uint64;
  148. }
  149. }
  150. break;
  151. case 9: // min16uint
  152. if (typeName[0] == 'm' && typeName[1] == 'i') {
  153. if (strncmp(typeName, "min16uint", 9))
  154. break;
  155. return HLSLScalarType_uint_min16;
  156. }
  157. break;
  158. case 10: // min10float, min16float
  159. if (typeName[0] == 'm' && typeName[1] == 'i') {
  160. if (typeName[4] == '0') {
  161. if (strncmp(typeName, "min10float", 10))
  162. break;
  163. return HLSLScalarType_float_min10;
  164. }
  165. if (typeName[4] == '6') {
  166. if (strncmp(typeName, "min16float", 10))
  167. break;
  168. return HLSLScalarType_float_min16;
  169. }
  170. }
  171. break;
  172. default:
  173. break;
  174. }
  175. // fixed width types (int16_t, uint16_t, int32_t, uint32_t, float16_t, float32_t, float64_t)
  176. // are only supported in HLSL 2018
  177. if (langOptions.HLSLVersion >= 2018) {
  178. switch (typeLen) {
  179. case 7: // int16_t, int32_t
  180. if (typeName[0] == 'i' && typeName[1] == 'n') {
  181. if (!langOptions.UseMinPrecision) {
  182. if (typeName[3] == '1') {
  183. if (strncmp(typeName, "int16_t", 7))
  184. break;
  185. return HLSLScalarType_int16;
  186. }
  187. }
  188. if (typeName[3] == '3') {
  189. if (strncmp(typeName, "int32_t", 7))
  190. break;
  191. return HLSLScalarType_int32;
  192. }
  193. }
  194. case 8: // uint16_t, uint32_t
  195. if (!langOptions.UseMinPrecision) {
  196. if (typeName[0] == 'u' && typeName[1] == 'i') {
  197. if (typeName[4] == '1') {
  198. if (strncmp(typeName, "uint16_t", 8))
  199. break;
  200. return HLSLScalarType_uint16;
  201. }
  202. }
  203. }
  204. if (typeName[4] == '3') {
  205. if (strncmp(typeName, "uint32_t", 8))
  206. break;
  207. return HLSLScalarType_uint32;
  208. }
  209. case 9: // float16_t, float32_t, float64_t
  210. if (typeName[0] == 'f' && typeName[1] == 'l') {
  211. if (!langOptions.UseMinPrecision) {
  212. if (typeName[5] == '1') {
  213. if (strncmp(typeName, "float16_t", 9))
  214. break;
  215. return HLSLScalarType_float16;
  216. }
  217. }
  218. if (typeName[5] == '3') {
  219. if (strncmp(typeName, "float32_t", 9))
  220. break;
  221. return HLSLScalarType_float32;
  222. }
  223. else if (typeName[5] == '6') {
  224. if (strncmp(typeName, "float64_t", 9))
  225. break;
  226. return HLSLScalarType_float64;
  227. }
  228. }
  229. }
  230. }
  231. return HLSLScalarType_unknown;
  232. }
  233. /// <summary>Provides the primitive type for lowering matrix types to IR.</summary>
  234. static
  235. CanQualType GetHLSLObjectHandleType(ASTContext& context)
  236. {
  237. return context.IntTy;
  238. }
  239. /// <summary>Adds a handle field to the specified record.</summary>
  240. static
  241. void AddHLSLHandleField(ASTContext& context, DeclContext* recordDecl, QualType handleQualType)
  242. {
  243. IdentifierInfo& handleId = context.Idents.get(StringRef("h"), tok::TokenKind::identifier);
  244. TypeSourceInfo* fieldTypeSource = context.getTrivialTypeSourceInfo(handleQualType, NoLoc);
  245. const bool MutableFalse = false;
  246. const InClassInitStyle initStyle = InClassInitStyle::ICIS_NoInit;
  247. FieldDecl* handleDecl = FieldDecl::Create(
  248. context, recordDecl, NoLoc, NoLoc, &handleId, handleQualType, fieldTypeSource, nullptr, MutableFalse, initStyle);
  249. handleDecl->setAccess(AccessSpecifier::AS_private);
  250. handleDecl->setImplicit(true);
  251. recordDecl->addDecl(handleDecl);
  252. }
  253. static
  254. void AddSubscriptOperator(
  255. ASTContext& context, unsigned int templateDepth, TemplateTypeParmDecl *elementTemplateParamDecl,
  256. NonTypeTemplateParmDecl* colCountTemplateParamDecl, QualType intType, CXXRecordDecl* templateRecordDecl,
  257. ClassTemplateDecl* vectorTemplateDecl,
  258. bool forConst)
  259. {
  260. QualType elementType = context.getTemplateTypeParmType(
  261. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  262. Expr* sizeExpr = DeclRefExpr::Create(context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl, false,
  263. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  264. intType, ExprValueKind::VK_RValue);
  265. CXXRecordDecl *vecTemplateRecordDecl = vectorTemplateDecl->getTemplatedDecl();
  266. const clang::Type *vecTy = vecTemplateRecordDecl->getTypeForDecl();
  267. TemplateArgument templateArgs[2] =
  268. {
  269. TemplateArgument(elementType),
  270. TemplateArgument(sizeExpr)
  271. };
  272. TemplateName canonName = context.getCanonicalTemplateName(TemplateName(vectorTemplateDecl));
  273. QualType vectorType = context.getTemplateSpecializationType(
  274. canonName, templateArgs, _countof(templateArgs), QualType(vecTy, 0));
  275. vectorType = context.getLValueReferenceType(vectorType);
  276. if (forConst)
  277. vectorType = context.getConstType(vectorType);
  278. QualType indexType = intType;
  279. CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
  280. context, templateRecordDecl, vectorType,
  281. ArrayRef<QualType>(indexType), ArrayRef<StringRef>(StringRef("index")),
  282. context.DeclarationNames.getCXXOperatorName(OO_Subscript), forConst);
  283. }
  284. /// <summary>Adds up-front support for HLSL matrix types (just the template declaration).</summary>
  285. void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorTemplateDecl, ClassTemplateDecl** matrixTemplateDecl)
  286. {
  287. DXASSERT_NOMSG(matrixTemplateDecl != nullptr);
  288. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  289. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  290. // Create a matrix template declaration in translation unit scope.
  291. // template<typename element, int row_count, int col_count> matrix { ... }
  292. IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
  293. TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
  294. context, currentDeclContext, NoLoc, NoLoc,
  295. FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
  296. elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
  297. QualType intType = context.IntTy;
  298. Expr *literalIntFour = IntegerLiteral::Create(
  299. context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
  300. IdentifierInfo& rowCountParamId = context.Idents.get(StringRef("row_count"), tok::TokenKind::identifier);
  301. NonTypeTemplateParmDecl* rowCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  302. context, currentDeclContext, NoLoc, NoLoc,
  303. FirstTemplateDepth, FirstParamPosition + 1, &rowCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
  304. rowCountTemplateParamDecl->setDefaultArgument(literalIntFour);
  305. IdentifierInfo& colCountParamId = context.Idents.get(StringRef("col_count"), tok::TokenKind::identifier);
  306. NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  307. context, currentDeclContext, NoLoc, NoLoc,
  308. FirstTemplateDepth, FirstParamPosition + 2, &colCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
  309. colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
  310. NamedDecl* templateParameters[] =
  311. {
  312. elementTemplateParamDecl, rowCountTemplateParamDecl, colCountTemplateParamDecl
  313. };
  314. TemplateParameterList* templateParameterList = TemplateParameterList::Create(
  315. context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
  316. IdentifierInfo& matrixId = context.Idents.get(StringRef("matrix"), tok::TokenKind::identifier);
  317. CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
  318. context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &matrixId,
  319. nullptr, DelayTypeCreationTrue);
  320. ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
  321. context, currentDeclContext, NoLoc, DeclarationName(&matrixId),
  322. templateParameterList, templateRecordDecl, nullptr);
  323. templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
  324. // Requesting the class name specialization will fault in required types.
  325. QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
  326. T = context.getInjectedClassNameType(templateRecordDecl, T);
  327. assert(T->isDependentType() && "Class template type is not dependent?");
  328. classTemplateDecl->setLexicalDeclContext(currentDeclContext);
  329. templateRecordDecl->setLexicalDeclContext(currentDeclContext);
  330. templateRecordDecl->startDefinition();
  331. // Add an 'h' field to hold the handle.
  332. // The type is vector<element, col>[row].
  333. QualType elementType = context.getTemplateTypeParmType(
  334. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  335. Expr *sizeExpr = DeclRefExpr::Create(
  336. context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
  337. false,
  338. DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
  339. intType, ExprValueKind::VK_RValue);
  340. Expr *rowSizeExpr = DeclRefExpr::Create(
  341. context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
  342. false,
  343. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  344. intType, ExprValueKind::VK_RValue);
  345. QualType vectorType = context.getDependentSizedExtVectorType(
  346. elementType, rowSizeExpr, SourceLocation());
  347. QualType vectorArrayType = context.getDependentSizedArrayType(
  348. vectorType, sizeExpr, ArrayType::Normal, 0, SourceRange());
  349. AddHLSLHandleField(context, templateRecordDecl, vectorArrayType);
  350. // Add an operator[]. The operator ranges from zero to rowcount-1, and returns a vector of colcount elements.
  351. const unsigned int templateDepth = 0;
  352. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  353. colCountTemplateParamDecl, context.UnsignedIntTy,
  354. templateRecordDecl, vectorTemplateDecl, ForConstFalse);
  355. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  356. colCountTemplateParamDecl, context.UnsignedIntTy,
  357. templateRecordDecl, vectorTemplateDecl, ForConstTrue);
  358. templateRecordDecl->completeDefinition();
  359. classTemplateDecl->setImplicit(true);
  360. templateRecordDecl->setImplicit(true);
  361. // Both declarations need to be present for correct handling.
  362. currentDeclContext->addDecl(classTemplateDecl);
  363. currentDeclContext->addDecl(templateRecordDecl);
  364. #ifdef DBG
  365. // Verify that we can read the field member from the template record.
  366. DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
  367. DeclarationName(&context.Idents.get(StringRef("h"))));
  368. DXASSERT(!lookupResult.empty(), "otherwise matrix handle cannot be looked up");
  369. #endif
  370. *matrixTemplateDecl = classTemplateDecl;
  371. }
  372. static void AddHLSLVectorSubscriptAttr(Decl *D, ASTContext &context) {
  373. StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
  374. D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)));
  375. }
  376. /// <summary>Adds up-front support for HLSL vector types (just the template declaration).</summary>
  377. void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vectorTemplateDecl)
  378. {
  379. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  380. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  381. // Create a vector template declaration in translation unit scope.
  382. // template<typename element, int element_count> vector { ... }
  383. IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
  384. TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
  385. context, currentDeclContext, NoLoc, NoLoc,
  386. FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
  387. elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
  388. QualType intType = context.IntTy;
  389. Expr *literalIntFour = IntegerLiteral::Create(
  390. context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
  391. IdentifierInfo& colCountParamId = context.Idents.get(StringRef("element_count"), tok::TokenKind::identifier);
  392. NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  393. context, currentDeclContext, NoLoc, NoLoc,
  394. FirstTemplateDepth, FirstParamPosition + 1, &colCountParamId, intType, ParameterPackFalse, nullptr);
  395. colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
  396. NamedDecl* templateParameters[] =
  397. {
  398. elementTemplateParamDecl, colCountTemplateParamDecl
  399. };
  400. TemplateParameterList* templateParameterList = TemplateParameterList::Create(
  401. context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
  402. IdentifierInfo& vectorId = context.Idents.get(StringRef("vector"), tok::TokenKind::identifier);
  403. CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
  404. context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &vectorId,
  405. nullptr, DelayTypeCreationTrue);
  406. ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
  407. context, currentDeclContext, NoLoc, DeclarationName(&vectorId),
  408. templateParameterList, templateRecordDecl, nullptr);
  409. templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
  410. // Requesting the class name specialization will fault in required types.
  411. QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
  412. T = context.getInjectedClassNameType(templateRecordDecl, T);
  413. assert(T->isDependentType() && "Class template type is not dependent?");
  414. classTemplateDecl->setLexicalDeclContext(currentDeclContext);
  415. templateRecordDecl->setLexicalDeclContext(currentDeclContext);
  416. templateRecordDecl->startDefinition();
  417. // Add an 'h' field to hold the handle.
  418. AddHLSLHandleField(context, templateRecordDecl, QualType(GetHLSLObjectHandleType(context)));
  419. // Add an operator[]. The operator ranges from zero to colcount-1, and returns a scalar.
  420. const unsigned int templateDepth = 0;
  421. QualType resultType = context.getTemplateTypeParmType(
  422. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  423. // ForConstTrue:
  424. QualType refResultType = context.getConstType(context.getLValueReferenceType(resultType));
  425. CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
  426. context, templateRecordDecl, refResultType,
  427. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  428. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstTrue);
  429. AddHLSLVectorSubscriptAttr(functionDecl, context);
  430. // ForConstFalse:
  431. resultType = context.getLValueReferenceType(resultType);
  432. functionDecl = CreateObjectFunctionDeclarationWithParams(
  433. context, templateRecordDecl, resultType,
  434. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  435. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstFalse);
  436. AddHLSLVectorSubscriptAttr(functionDecl, context);
  437. templateRecordDecl->completeDefinition();
  438. classTemplateDecl->setImplicit(true);
  439. templateRecordDecl->setImplicit(true);
  440. // Both declarations need to be present for correct handling.
  441. currentDeclContext->addDecl(classTemplateDecl);
  442. currentDeclContext->addDecl(templateRecordDecl);
  443. #ifdef DBG
  444. // Verify that we can read the field member from the template record.
  445. DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
  446. DeclarationName(&context.Idents.get(StringRef("h"))));
  447. DXASSERT(!lookupResult.empty(), "otherwise vector handle cannot be looked up");
  448. #endif
  449. *vectorTemplateDecl = classTemplateDecl;
  450. }
  451. /// <summary>
  452. /// Adds a new record type in the specified context with the given name. The record type will have a handle field.
  453. /// </summary>
  454. void hlsl::AddRecordTypeWithHandle(ASTContext& context, _Outptr_ CXXRecordDecl** typeDecl, _In_z_ const char* typeName)
  455. {
  456. DXASSERT_NOMSG(typeDecl != nullptr);
  457. DXASSERT_NOMSG(typeName != nullptr);
  458. *typeDecl = nullptr;
  459. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  460. IdentifierInfo& newTypeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  461. CXXRecordDecl* newDecl = CXXRecordDecl::Create(
  462. context, TagDecl::TagKind::TTK_Struct, currentDeclContext, NoLoc, NoLoc, &newTypeId, nullptr);
  463. newDecl->setLexicalDeclContext(currentDeclContext);
  464. newDecl->setFreeStanding();
  465. newDecl->startDefinition();
  466. AddHLSLHandleField(context, newDecl, QualType(GetHLSLObjectHandleType(context)));
  467. currentDeclContext->addDecl(newDecl);
  468. newDecl->completeDefinition();
  469. *typeDecl = newDecl;
  470. }
  471. static
  472. Expr* IntConstantAsBoolExpr(clang::Sema& sema, uint64_t value)
  473. {
  474. return sema.ImpCastExprToType(
  475. sema.ActOnIntegerConstant(NoLoc, value).get(), sema.getASTContext().BoolTy, CK_IntegralToBoolean).get();
  476. }
  477. static
  478. CXXRecordDecl* CreateStdStructWithStaticBool(clang::ASTContext& context, NamespaceDecl* stdNamespace, IdentifierInfo& trueTypeId, IdentifierInfo& valueId, Expr* trueExpression)
  479. {
  480. // struct true_type { static const bool value = true; }
  481. TypeSourceInfo* boolTypeSource = context.getTrivialTypeSourceInfo(context.BoolTy.withConst());
  482. CXXRecordDecl* trueTypeDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &trueTypeId, nullptr, DelayTypeCreationTrue);
  483. QualType trueTypeQT = context.getTagDeclType(trueTypeDecl); // Fault this in now.
  484. // static fields are variables in the AST
  485. VarDecl* trueValueDecl = VarDecl::Create(context, trueTypeDecl, NoLoc, NoLoc, &valueId,
  486. context.BoolTy.withConst(), boolTypeSource, SC_Static);
  487. trueValueDecl->setInit(trueExpression);
  488. trueValueDecl->setConstexpr(true);
  489. trueValueDecl->setAccess(AS_public);
  490. trueTypeDecl->setLexicalDeclContext(stdNamespace);
  491. trueTypeDecl->startDefinition();
  492. trueTypeDecl->addDecl(trueValueDecl);
  493. trueTypeDecl->completeDefinition();
  494. stdNamespace->addDecl(trueTypeDecl);
  495. return trueTypeDecl;
  496. }
  497. static
  498. void DefineRecordWithBase(CXXRecordDecl* decl, DeclContext* lexicalContext, const CXXBaseSpecifier* base)
  499. {
  500. decl->setLexicalDeclContext(lexicalContext);
  501. decl->startDefinition();
  502. decl->setBases(&base, 1);
  503. decl->completeDefinition();
  504. lexicalContext->addDecl(decl);
  505. }
  506. static
  507. void SetPartialExplicitSpecialization(ClassTemplateDecl* templateDecl, ClassTemplatePartialSpecializationDecl* specializationDecl)
  508. {
  509. specializationDecl->setSpecializationKind(TSK_ExplicitSpecialization);
  510. templateDecl->AddPartialSpecialization(specializationDecl, nullptr);
  511. }
  512. static
  513. void CreateIsEqualSpecialization(ASTContext& context, ClassTemplateDecl* templateDecl, TemplateName& templateName,
  514. DeclContext* lexicalContext, const CXXBaseSpecifier* base, TemplateParameterList* templateParamList,
  515. TemplateArgument (&templateArgs)[2])
  516. {
  517. QualType specializationCanonType = context.getTemplateSpecializationType(templateName, templateArgs, _countof(templateArgs));
  518. TemplateArgumentListInfo templateArgsListInfo = TemplateArgumentListInfo(NoLoc, NoLoc);
  519. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[0], context.getTrivialTypeSourceInfo(templateArgs[0].getAsType())));
  520. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[1], context.getTrivialTypeSourceInfo(templateArgs[1].getAsType())));
  521. ClassTemplatePartialSpecializationDecl* specializationDecl =
  522. ClassTemplatePartialSpecializationDecl::Create(context, TTK_Struct, lexicalContext, NoLoc, NoLoc,
  523. templateParamList, templateDecl, templateArgs, _countof(templateArgs),
  524. templateArgsListInfo, specializationCanonType, nullptr);
  525. context.getTagDeclType(specializationDecl); // Fault this in now.
  526. DefineRecordWithBase(specializationDecl, lexicalContext, base);
  527. SetPartialExplicitSpecialization(templateDecl, specializationDecl);
  528. }
  529. /// <summary>Adds the implementation for std::is_equal.</summary>
  530. void hlsl::AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema)
  531. {
  532. // The goal is to support std::is_same<T, T>::value for testing purposes, in a manner that can
  533. // evolve into a compliant feature in the future.
  534. //
  535. // The definitions necessary are as follows (all in the std namespace).
  536. // template <class T, T v>
  537. // struct integral_constant {
  538. // typedef T value_type;
  539. // static const value_type value = v;
  540. // operator value_type() { return value; }
  541. // };
  542. //
  543. // typedef integral_constant<bool, true> true_type;
  544. // typedef integral_constant<bool, false> false_type;
  545. //
  546. // template<typename T, typename U> struct is_same : public false_type {};
  547. // template<typename T> struct is_same<T, T> : public true_type{};
  548. //
  549. // We instead use these simpler definitions for true_type and false_type.
  550. // struct false_type { static const bool value = false; };
  551. // struct true_type { static const bool value = true; };
  552. DeclContext* tuContext = context.getTranslationUnitDecl();
  553. IdentifierInfo& stdId = context.Idents.get(StringRef("std"), tok::TokenKind::identifier);
  554. IdentifierInfo& trueTypeId = context.Idents.get(StringRef("true_type"), tok::TokenKind::identifier);
  555. IdentifierInfo& falseTypeId = context.Idents.get(StringRef("false_type"), tok::TokenKind::identifier);
  556. IdentifierInfo& valueId = context.Idents.get(StringRef("value"), tok::TokenKind::identifier);
  557. IdentifierInfo& isSameId = context.Idents.get(StringRef("is_same"), tok::TokenKind::identifier);
  558. IdentifierInfo& tId = context.Idents.get(StringRef("T"), tok::TokenKind::identifier);
  559. IdentifierInfo& vId = context.Idents.get(StringRef("V"), tok::TokenKind::identifier);
  560. Expr* trueExpression = IntConstantAsBoolExpr(sema, 1);
  561. Expr* falseExpression = IntConstantAsBoolExpr(sema, 0);
  562. // namespace std
  563. NamespaceDecl* stdNamespace = NamespaceDecl::Create(context, tuContext, InlineFalse, NoLoc, NoLoc, &stdId, nullptr);
  564. CXXRecordDecl* trueTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, trueTypeId, valueId, trueExpression);
  565. CXXRecordDecl* falseTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, falseTypeId, valueId, falseExpression);
  566. // template<typename T, typename U> struct is_same : public false_type {};
  567. CXXRecordDecl* isSameFalseRecordDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &isSameId, nullptr, false);
  568. TemplateTypeParmDecl* tParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition, &tId, TypenameFalse, ParameterPackFalse);
  569. TemplateTypeParmDecl* uParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition + 1, &vId, TypenameFalse, ParameterPackFalse);
  570. NamedDecl* falseParams[] = { tParam, uParam };
  571. TemplateParameterList* falseParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, falseParams, _countof(falseParams), NoLoc);
  572. ClassTemplateDecl* isSameFalseTemplateDecl = ClassTemplateDecl::Create(context, stdNamespace, NoLoc, DeclarationName(&isSameId), falseParamList, isSameFalseRecordDecl, nullptr);
  573. context.getTagDeclType(isSameFalseRecordDecl); // Fault this in now.
  574. CXXBaseSpecifier* falseBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  575. context.getTrivialTypeSourceInfo(context.getTypeDeclType(falseTypeDecl)), NoLoc);
  576. isSameFalseRecordDecl->setDescribedClassTemplate(isSameFalseTemplateDecl);
  577. isSameFalseTemplateDecl->setLexicalDeclContext(stdNamespace);
  578. DefineRecordWithBase(isSameFalseRecordDecl, stdNamespace, falseBase);
  579. // is_same for 'true' is a specialization of is_same for 'false', taking a single T, where both T will match
  580. // template<typename T> struct is_same<T, T> : public true_type{};
  581. TemplateName tn = TemplateName(isSameFalseTemplateDecl);
  582. NamedDecl* trueParams[] = { tParam };
  583. TemplateParameterList* trueParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, trueParams, _countof(trueParams), NoLoc);
  584. CXXBaseSpecifier* trueBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  585. context.getTrivialTypeSourceInfo(context.getTypeDeclType(trueTypeDecl)), NoLoc);
  586. TemplateArgument ta = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)));
  587. TemplateArgument isSameTrueTemplateArgs[] = { ta, ta };
  588. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueTemplateArgs);
  589. stdNamespace->addDecl(isSameFalseTemplateDecl);
  590. stdNamespace->setImplicit(true);
  591. tuContext->addDecl(stdNamespace);
  592. // This could be a parameter if ever needed.
  593. const bool SupportExtensions = true;
  594. // Consider right-hand const and right-hand ref to be true for is_same:
  595. // template<typename T> struct is_same<T, const T> : public true_type{};
  596. // template<typename T> struct is_same<T, T&> : public true_type{};
  597. if (SupportExtensions)
  598. {
  599. TemplateArgument trueConstArg = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)).withConst());
  600. TemplateArgument isSameTrueConstTemplateArgs[] = { ta, trueConstArg };
  601. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueConstTemplateArgs);
  602. TemplateArgument trueRefArg = TemplateArgument(
  603. context.getLValueReferenceType(context.getCanonicalType(context.getTypeDeclType(tParam))));
  604. TemplateArgument isSameTrueRefTemplateArgs[] = { ta, trueRefArg };
  605. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueRefTemplateArgs);
  606. }
  607. }
  608. /// <summary>
  609. /// Adds a new template type in the specified context with the given name. The record type will have a handle field.
  610. /// </summary>
  611. /// <parm name="context">AST context to which template will be added.</param>
  612. /// <parm name="typeDecl">After execution, template declaration.</param>
  613. /// <parm name="recordDecl">After execution, record declaration for template.</param>
  614. /// <parm name="typeName">Name of template to create.</param>
  615. /// <parm name="templateArgCount">Number of template arguments (one or two).</param>
  616. /// <parm name="defaultTypeArgValue">If assigned, the default argument for the element template.</param>
  617. void hlsl::AddTemplateTypeWithHandle(
  618. ASTContext& context,
  619. _Outptr_ ClassTemplateDecl** typeDecl,
  620. _Outptr_ CXXRecordDecl** recordDecl,
  621. _In_z_ const char* typeName,
  622. uint8_t templateArgCount,
  623. _In_opt_ TypeSourceInfo* defaultTypeArgValue
  624. )
  625. {
  626. DXASSERT_NOMSG(typeDecl != nullptr);
  627. DXASSERT_NOMSG(recordDecl != nullptr);
  628. DXASSERT_NOMSG(typeName != nullptr);
  629. DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
  630. DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");
  631. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  632. // Create an object template declaration in translation unit scope.
  633. // templateArgCount=1: template<typename element> typeName { ... }
  634. // templateArgCount=2: template<typename element, int count> typeName { ... }
  635. IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
  636. TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
  637. context, currentDeclContext, NoLoc, NoLoc,
  638. FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
  639. QualType intType = context.IntTy;
  640. if (defaultTypeArgValue != nullptr)
  641. {
  642. elementTemplateParamDecl->setDefaultArgument(defaultTypeArgValue);
  643. }
  644. NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
  645. if (templateArgCount > 1) {
  646. IdentifierInfo& countParamId = context.Idents.get(StringRef("count"), tok::TokenKind::identifier);
  647. countTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  648. context, currentDeclContext, NoLoc, NoLoc,
  649. FirstTemplateDepth, FirstParamPosition + 1, &countParamId, intType, ParameterPackFalse, nullptr);
  650. // Zero means default here. The count is decided by runtime.
  651. Expr *literalIntZero = IntegerLiteral::Create(
  652. context, llvm::APInt(context.getIntWidth(intType), 0), intType, NoLoc);
  653. countTemplateParamDecl->setDefaultArgument(literalIntZero);
  654. }
  655. NamedDecl* templateParameters[] =
  656. {
  657. elementTemplateParamDecl, countTemplateParamDecl
  658. };
  659. TemplateParameterList* templateParameterList = TemplateParameterList::Create(
  660. context, NoLoc, NoLoc, templateParameters, templateArgCount, NoLoc);
  661. IdentifierInfo& typeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  662. CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
  663. context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
  664. nullptr, DelayTypeCreationTrue);
  665. ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
  666. context, currentDeclContext, NoLoc, DeclarationName(&typeId),
  667. templateParameterList, templateRecordDecl, nullptr);
  668. templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
  669. // Requesting the class name specialization will fault in required types.
  670. QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
  671. T = context.getInjectedClassNameType(templateRecordDecl, T);
  672. assert(T->isDependentType() && "Class template type is not dependent?");
  673. classTemplateDecl->setLexicalDeclContext(currentDeclContext);
  674. templateRecordDecl->setLexicalDeclContext(currentDeclContext);
  675. templateRecordDecl->startDefinition();
  676. // Many more things to come here, like constructors and the like....
  677. // Add an 'h' field to hold the handle.
  678. QualType elementType = context.getTemplateTypeParmType(
  679. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  680. if (templateArgCount > 1 &&
  681. // Only need array type for inputpatch and outputpatch.
  682. // Avoid Texture2DMS which may use 0 count.
  683. // TODO: use hlsl types to do the check.
  684. !typeId.getName().startswith("Texture")) {
  685. Expr *countExpr = DeclRefExpr::Create(
  686. context, NestedNameSpecifierLoc(), NoLoc, countTemplateParamDecl, false,
  687. DeclarationNameInfo(countTemplateParamDecl->getDeclName(), NoLoc),
  688. intType, ExprValueKind::VK_RValue);
  689. elementType = context.getDependentSizedArrayType(
  690. elementType, countExpr, ArrayType::ArraySizeModifier::Normal, 0,
  691. SourceRange());
  692. // InputPatch and OutputPatch also have a "Length" static const member for the number of control points
  693. IdentifierInfo& lengthId = context.Idents.get(StringRef("Length"), tok::TokenKind::identifier);
  694. TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(intType.withConst());
  695. VarDecl* lengthValueDecl = VarDecl::Create(context, templateRecordDecl, NoLoc, NoLoc, &lengthId,
  696. intType.withConst(), lengthTypeSource, SC_Static);
  697. lengthValueDecl->setInit(countExpr);
  698. lengthValueDecl->setAccess(AS_public);
  699. templateRecordDecl->addDecl(lengthValueDecl);
  700. }
  701. AddHLSLHandleField(context, templateRecordDecl, elementType);
  702. templateRecordDecl->completeDefinition();
  703. // Both declarations need to be present for correct handling.
  704. currentDeclContext->addDecl(classTemplateDecl);
  705. currentDeclContext->addDecl(templateRecordDecl);
  706. #ifdef DBG
  707. // Verify that we can read the field member from the template record.
  708. DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
  709. DeclarationName(&context.Idents.get(StringRef("h"))));
  710. DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
  711. #endif
  712. *typeDecl = classTemplateDecl;
  713. *recordDecl = templateRecordDecl;
  714. }
  715. FunctionTemplateDecl* hlsl::CreateFunctionTemplateDecl(
  716. ASTContext& context,
  717. _In_ CXXRecordDecl* recordDecl,
  718. _In_ CXXMethodDecl* functionDecl,
  719. _In_count_(templateParamNamedDeclsCount) NamedDecl** templateParamNamedDecls,
  720. size_t templateParamNamedDeclsCount)
  721. {
  722. DXASSERT_NOMSG(recordDecl != nullptr);
  723. DXASSERT_NOMSG(templateParamNamedDecls != nullptr);
  724. DXASSERT(templateParamNamedDeclsCount > 0, "otherwise caller shouldn't invoke this function");
  725. TemplateParameterList* templateParams = TemplateParameterList::Create(
  726. context, NoLoc, NoLoc, &templateParamNamedDecls[0], templateParamNamedDeclsCount, NoLoc);
  727. FunctionTemplateDecl* functionTemplate =
  728. FunctionTemplateDecl::Create(context, recordDecl, NoLoc, functionDecl->getDeclName(), templateParams, functionDecl);
  729. functionTemplate->setAccess(AccessSpecifier::AS_public);
  730. functionTemplate->setLexicalDeclContext(recordDecl);
  731. functionDecl->setDescribedFunctionTemplate(functionTemplate);
  732. recordDecl->addDecl(functionTemplate);
  733. return functionTemplate;
  734. }
  735. static
  736. void AssociateParametersToFunctionPrototype(
  737. _In_ TypeSourceInfo* tinfo,
  738. _In_count_(numParams) ParmVarDecl** paramVarDecls,
  739. unsigned int numParams)
  740. {
  741. FunctionProtoTypeLoc protoLoc = tinfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
  742. DXASSERT(protoLoc.getNumParams() == numParams, "otherwise unexpected number of parameters available");
  743. for (unsigned i = 0; i < numParams; i++) {
  744. DXASSERT(protoLoc.getParam(i) == nullptr, "otherwise prototype parameters were already initialized");
  745. protoLoc.setParam(i, paramVarDecls[i]);
  746. }
  747. }
  748. static void CreateObjectFunctionDeclaration(
  749. ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
  750. ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
  751. _Out_ CXXMethodDecl **functionDecl, _Out_ TypeSourceInfo **tinfo) {
  752. DXASSERT_NOMSG(recordDecl != nullptr);
  753. DXASSERT_NOMSG(functionDecl != nullptr);
  754. FunctionProtoType::ExtProtoInfo functionExtInfo;
  755. functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
  756. QualType functionQT = context.getFunctionType(
  757. resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
  758. DeclarationNameInfo declNameInfo(declarationName, NoLoc);
  759. *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
  760. DXASSERT_NOMSG(*tinfo != nullptr);
  761. *functionDecl = CXXMethodDecl::Create(
  762. context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
  763. StorageClass::SC_None, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
  764. DXASSERT_NOMSG(*functionDecl != nullptr);
  765. (*functionDecl)->setLexicalDeclContext(recordDecl);
  766. (*functionDecl)->setAccess(AccessSpecifier::AS_public);
  767. }
  768. CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
  769. ASTContext& context,
  770. _In_ CXXRecordDecl* recordDecl,
  771. QualType resultType,
  772. ArrayRef<QualType> paramTypes,
  773. ArrayRef<StringRef> paramNames,
  774. DeclarationName declarationName,
  775. bool isConst)
  776. {
  777. DXASSERT_NOMSG(recordDecl != nullptr);
  778. DXASSERT_NOMSG(!resultType.isNull());
  779. DXASSERT_NOMSG(paramTypes.size() == paramNames.size());
  780. TypeSourceInfo* tinfo;
  781. CXXMethodDecl* functionDecl;
  782. CreateObjectFunctionDeclaration(context, recordDecl, resultType, paramTypes,
  783. declarationName, isConst, &functionDecl,
  784. &tinfo);
  785. // Create and associate parameters to method.
  786. SmallVector<ParmVarDecl *, 2> parmVarDecls;
  787. if (!paramTypes.empty()) {
  788. for (unsigned int i = 0; i < paramTypes.size(); ++i) {
  789. IdentifierInfo *argIi = &context.Idents.get(paramNames[i]);
  790. ParmVarDecl *parmVarDecl = ParmVarDecl::Create(
  791. context, functionDecl, NoLoc, NoLoc, argIi, paramTypes[i],
  792. context.getTrivialTypeSourceInfo(paramTypes[i], NoLoc),
  793. StorageClass::SC_None, nullptr);
  794. parmVarDecl->setScopeInfo(0, i);
  795. DXASSERT(parmVarDecl->getFunctionScopeIndex() == i,
  796. "otherwise failed to set correct index");
  797. parmVarDecls.push_back(parmVarDecl);
  798. }
  799. functionDecl->setParams(ArrayRef<ParmVarDecl *>(parmVarDecls));
  800. AssociateParametersToFunctionPrototype(tinfo, &parmVarDecls.front(),
  801. parmVarDecls.size());
  802. }
  803. recordDecl->addDecl(functionDecl);
  804. return functionDecl;
  805. }
  806. bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
  807. return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
  808. }
  809. bool hlsl::GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
  810. llvm::StringRef &group) {
  811. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  812. return false;
  813. }
  814. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  815. opcode = A->getOpcode();
  816. group = A->getGroup();
  817. return true;
  818. }
  819. bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S) {
  820. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  821. return false;
  822. }
  823. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  824. S = A->getLowering();
  825. return true;
  826. }
  827. /// <summary>Parses a column or row digit.</summary>
  828. static
  829. bool TryParseColOrRowChar(const char digit, _Out_ int* count) {
  830. if ('1' <= digit && digit <= '4') {
  831. *count = digit - '0';
  832. return true;
  833. }
  834. *count = 0;
  835. return false;
  836. }
  837. /// <summary>Parses a matrix shorthand identifier (eg, float3x2).</summary>
  838. _Use_decl_annotations_
  839. bool hlsl::TryParseMatrixShorthand(
  840. const char* typeName,
  841. size_t typeNameLen,
  842. HLSLScalarType* parsedType,
  843. int* rowCount,
  844. int* colCount,
  845. const clang::LangOptions& langOptions
  846. )
  847. {
  848. //
  849. // Matrix shorthand format is PrimitiveTypeRxC, where R is the row count and C is the column count.
  850. // R and C should be between 1 and 4 inclusive.
  851. // x is a literal 'x' character.
  852. // PrimitiveType is one of the HLSLScalarTypeNames values.
  853. //
  854. if (TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions) &&
  855. *rowCount != 0 && *colCount != 0) {
  856. // compare scalar component
  857. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-3, langOptions);
  858. if (type!= HLSLScalarType_unknown) {
  859. *parsedType = type;
  860. return true;
  861. }
  862. }
  863. // Unable to parse.
  864. return false;
  865. }
  866. /// <summary>Parses a vector shorthand identifier (eg, float3).</summary>
  867. _Use_decl_annotations_
  868. bool hlsl::TryParseVectorShorthand(
  869. const char* typeName,
  870. size_t typeNameLen,
  871. HLSLScalarType* parsedType,
  872. int* elementCount,
  873. const clang::LangOptions& langOptions
  874. )
  875. {
  876. // At least *something*N characters necessary, where something is at least 'int'
  877. if (TryParseColOrRowChar(typeName[typeNameLen - 1], elementCount)) {
  878. // compare scalar component
  879. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-1, langOptions);
  880. if (type!= HLSLScalarType_unknown) {
  881. *parsedType = type;
  882. return true;
  883. }
  884. }
  885. // Unable to parse.
  886. return false;
  887. }
  888. /// <summary>Parses a hlsl scalar type (e.g min16float, uint3x4) </summary>
  889. _Use_decl_annotations_
  890. bool hlsl::TryParseScalar(
  891. _In_count_(typenameLen)
  892. const char* typeName,
  893. size_t typeNameLen,
  894. _Out_ HLSLScalarType *parsedType,
  895. _In_ const clang::LangOptions& langOptions) {
  896. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen, langOptions);
  897. if (type!= HLSLScalarType_unknown) {
  898. *parsedType = type;
  899. return true;
  900. }
  901. return false; // unable to parse
  902. }
  903. /// <summary>Parse any (scalar, vector, matrix) hlsl types (e.g float, int3x4, uint2) </summary>
  904. _Use_decl_annotations_
  905. bool hlsl::TryParseAny(
  906. _In_count_(typenameLen)
  907. const char* typeName,
  908. size_t typeNameLen,
  909. _Out_ HLSLScalarType *parsedType,
  910. int *rowCount,
  911. int *colCount,
  912. _In_ const clang::LangOptions& langOptions) {
  913. // at least 'int'
  914. const size_t MinValidLen = 3;
  915. if (typeNameLen >= MinValidLen) {
  916. TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions);
  917. int suffixLen = *colCount == 0 ? 0 :
  918. *rowCount == 0 ? 1 : 3;
  919. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-suffixLen, langOptions);
  920. if (type!= HLSLScalarType_unknown) {
  921. *parsedType = type;
  922. return true;
  923. }
  924. }
  925. return false;
  926. }
  927. /// <summary>Parse any kind of dimension for vector or matrix (e.g 4,3 in int4x3).
  928. /// If it's a matrix type, rowCount and colCount will be nonzero. If it's a vector type, colCount is 0.
  929. /// Otherwise both rowCount and colCount is 0. Returns true if either matrix or vector dimensions detected. </summary>
  930. _Use_decl_annotations_
  931. bool hlsl::TryParseMatrixOrVectorDimension(
  932. _In_count_(typeNameLen)
  933. const char *typeName,
  934. size_t typeNameLen,
  935. _Out_opt_ int *rowCount,
  936. _Out_opt_ int *colCount,
  937. _In_ const clang::LangOptions& langOptions) {
  938. *rowCount = 0;
  939. *colCount = 0;
  940. size_t MinValidLen = 3; // at least int
  941. if (typeNameLen > MinValidLen) {
  942. if (TryParseColOrRowChar(typeName[typeNameLen - 1], colCount)) {
  943. // Try parse matrix
  944. if (typeName[typeNameLen - 2] == 'x')
  945. TryParseColOrRowChar(typeName[typeNameLen - 3], rowCount);
  946. return true;
  947. }
  948. }
  949. return false;
  950. }
  951. /// <summary>Creates a typedef for a matrix shorthand (eg, float3x2).</summary>
  952. TypedefDecl* hlsl::CreateMatrixSpecializationShorthand(
  953. ASTContext& context,
  954. QualType matrixSpecialization,
  955. HLSLScalarType scalarType,
  956. size_t rowCount,
  957. size_t colCount)
  958. {
  959. DXASSERT(rowCount <= 4, "else caller didn't validate rowCount");
  960. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  961. char typeName[64];
  962. sprintf_s(typeName, _countof(typeName), "%s%ux%u",
  963. HLSLScalarTypeNames[scalarType], (unsigned)rowCount, (unsigned)colCount);
  964. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  965. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  966. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  967. context.getTrivialTypeSourceInfo(matrixSpecialization, NoLoc));
  968. decl->setImplicit(true);
  969. currentDeclContext->addDecl(decl);
  970. return decl;
  971. }
  972. /// <summary>Creates a typedef for a vector shorthand (eg, float3).</summary>
  973. TypedefDecl* hlsl::CreateVectorSpecializationShorthand(
  974. ASTContext& context,
  975. QualType vectorSpecialization,
  976. HLSLScalarType scalarType,
  977. size_t colCount)
  978. {
  979. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  980. char typeName[64];
  981. sprintf_s(typeName, _countof(typeName), "%s%u",
  982. HLSLScalarTypeNames[scalarType], (unsigned)colCount);
  983. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  984. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  985. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  986. context.getTrivialTypeSourceInfo(vectorSpecialization, NoLoc));
  987. decl->setImplicit(true);
  988. currentDeclContext->addDecl(decl);
  989. return decl;
  990. }
  991. llvm::ArrayRef<hlsl::UnusualAnnotation*>
  992. hlsl::UnusualAnnotation::CopyToASTContextArray(
  993. clang::ASTContext& Context, hlsl::UnusualAnnotation** begin, size_t count) {
  994. if (count == 0) {
  995. return llvm::ArrayRef<hlsl::UnusualAnnotation*>();
  996. }
  997. UnusualAnnotation** arr = ::new (Context) UnusualAnnotation*[count];
  998. for (size_t i = 0; i < count; ++i) {
  999. arr[i] = begin[i]->CopyToASTContext(Context);
  1000. }
  1001. return llvm::ArrayRef<hlsl::UnusualAnnotation*>(arr, count);
  1002. }
  1003. UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context) {
  1004. // All UnusualAnnotation instances can be blitted.
  1005. size_t instanceSize;
  1006. switch (Kind) {
  1007. case UA_RegisterAssignment:
  1008. instanceSize = sizeof(hlsl::RegisterAssignment);
  1009. break;
  1010. case UA_ConstantPacking:
  1011. instanceSize = sizeof(hlsl::ConstantPacking);
  1012. break;
  1013. default:
  1014. DXASSERT(Kind == UA_SemanticDecl, "Kind == UA_SemanticDecl -- otherwise switch is incomplete");
  1015. instanceSize = sizeof(hlsl::SemanticDecl);
  1016. break;
  1017. }
  1018. void* result = Context.Allocate(instanceSize);
  1019. memcpy(result, this, instanceSize);
  1020. return (UnusualAnnotation*)result;
  1021. }
  1022. static bool HasTessFactorSemantic(const ValueDecl *decl) {
  1023. for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  1024. switch (it->getKind()) {
  1025. case UnusualAnnotation::UA_SemanticDecl: {
  1026. const SemanticDecl *sd = cast<SemanticDecl>(it);
  1027. const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
  1028. if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
  1029. return true;
  1030. }
  1031. }
  1032. }
  1033. return false;
  1034. }
  1035. static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
  1036. if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
  1037. return false;
  1038. if (const RecordType *RT = Ty->getAsStructureType()) {
  1039. RecordDecl *RD = RT->getDecl();
  1040. for (FieldDecl *fieldDecl : RD->fields()) {
  1041. if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
  1042. return true;
  1043. }
  1044. return false;
  1045. }
  1046. if (const clang::ArrayType *arrayTy = Ty->getAsArrayTypeUnsafe())
  1047. return HasTessFactorSemantic(decl);
  1048. return false;
  1049. }
  1050. bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
  1051. // This checks whether the function is structurally capable of being a patch
  1052. // constant function, not whether it is in fact the patch constant function
  1053. // for the entry point of a compiled hull shader (which may not have been
  1054. // seen yet). So the answer is conservative.
  1055. if (!FD->getReturnType()->isVoidType()) {
  1056. // Try to find TessFactor in return type.
  1057. if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
  1058. return true;
  1059. }
  1060. // Try to find TessFactor in out param.
  1061. for (const ParmVarDecl *param : FD->params()) {
  1062. if (param->hasAttr<HLSLOutAttr>()) {
  1063. if (HasTessFactorSemanticRecurse(param, param->getType()))
  1064. return true;
  1065. }
  1066. }
  1067. return false;
  1068. }