ASTContextHLSL.cpp 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226
  1. //===--- ASTContextHLSL.cpp - HLSL support for AST nodes and operations ---===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // ASTContextHLSL.cpp //
  5. // Copyright (C) Microsoft Corporation. All rights reserved. //
  6. // This file is distributed under the University of Illinois Open Source //
  7. // License. See LICENSE.TXT for details. //
  8. // //
  9. // This file implements the ASTContext interface for HLSL. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "clang/AST/ASTContext.h"
  13. #include "clang/AST/Attr.h"
  14. #include "clang/AST/DeclCXX.h"
  15. #include "clang/AST/DeclTemplate.h"
  16. #include "clang/AST/Expr.h"
  17. #include "clang/AST/ExprCXX.h"
  18. #include "clang/AST/ExternalASTSource.h"
  19. #include "clang/AST/TypeLoc.h"
  20. #include "clang/Sema/SemaDiagnostic.h"
  21. #include "clang/Sema/Sema.h"
  22. #include "clang/Sema/Overload.h"
  23. #include "dxc/Support/Global.h"
  24. #include "dxc/HLSL/HLOperations.h"
  25. #include "dxc/HLSL/DxilSemantic.h"
  26. using namespace clang;
  27. using namespace hlsl;
  28. static const int FirstTemplateDepth = 0;
  29. static const int FirstParamPosition = 0;
  30. static const bool ForConstFalse = false; // a construct is targeting a const type
  31. static const bool ForConstTrue = true; // a construct is targeting a non-const type
  32. static const bool ParameterPackFalse = false; // template parameter is not an ellipsis.
  33. static const bool TypenameFalse = false; // 'typename' specified rather than 'class' for a template argument.
  34. static const bool DelayTypeCreationTrue = true; // delay type creation for a declaration
  35. static const SourceLocation NoLoc; // no source location attribution available
  36. static const bool InlineFalse = false; // namespace is not an inline namespace
  37. static const bool InlineSpecifiedFalse = false; // function was not specified as inline
  38. static const bool IsConstexprFalse = false; // function is not constexpr
  39. static const bool VirtualFalse = false; // whether the base class is declares 'virtual'
  40. static const bool BaseClassFalse = false; // whether the base class is declared as 'class' (vs. 'struct')
  41. /// <summary>Names of HLSLScalarType enumeration values, in matching order to HLSLScalarType.</summary>
  42. const char* HLSLScalarTypeNames[] = {
  43. "<unknown>",
  44. "bool",
  45. "int",
  46. "uint",
  47. "dword",
  48. "half",
  49. "float",
  50. "double",
  51. "min10float",
  52. "min16float",
  53. "min12int",
  54. "min16int",
  55. "min16uint",
  56. "literal float",
  57. "literal int",
  58. "int16_t",
  59. "int32_t",
  60. "int64_t",
  61. "uint16_t",
  62. "uint32_t",
  63. "uint64_t",
  64. "float16_t",
  65. "float32_t",
  66. "float64_t"
  67. };
  68. static_assert(HLSLScalarTypeCount == _countof(HLSLScalarTypeNames), "otherwise scalar constants are not aligned");
  69. static HLSLScalarType FindScalarTypeByName(const char *typeName, const size_t typeLen, const LangOptions& langOptions) {
  70. // skipped HLSLScalarType: unknown, literal int, literal float
  71. switch (typeLen) {
  72. case 3: // int
  73. if (typeName[0] == 'i') {
  74. if (strncmp(typeName, "int", 3))
  75. break;
  76. return HLSLScalarType_int;
  77. }
  78. break;
  79. case 4: // bool, uint, half
  80. if (typeName[0] == 'b') {
  81. if (strncmp(typeName, "bool", 4))
  82. break;
  83. return HLSLScalarType_bool;
  84. }
  85. else if (typeName[0] == 'u') {
  86. if (strncmp(typeName, "uint", 4))
  87. break;
  88. return HLSLScalarType_uint;
  89. }
  90. else if (typeName[0] == 'h') {
  91. if (strncmp(typeName, "half", 4))
  92. break;
  93. return HLSLScalarType_half;
  94. }
  95. break;
  96. case 5: // dword, float
  97. if (typeName[0] == 'd') {
  98. if (strncmp(typeName, "dword", 5))
  99. break;
  100. return HLSLScalarType_dword;
  101. }
  102. else if (typeName[0] == 'f') {
  103. if (strncmp(typeName, "float", 5))
  104. break;
  105. return HLSLScalarType_float;
  106. }
  107. break;
  108. case 6: // double
  109. if (typeName[0] == 'd') {
  110. if (strncmp(typeName, "double", 6))
  111. break;
  112. return HLSLScalarType_double;
  113. }
  114. break;
  115. case 7: // int64_t
  116. if (typeName[0] == 'i' && typeName[1] == 'n') {
  117. if (typeName[3] == '6') {
  118. if (strncmp(typeName, "int64_t", 7))
  119. break;
  120. return HLSLScalarType_int64;
  121. }
  122. }
  123. case 8: // min12int, min16int, uint64_t
  124. if (typeName[0] == 'm' && typeName[1] == 'i') {
  125. if (typeName[4] == '2') {
  126. if (strncmp(typeName, "min12int", 8))
  127. break;
  128. return HLSLScalarType_int_min12;
  129. }
  130. else if (typeName[4] == '6') {
  131. if (strncmp(typeName, "min16int", 8))
  132. break;
  133. return HLSLScalarType_int_min16;
  134. }
  135. }
  136. else if (typeName[0] == 'u' && typeName[1] == 'i') {
  137. if (typeName[4] == '6') {
  138. if (strncmp(typeName, "uint64_t", 8))
  139. break;
  140. return HLSLScalarType_uint64;
  141. }
  142. }
  143. break;
  144. case 9: // min16uint
  145. if (typeName[0] == 'm' && typeName[1] == 'i') {
  146. if (strncmp(typeName, "min16uint", 9))
  147. break;
  148. return HLSLScalarType_uint_min16;
  149. }
  150. break;
  151. case 10: // min10float, min16float
  152. if (typeName[0] == 'm' && typeName[1] == 'i') {
  153. if (typeName[4] == '0') {
  154. if (strncmp(typeName, "min10float", 10))
  155. break;
  156. return HLSLScalarType_float_min10;
  157. }
  158. if (typeName[4] == '6') {
  159. if (strncmp(typeName, "min16float", 10))
  160. break;
  161. return HLSLScalarType_float_min16;
  162. }
  163. }
  164. break;
  165. default:
  166. break;
  167. }
  168. // fixed width types (int16_t, uint16_t, int32_t, uint32_t, float16_t, float32_t, float64_t)
  169. // are only supported in HLSL 2018
  170. if (langOptions.HLSLVersion >= 2018) {
  171. switch (typeLen) {
  172. case 7: // int16_t, int32_t
  173. if (typeName[0] == 'i' && typeName[1] == 'n') {
  174. if (!langOptions.UseMinPrecision) {
  175. if (typeName[3] == '1') {
  176. if (strncmp(typeName, "int16_t", 7))
  177. break;
  178. return HLSLScalarType_int16;
  179. }
  180. }
  181. if (typeName[3] == '3') {
  182. if (strncmp(typeName, "int32_t", 7))
  183. break;
  184. return HLSLScalarType_int32;
  185. }
  186. }
  187. case 8: // uint16_t, uint32_t
  188. if (!langOptions.UseMinPrecision) {
  189. if (typeName[0] == 'u' && typeName[1] == 'i') {
  190. if (typeName[4] == '1') {
  191. if (strncmp(typeName, "uint16_t", 8))
  192. break;
  193. return HLSLScalarType_uint16;
  194. }
  195. }
  196. }
  197. if (typeName[4] == '3') {
  198. if (strncmp(typeName, "uint32_t", 8))
  199. break;
  200. return HLSLScalarType_uint32;
  201. }
  202. case 9: // float16_t, float32_t, float64_t
  203. if (typeName[0] == 'f' && typeName[1] == 'l') {
  204. if (!langOptions.UseMinPrecision) {
  205. if (typeName[5] == '1') {
  206. if (strncmp(typeName, "float16_t", 9))
  207. break;
  208. return HLSLScalarType_float16;
  209. }
  210. }
  211. if (typeName[5] == '3') {
  212. if (strncmp(typeName, "float32_t", 9))
  213. break;
  214. return HLSLScalarType_float32;
  215. }
  216. else if (typeName[5] == '6') {
  217. if (strncmp(typeName, "float64_t", 9))
  218. break;
  219. return HLSLScalarType_float64;
  220. }
  221. }
  222. }
  223. }
  224. return HLSLScalarType_unknown;
  225. }
  226. /// <summary>Provides the primitive type for lowering matrix types to IR.</summary>
  227. static
  228. CanQualType GetHLSLObjectHandleType(ASTContext& context)
  229. {
  230. return context.IntTy;
  231. }
  232. /// <summary>Adds a handle field to the specified record.</summary>
  233. static
  234. void AddHLSLHandleField(ASTContext& context, DeclContext* recordDecl, QualType handleQualType)
  235. {
  236. IdentifierInfo& handleId = context.Idents.get(StringRef("h"), tok::TokenKind::identifier);
  237. TypeSourceInfo* fieldTypeSource = context.getTrivialTypeSourceInfo(handleQualType, NoLoc);
  238. const bool MutableFalse = false;
  239. const InClassInitStyle initStyle = InClassInitStyle::ICIS_NoInit;
  240. FieldDecl* handleDecl = FieldDecl::Create(
  241. context, recordDecl, NoLoc, NoLoc, &handleId, handleQualType, fieldTypeSource, nullptr, MutableFalse, initStyle);
  242. handleDecl->setAccess(AccessSpecifier::AS_private);
  243. handleDecl->setImplicit(true);
  244. recordDecl->addDecl(handleDecl);
  245. }
  246. static
  247. void AddSubscriptOperator(
  248. ASTContext& context, unsigned int templateDepth, TemplateTypeParmDecl *elementTemplateParamDecl,
  249. NonTypeTemplateParmDecl* colCountTemplateParamDecl, QualType intType, CXXRecordDecl* templateRecordDecl,
  250. ClassTemplateDecl* vectorTemplateDecl,
  251. bool forConst)
  252. {
  253. QualType elementType = context.getTemplateTypeParmType(
  254. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  255. Expr* sizeExpr = DeclRefExpr::Create(context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl, false,
  256. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  257. intType, ExprValueKind::VK_RValue);
  258. CXXRecordDecl *vecTemplateRecordDecl = vectorTemplateDecl->getTemplatedDecl();
  259. const clang::Type *vecTy = vecTemplateRecordDecl->getTypeForDecl();
  260. TemplateArgument templateArgs[2] =
  261. {
  262. TemplateArgument(elementType),
  263. TemplateArgument(sizeExpr)
  264. };
  265. TemplateName canonName = context.getCanonicalTemplateName(TemplateName(vectorTemplateDecl));
  266. QualType vectorType = context.getTemplateSpecializationType(
  267. canonName, templateArgs, _countof(templateArgs), QualType(vecTy, 0));
  268. vectorType = context.getLValueReferenceType(vectorType);
  269. if (forConst)
  270. vectorType = context.getConstType(vectorType);
  271. QualType indexType = intType;
  272. CreateObjectFunctionDeclarationWithParams(
  273. context, templateRecordDecl, vectorType,
  274. ArrayRef<QualType>(indexType), ArrayRef<StringRef>(StringRef("index")),
  275. context.DeclarationNames.getCXXOperatorName(OO_Subscript), forConst);
  276. }
  277. /// <summary>Adds up-front support for HLSL matrix types (just the template declaration).</summary>
  278. void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorTemplateDecl, ClassTemplateDecl** matrixTemplateDecl)
  279. {
  280. DXASSERT_NOMSG(matrixTemplateDecl != nullptr);
  281. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  282. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  283. // Create a matrix template declaration in translation unit scope.
  284. // template<typename element, int row_count, int col_count> matrix { ... }
  285. IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
  286. TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
  287. context, currentDeclContext, NoLoc, NoLoc,
  288. FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
  289. elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
  290. QualType intType = context.IntTy;
  291. Expr *literalIntFour = IntegerLiteral::Create(
  292. context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
  293. IdentifierInfo& rowCountParamId = context.Idents.get(StringRef("row_count"), tok::TokenKind::identifier);
  294. NonTypeTemplateParmDecl* rowCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  295. context, currentDeclContext, NoLoc, NoLoc,
  296. FirstTemplateDepth, FirstParamPosition + 1, &rowCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
  297. rowCountTemplateParamDecl->setDefaultArgument(literalIntFour);
  298. IdentifierInfo& colCountParamId = context.Idents.get(StringRef("col_count"), tok::TokenKind::identifier);
  299. NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  300. context, currentDeclContext, NoLoc, NoLoc,
  301. FirstTemplateDepth, FirstParamPosition + 2, &colCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
  302. colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
  303. NamedDecl* templateParameters[] =
  304. {
  305. elementTemplateParamDecl, rowCountTemplateParamDecl, colCountTemplateParamDecl
  306. };
  307. TemplateParameterList* templateParameterList = TemplateParameterList::Create(
  308. context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
  309. IdentifierInfo& matrixId = context.Idents.get(StringRef("matrix"), tok::TokenKind::identifier);
  310. CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
  311. context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &matrixId,
  312. nullptr, DelayTypeCreationTrue);
  313. ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
  314. context, currentDeclContext, NoLoc, DeclarationName(&matrixId),
  315. templateParameterList, templateRecordDecl, nullptr);
  316. templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
  317. // Requesting the class name specialization will fault in required types.
  318. QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
  319. T = context.getInjectedClassNameType(templateRecordDecl, T);
  320. assert(T->isDependentType() && "Class template type is not dependent?");
  321. classTemplateDecl->setLexicalDeclContext(currentDeclContext);
  322. templateRecordDecl->setLexicalDeclContext(currentDeclContext);
  323. templateRecordDecl->startDefinition();
  324. // Add an 'h' field to hold the handle.
  325. // The type is vector<element, col>[row].
  326. QualType elementType = context.getTemplateTypeParmType(
  327. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  328. Expr *sizeExpr = DeclRefExpr::Create(
  329. context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
  330. false,
  331. DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
  332. intType, ExprValueKind::VK_RValue);
  333. Expr *rowSizeExpr = DeclRefExpr::Create(
  334. context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
  335. false,
  336. DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
  337. intType, ExprValueKind::VK_RValue);
  338. QualType vectorType = context.getDependentSizedExtVectorType(
  339. elementType, rowSizeExpr, SourceLocation());
  340. QualType vectorArrayType = context.getDependentSizedArrayType(
  341. vectorType, sizeExpr, ArrayType::Normal, 0, SourceRange());
  342. AddHLSLHandleField(context, templateRecordDecl, vectorArrayType);
  343. // Add an operator[]. The operator ranges from zero to rowcount-1, and returns a vector of colcount elements.
  344. const unsigned int templateDepth = 0;
  345. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  346. colCountTemplateParamDecl, context.UnsignedIntTy,
  347. templateRecordDecl, vectorTemplateDecl, ForConstFalse);
  348. AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
  349. colCountTemplateParamDecl, context.UnsignedIntTy,
  350. templateRecordDecl, vectorTemplateDecl, ForConstTrue);
  351. templateRecordDecl->completeDefinition();
  352. classTemplateDecl->setImplicit(true);
  353. templateRecordDecl->setImplicit(true);
  354. // Both declarations need to be present for correct handling.
  355. currentDeclContext->addDecl(classTemplateDecl);
  356. currentDeclContext->addDecl(templateRecordDecl);
  357. #ifdef DBG
  358. // Verify that we can read the field member from the template record.
  359. DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
  360. DeclarationName(&context.Idents.get(StringRef("h"))));
  361. DXASSERT(!lookupResult.empty(), "otherwise matrix handle cannot be looked up");
  362. #endif
  363. *matrixTemplateDecl = classTemplateDecl;
  364. }
  365. static void AddHLSLVectorSubscriptAttr(Decl *D, ASTContext &context) {
  366. StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
  367. D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast<unsigned>(HLSubscriptOpcode::VectorSubscript)));
  368. }
  369. /// <summary>Adds up-front support for HLSL vector types (just the template declaration).</summary>
  370. void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vectorTemplateDecl)
  371. {
  372. DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
  373. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  374. // Create a vector template declaration in translation unit scope.
  375. // template<typename element, int element_count> vector { ... }
  376. IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
  377. TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
  378. context, currentDeclContext, NoLoc, NoLoc,
  379. FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
  380. elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
  381. QualType intType = context.IntTy;
  382. Expr *literalIntFour = IntegerLiteral::Create(
  383. context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
  384. IdentifierInfo& colCountParamId = context.Idents.get(StringRef("element_count"), tok::TokenKind::identifier);
  385. NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  386. context, currentDeclContext, NoLoc, NoLoc,
  387. FirstTemplateDepth, FirstParamPosition + 1, &colCountParamId, intType, ParameterPackFalse, nullptr);
  388. colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
  389. NamedDecl* templateParameters[] =
  390. {
  391. elementTemplateParamDecl, colCountTemplateParamDecl
  392. };
  393. TemplateParameterList* templateParameterList = TemplateParameterList::Create(
  394. context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
  395. IdentifierInfo& vectorId = context.Idents.get(StringRef("vector"), tok::TokenKind::identifier);
  396. CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
  397. context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &vectorId,
  398. nullptr, DelayTypeCreationTrue);
  399. ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
  400. context, currentDeclContext, NoLoc, DeclarationName(&vectorId),
  401. templateParameterList, templateRecordDecl, nullptr);
  402. templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
  403. // Requesting the class name specialization will fault in required types.
  404. QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
  405. T = context.getInjectedClassNameType(templateRecordDecl, T);
  406. assert(T->isDependentType() && "Class template type is not dependent?");
  407. classTemplateDecl->setLexicalDeclContext(currentDeclContext);
  408. templateRecordDecl->setLexicalDeclContext(currentDeclContext);
  409. templateRecordDecl->startDefinition();
  410. // Add an 'h' field to hold the handle.
  411. AddHLSLHandleField(context, templateRecordDecl, QualType(GetHLSLObjectHandleType(context)));
  412. // Add an operator[]. The operator ranges from zero to colcount-1, and returns a scalar.
  413. const unsigned int templateDepth = 0;
  414. QualType resultType = context.getTemplateTypeParmType(
  415. templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
  416. // ForConstTrue:
  417. QualType refResultType = context.getConstType(context.getLValueReferenceType(resultType));
  418. CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
  419. context, templateRecordDecl, refResultType,
  420. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  421. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstTrue);
  422. AddHLSLVectorSubscriptAttr(functionDecl, context);
  423. // ForConstFalse:
  424. resultType = context.getLValueReferenceType(resultType);
  425. functionDecl = CreateObjectFunctionDeclarationWithParams(
  426. context, templateRecordDecl, resultType,
  427. ArrayRef<QualType>(context.UnsignedIntTy), ArrayRef<StringRef>(StringRef("index")),
  428. context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstFalse);
  429. AddHLSLVectorSubscriptAttr(functionDecl, context);
  430. templateRecordDecl->completeDefinition();
  431. classTemplateDecl->setImplicit(true);
  432. templateRecordDecl->setImplicit(true);
  433. // Both declarations need to be present for correct handling.
  434. currentDeclContext->addDecl(classTemplateDecl);
  435. currentDeclContext->addDecl(templateRecordDecl);
  436. #ifdef DBG
  437. // Verify that we can read the field member from the template record.
  438. DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
  439. DeclarationName(&context.Idents.get(StringRef("h"))));
  440. DXASSERT(!lookupResult.empty(), "otherwise vector handle cannot be looked up");
  441. #endif
  442. *vectorTemplateDecl = classTemplateDecl;
  443. }
  444. /// <summary>
  445. /// Adds a new record type in the specified context with the given name. The record type will have a handle field.
  446. /// </summary>
  447. void hlsl::AddRecordTypeWithHandle(ASTContext& context, _Outptr_ CXXRecordDecl** typeDecl, _In_z_ const char* typeName)
  448. {
  449. DXASSERT_NOMSG(typeDecl != nullptr);
  450. DXASSERT_NOMSG(typeName != nullptr);
  451. *typeDecl = nullptr;
  452. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  453. IdentifierInfo& newTypeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  454. CXXRecordDecl* newDecl = CXXRecordDecl::Create(
  455. context, TagDecl::TagKind::TTK_Struct, currentDeclContext, NoLoc, NoLoc, &newTypeId, nullptr);
  456. newDecl->setLexicalDeclContext(currentDeclContext);
  457. newDecl->setFreeStanding();
  458. newDecl->startDefinition();
  459. AddHLSLHandleField(context, newDecl, QualType(GetHLSLObjectHandleType(context)));
  460. currentDeclContext->addDecl(newDecl);
  461. newDecl->completeDefinition();
  462. *typeDecl = newDecl;
  463. }
  464. // creates a global static constant unsigned integer with value.
  465. // equivalent to: static const uint name = val;
  466. static void AddConstUInt(clang::ASTContext& context, DeclContext *DC, StringRef name, unsigned val) {
  467. IdentifierInfo &Id = context.Idents.get(name, tok::TokenKind::identifier);
  468. QualType type = context.getConstType(context.UnsignedIntTy);
  469. VarDecl *varDecl = VarDecl::Create(context, DC, NoLoc, NoLoc, &Id, type,
  470. context.getTrivialTypeSourceInfo(type),
  471. clang::StorageClass::SC_Static);
  472. Expr *exprVal = IntegerLiteral::Create(
  473. context, llvm::APInt(context.getIntWidth(type), val), type, NoLoc);
  474. varDecl->setInit(exprVal);
  475. varDecl->setImplicit(true);
  476. DC->addDecl(varDecl);
  477. }
  478. /// <summary> Adds a const integers for ray flags </summary>
  479. void hlsl::AddRayFlags(ASTContext& context) {
  480. DeclContext *curDC = context.getTranslationUnitDecl();
  481. // typedef uint RAY_FLAG;
  482. IdentifierInfo &rayFlagId = context.Idents.get(StringRef("RAY_FLAG"), tok::TokenKind::identifier);
  483. TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
  484. TypedefDecl *rayFlagDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &rayFlagId, uintTypeSource);
  485. curDC->addDecl(rayFlagDecl);
  486. rayFlagDecl->setImplicit(true);
  487. // static const uint RAY_FLAG_* = *;
  488. AddConstUInt(context, curDC, StringRef("RAY_FLAG_NONE"), (unsigned)DXIL::RayFlag::None);
  489. AddConstUInt(context, curDC, StringRef("RAY_FLAG_FORCE_OPAQUE"), (unsigned)DXIL::RayFlag::ForceOpaque);
  490. AddConstUInt(context, curDC, StringRef("RAY_FLAG_FORCE_NON_OPAQUE"), (unsigned)DXIL::RayFlag::ForceNonOpaque);
  491. AddConstUInt(context, curDC, StringRef("RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH"), (unsigned)DXIL::RayFlag::AcceptFirstHitAndEndSearch);
  492. AddConstUInt(context, curDC, StringRef("RAY_FLAG_SKIP_CLOSEST_HIT_SHADER"), (unsigned)DXIL::RayFlag::SkipClosestHitShader);
  493. AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_BACK_FACING_TRIANGLES"), (unsigned)DXIL::RayFlag::CullBackFacingTriangles);
  494. AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_FRONT_FACING_TRIANGLES"), (unsigned)DXIL::RayFlag::CullFrontFacingTriangles);
  495. AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_OPAQUE"), (unsigned)DXIL::RayFlag::CullOpaque);
  496. AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_NON_OPAQUE"), (unsigned)DXIL::RayFlag::CullNonOpaque);
  497. }
  498. /// <summary> Adds a constant integers for hit kinds </summary>
  499. void hlsl::AddHitKinds(ASTContext& context) {
  500. DeclContext *curDC = context.getTranslationUnitDecl();
  501. // static const uint HIT_KIND_* = *;
  502. AddConstUInt(context, curDC, StringRef("HIT_KIND_NONE"), (unsigned)DXIL::HitKind::None);
  503. AddConstUInt(context, curDC, StringRef("HIT_KIND_TRIANGLE_FRONT_FACE"), (unsigned)DXIL::HitKind::TriangleFrontFace);
  504. AddConstUInt(context, curDC, StringRef("HIT_KIND_TRIANGLE_BACK_FACE"), (unsigned)DXIL::HitKind::TriangleBackFace);
  505. }
  506. static
  507. Expr* IntConstantAsBoolExpr(clang::Sema& sema, uint64_t value)
  508. {
  509. return sema.ImpCastExprToType(
  510. sema.ActOnIntegerConstant(NoLoc, value).get(), sema.getASTContext().BoolTy, CK_IntegralToBoolean).get();
  511. }
  512. static
  513. CXXRecordDecl* CreateStdStructWithStaticBool(clang::ASTContext& context, NamespaceDecl* stdNamespace, IdentifierInfo& trueTypeId, IdentifierInfo& valueId, Expr* trueExpression)
  514. {
  515. // struct true_type { static const bool value = true; }
  516. TypeSourceInfo* boolTypeSource = context.getTrivialTypeSourceInfo(context.BoolTy.withConst());
  517. CXXRecordDecl* trueTypeDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &trueTypeId, nullptr, DelayTypeCreationTrue);
  518. // static fields are variables in the AST
  519. VarDecl* trueValueDecl = VarDecl::Create(context, trueTypeDecl, NoLoc, NoLoc, &valueId,
  520. context.BoolTy.withConst(), boolTypeSource, SC_Static);
  521. trueValueDecl->setInit(trueExpression);
  522. trueValueDecl->setConstexpr(true);
  523. trueValueDecl->setAccess(AS_public);
  524. trueTypeDecl->setLexicalDeclContext(stdNamespace);
  525. trueTypeDecl->startDefinition();
  526. trueTypeDecl->addDecl(trueValueDecl);
  527. trueTypeDecl->completeDefinition();
  528. stdNamespace->addDecl(trueTypeDecl);
  529. return trueTypeDecl;
  530. }
  531. static
  532. void DefineRecordWithBase(CXXRecordDecl* decl, DeclContext* lexicalContext, const CXXBaseSpecifier* base)
  533. {
  534. decl->setLexicalDeclContext(lexicalContext);
  535. decl->startDefinition();
  536. decl->setBases(&base, 1);
  537. decl->completeDefinition();
  538. lexicalContext->addDecl(decl);
  539. }
  540. static
  541. void SetPartialExplicitSpecialization(ClassTemplateDecl* templateDecl, ClassTemplatePartialSpecializationDecl* specializationDecl)
  542. {
  543. specializationDecl->setSpecializationKind(TSK_ExplicitSpecialization);
  544. templateDecl->AddPartialSpecialization(specializationDecl, nullptr);
  545. }
  546. static
  547. void CreateIsEqualSpecialization(ASTContext& context, ClassTemplateDecl* templateDecl, TemplateName& templateName,
  548. DeclContext* lexicalContext, const CXXBaseSpecifier* base, TemplateParameterList* templateParamList,
  549. TemplateArgument (&templateArgs)[2])
  550. {
  551. QualType specializationCanonType = context.getTemplateSpecializationType(templateName, templateArgs, _countof(templateArgs));
  552. TemplateArgumentListInfo templateArgsListInfo = TemplateArgumentListInfo(NoLoc, NoLoc);
  553. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[0], context.getTrivialTypeSourceInfo(templateArgs[0].getAsType())));
  554. templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[1], context.getTrivialTypeSourceInfo(templateArgs[1].getAsType())));
  555. ClassTemplatePartialSpecializationDecl* specializationDecl =
  556. ClassTemplatePartialSpecializationDecl::Create(context, TTK_Struct, lexicalContext, NoLoc, NoLoc,
  557. templateParamList, templateDecl, templateArgs, _countof(templateArgs),
  558. templateArgsListInfo, specializationCanonType, nullptr);
  559. context.getTagDeclType(specializationDecl); // Fault this in now.
  560. DefineRecordWithBase(specializationDecl, lexicalContext, base);
  561. SetPartialExplicitSpecialization(templateDecl, specializationDecl);
  562. }
  563. /// <summary>Adds the implementation for std::is_equal.</summary>
  564. void hlsl::AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema)
  565. {
  566. // The goal is to support std::is_same<T, T>::value for testing purposes, in a manner that can
  567. // evolve into a compliant feature in the future.
  568. //
  569. // The definitions necessary are as follows (all in the std namespace).
  570. // template <class T, T v>
  571. // struct integral_constant {
  572. // typedef T value_type;
  573. // static const value_type value = v;
  574. // operator value_type() { return value; }
  575. // };
  576. //
  577. // typedef integral_constant<bool, true> true_type;
  578. // typedef integral_constant<bool, false> false_type;
  579. //
  580. // template<typename T, typename U> struct is_same : public false_type {};
  581. // template<typename T> struct is_same<T, T> : public true_type{};
  582. //
  583. // We instead use these simpler definitions for true_type and false_type.
  584. // struct false_type { static const bool value = false; };
  585. // struct true_type { static const bool value = true; };
  586. DeclContext* tuContext = context.getTranslationUnitDecl();
  587. IdentifierInfo& stdId = context.Idents.get(StringRef("std"), tok::TokenKind::identifier);
  588. IdentifierInfo& trueTypeId = context.Idents.get(StringRef("true_type"), tok::TokenKind::identifier);
  589. IdentifierInfo& falseTypeId = context.Idents.get(StringRef("false_type"), tok::TokenKind::identifier);
  590. IdentifierInfo& valueId = context.Idents.get(StringRef("value"), tok::TokenKind::identifier);
  591. IdentifierInfo& isSameId = context.Idents.get(StringRef("is_same"), tok::TokenKind::identifier);
  592. IdentifierInfo& tId = context.Idents.get(StringRef("T"), tok::TokenKind::identifier);
  593. IdentifierInfo& vId = context.Idents.get(StringRef("V"), tok::TokenKind::identifier);
  594. Expr* trueExpression = IntConstantAsBoolExpr(sema, 1);
  595. Expr* falseExpression = IntConstantAsBoolExpr(sema, 0);
  596. // namespace std
  597. NamespaceDecl* stdNamespace = NamespaceDecl::Create(context, tuContext, InlineFalse, NoLoc, NoLoc, &stdId, nullptr);
  598. CXXRecordDecl* trueTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, trueTypeId, valueId, trueExpression);
  599. CXXRecordDecl* falseTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, falseTypeId, valueId, falseExpression);
  600. // template<typename T, typename U> struct is_same : public false_type {};
  601. CXXRecordDecl* isSameFalseRecordDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &isSameId, nullptr, false);
  602. TemplateTypeParmDecl* tParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition, &tId, TypenameFalse, ParameterPackFalse);
  603. TemplateTypeParmDecl* uParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition + 1, &vId, TypenameFalse, ParameterPackFalse);
  604. NamedDecl* falseParams[] = { tParam, uParam };
  605. TemplateParameterList* falseParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, falseParams, _countof(falseParams), NoLoc);
  606. ClassTemplateDecl* isSameFalseTemplateDecl = ClassTemplateDecl::Create(context, stdNamespace, NoLoc, DeclarationName(&isSameId), falseParamList, isSameFalseRecordDecl, nullptr);
  607. context.getTagDeclType(isSameFalseRecordDecl); // Fault this in now.
  608. CXXBaseSpecifier* falseBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  609. context.getTrivialTypeSourceInfo(context.getTypeDeclType(falseTypeDecl)), NoLoc);
  610. isSameFalseRecordDecl->setDescribedClassTemplate(isSameFalseTemplateDecl);
  611. isSameFalseTemplateDecl->setLexicalDeclContext(stdNamespace);
  612. DefineRecordWithBase(isSameFalseRecordDecl, stdNamespace, falseBase);
  613. // is_same for 'true' is a specialization of is_same for 'false', taking a single T, where both T will match
  614. // template<typename T> struct is_same<T, T> : public true_type{};
  615. TemplateName tn = TemplateName(isSameFalseTemplateDecl);
  616. NamedDecl* trueParams[] = { tParam };
  617. TemplateParameterList* trueParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, trueParams, _countof(trueParams), NoLoc);
  618. CXXBaseSpecifier* trueBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
  619. context.getTrivialTypeSourceInfo(context.getTypeDeclType(trueTypeDecl)), NoLoc);
  620. TemplateArgument ta = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)));
  621. TemplateArgument isSameTrueTemplateArgs[] = { ta, ta };
  622. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueTemplateArgs);
  623. stdNamespace->addDecl(isSameFalseTemplateDecl);
  624. stdNamespace->setImplicit(true);
  625. tuContext->addDecl(stdNamespace);
  626. // This could be a parameter if ever needed.
  627. const bool SupportExtensions = true;
  628. // Consider right-hand const and right-hand ref to be true for is_same:
  629. // template<typename T> struct is_same<T, const T> : public true_type{};
  630. // template<typename T> struct is_same<T, T&> : public true_type{};
  631. if (SupportExtensions)
  632. {
  633. TemplateArgument trueConstArg = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)).withConst());
  634. TemplateArgument isSameTrueConstTemplateArgs[] = { ta, trueConstArg };
  635. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueConstTemplateArgs);
  636. TemplateArgument trueRefArg = TemplateArgument(
  637. context.getLValueReferenceType(context.getCanonicalType(context.getTypeDeclType(tParam))));
  638. TemplateArgument isSameTrueRefTemplateArgs[] = { ta, trueRefArg };
  639. CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueRefTemplateArgs);
  640. }
  641. }
  642. /// <summary>
  643. /// Adds a new template type in the specified context with the given name. The record type will have a handle field.
  644. /// </summary>
  645. /// <parm name="context">AST context to which template will be added.</param>
  646. /// <parm name="typeDecl">After execution, template declaration.</param>
  647. /// <parm name="recordDecl">After execution, record declaration for template.</param>
  648. /// <parm name="typeName">Name of template to create.</param>
  649. /// <parm name="templateArgCount">Number of template arguments (one or two).</param>
  650. /// <parm name="defaultTypeArgValue">If assigned, the default argument for the element template.</param>
  651. void hlsl::AddTemplateTypeWithHandle(
  652. ASTContext& context,
  653. _Outptr_ ClassTemplateDecl** typeDecl,
  654. _Outptr_ CXXRecordDecl** recordDecl,
  655. _In_z_ const char* typeName,
  656. uint8_t templateArgCount,
  657. _In_opt_ TypeSourceInfo* defaultTypeArgValue
  658. )
  659. {
  660. DXASSERT_NOMSG(typeDecl != nullptr);
  661. DXASSERT_NOMSG(recordDecl != nullptr);
  662. DXASSERT_NOMSG(typeName != nullptr);
  663. DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
  664. DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");
  665. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  666. // Create an object template declaration in translation unit scope.
  667. // templateArgCount=1: template<typename element> typeName { ... }
  668. // templateArgCount=2: template<typename element, int count> typeName { ... }
  669. IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
  670. TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
  671. context, currentDeclContext, NoLoc, NoLoc,
  672. FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
  673. QualType intType = context.IntTy;
  674. if (defaultTypeArgValue != nullptr)
  675. {
  676. elementTemplateParamDecl->setDefaultArgument(defaultTypeArgValue);
  677. }
  678. NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
  679. if (templateArgCount > 1) {
  680. IdentifierInfo& countParamId = context.Idents.get(StringRef("count"), tok::TokenKind::identifier);
  681. countTemplateParamDecl = NonTypeTemplateParmDecl::Create(
  682. context, currentDeclContext, NoLoc, NoLoc,
  683. FirstTemplateDepth, FirstParamPosition + 1, &countParamId, intType, ParameterPackFalse, nullptr);
  684. // Zero means default here. The count is decided by runtime.
  685. Expr *literalIntZero = IntegerLiteral::Create(
  686. context, llvm::APInt(context.getIntWidth(intType), 0), intType, NoLoc);
  687. countTemplateParamDecl->setDefaultArgument(literalIntZero);
  688. }
  689. NamedDecl* templateParameters[] =
  690. {
  691. elementTemplateParamDecl, countTemplateParamDecl
  692. };
  693. TemplateParameterList* templateParameterList = TemplateParameterList::Create(
  694. context, NoLoc, NoLoc, templateParameters, templateArgCount, NoLoc);
  695. IdentifierInfo& typeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  696. CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
  697. context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
  698. nullptr, DelayTypeCreationTrue);
  699. ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
  700. context, currentDeclContext, NoLoc, DeclarationName(&typeId),
  701. templateParameterList, templateRecordDecl, nullptr);
  702. templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
  703. // Requesting the class name specialization will fault in required types.
  704. QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
  705. T = context.getInjectedClassNameType(templateRecordDecl, T);
  706. assert(T->isDependentType() && "Class template type is not dependent?");
  707. classTemplateDecl->setLexicalDeclContext(currentDeclContext);
  708. templateRecordDecl->setLexicalDeclContext(currentDeclContext);
  709. templateRecordDecl->startDefinition();
  710. // Many more things to come here, like constructors and the like....
  711. // Add an 'h' field to hold the handle.
  712. QualType elementType = context.getTemplateTypeParmType(
  713. /*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
  714. if (templateArgCount > 1 &&
  715. // Only need array type for inputpatch and outputpatch.
  716. // Avoid Texture2DMS which may use 0 count.
  717. // TODO: use hlsl types to do the check.
  718. !typeId.getName().startswith("Texture")) {
  719. Expr *countExpr = DeclRefExpr::Create(
  720. context, NestedNameSpecifierLoc(), NoLoc, countTemplateParamDecl, false,
  721. DeclarationNameInfo(countTemplateParamDecl->getDeclName(), NoLoc),
  722. intType, ExprValueKind::VK_RValue);
  723. elementType = context.getDependentSizedArrayType(
  724. elementType, countExpr, ArrayType::ArraySizeModifier::Normal, 0,
  725. SourceRange());
  726. // InputPatch and OutputPatch also have a "Length" static const member for the number of control points
  727. IdentifierInfo& lengthId = context.Idents.get(StringRef("Length"), tok::TokenKind::identifier);
  728. TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(intType.withConst());
  729. VarDecl* lengthValueDecl = VarDecl::Create(context, templateRecordDecl, NoLoc, NoLoc, &lengthId,
  730. intType.withConst(), lengthTypeSource, SC_Static);
  731. lengthValueDecl->setInit(countExpr);
  732. lengthValueDecl->setAccess(AS_public);
  733. templateRecordDecl->addDecl(lengthValueDecl);
  734. }
  735. AddHLSLHandleField(context, templateRecordDecl, elementType);
  736. templateRecordDecl->completeDefinition();
  737. // Both declarations need to be present for correct handling.
  738. currentDeclContext->addDecl(classTemplateDecl);
  739. currentDeclContext->addDecl(templateRecordDecl);
  740. #ifdef DBG
  741. // Verify that we can read the field member from the template record.
  742. DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
  743. DeclarationName(&context.Idents.get(StringRef("h"))));
  744. DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
  745. #endif
  746. *typeDecl = classTemplateDecl;
  747. *recordDecl = templateRecordDecl;
  748. }
  749. FunctionTemplateDecl* hlsl::CreateFunctionTemplateDecl(
  750. ASTContext& context,
  751. _In_ CXXRecordDecl* recordDecl,
  752. _In_ CXXMethodDecl* functionDecl,
  753. _In_count_(templateParamNamedDeclsCount) NamedDecl** templateParamNamedDecls,
  754. size_t templateParamNamedDeclsCount)
  755. {
  756. DXASSERT_NOMSG(recordDecl != nullptr);
  757. DXASSERT_NOMSG(templateParamNamedDecls != nullptr);
  758. DXASSERT(templateParamNamedDeclsCount > 0, "otherwise caller shouldn't invoke this function");
  759. TemplateParameterList* templateParams = TemplateParameterList::Create(
  760. context, NoLoc, NoLoc, &templateParamNamedDecls[0], templateParamNamedDeclsCount, NoLoc);
  761. FunctionTemplateDecl* functionTemplate =
  762. FunctionTemplateDecl::Create(context, recordDecl, NoLoc, functionDecl->getDeclName(), templateParams, functionDecl);
  763. functionTemplate->setAccess(AccessSpecifier::AS_public);
  764. functionTemplate->setLexicalDeclContext(recordDecl);
  765. functionDecl->setDescribedFunctionTemplate(functionTemplate);
  766. recordDecl->addDecl(functionTemplate);
  767. return functionTemplate;
  768. }
  769. static
  770. void AssociateParametersToFunctionPrototype(
  771. _In_ TypeSourceInfo* tinfo,
  772. _In_count_(numParams) ParmVarDecl** paramVarDecls,
  773. unsigned int numParams)
  774. {
  775. FunctionProtoTypeLoc protoLoc = tinfo->getTypeLoc().getAs<FunctionProtoTypeLoc>();
  776. DXASSERT(protoLoc.getNumParams() == numParams, "otherwise unexpected number of parameters available");
  777. for (unsigned i = 0; i < numParams; i++) {
  778. DXASSERT(protoLoc.getParam(i) == nullptr, "otherwise prototype parameters were already initialized");
  779. protoLoc.setParam(i, paramVarDecls[i]);
  780. }
  781. }
  782. static void CreateObjectFunctionDeclaration(
  783. ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
  784. ArrayRef<QualType> args, DeclarationName declarationName, bool isConst,
  785. _Out_ CXXMethodDecl **functionDecl, _Out_ TypeSourceInfo **tinfo) {
  786. DXASSERT_NOMSG(recordDecl != nullptr);
  787. DXASSERT_NOMSG(functionDecl != nullptr);
  788. FunctionProtoType::ExtProtoInfo functionExtInfo;
  789. functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
  790. QualType functionQT = context.getFunctionType(
  791. resultType, args, functionExtInfo, ArrayRef<ParameterModifier>());
  792. DeclarationNameInfo declNameInfo(declarationName, NoLoc);
  793. *tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
  794. DXASSERT_NOMSG(*tinfo != nullptr);
  795. *functionDecl = CXXMethodDecl::Create(
  796. context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
  797. StorageClass::SC_None, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
  798. DXASSERT_NOMSG(*functionDecl != nullptr);
  799. (*functionDecl)->setLexicalDeclContext(recordDecl);
  800. (*functionDecl)->setAccess(AccessSpecifier::AS_public);
  801. }
  802. CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
  803. ASTContext& context,
  804. _In_ CXXRecordDecl* recordDecl,
  805. QualType resultType,
  806. ArrayRef<QualType> paramTypes,
  807. ArrayRef<StringRef> paramNames,
  808. DeclarationName declarationName,
  809. bool isConst)
  810. {
  811. DXASSERT_NOMSG(recordDecl != nullptr);
  812. DXASSERT_NOMSG(!resultType.isNull());
  813. DXASSERT_NOMSG(paramTypes.size() == paramNames.size());
  814. TypeSourceInfo* tinfo;
  815. CXXMethodDecl* functionDecl;
  816. CreateObjectFunctionDeclaration(context, recordDecl, resultType, paramTypes,
  817. declarationName, isConst, &functionDecl,
  818. &tinfo);
  819. // Create and associate parameters to method.
  820. SmallVector<ParmVarDecl *, 2> parmVarDecls;
  821. if (!paramTypes.empty()) {
  822. for (unsigned int i = 0; i < paramTypes.size(); ++i) {
  823. IdentifierInfo *argIi = &context.Idents.get(paramNames[i]);
  824. ParmVarDecl *parmVarDecl = ParmVarDecl::Create(
  825. context, functionDecl, NoLoc, NoLoc, argIi, paramTypes[i],
  826. context.getTrivialTypeSourceInfo(paramTypes[i], NoLoc),
  827. StorageClass::SC_None, nullptr);
  828. parmVarDecl->setScopeInfo(0, i);
  829. DXASSERT(parmVarDecl->getFunctionScopeIndex() == i,
  830. "otherwise failed to set correct index");
  831. parmVarDecls.push_back(parmVarDecl);
  832. }
  833. functionDecl->setParams(ArrayRef<ParmVarDecl *>(parmVarDecls));
  834. AssociateParametersToFunctionPrototype(tinfo, &parmVarDecls.front(),
  835. parmVarDecls.size());
  836. }
  837. recordDecl->addDecl(functionDecl);
  838. return functionDecl;
  839. }
  840. bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
  841. return FD != nullptr && FD->hasAttr<HLSLIntrinsicAttr>();
  842. }
  843. bool hlsl::GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
  844. llvm::StringRef &group) {
  845. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  846. return false;
  847. }
  848. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  849. opcode = A->getOpcode();
  850. group = A->getGroup();
  851. return true;
  852. }
  853. bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S) {
  854. if (FD == nullptr || !FD->hasAttr<HLSLIntrinsicAttr>()) {
  855. return false;
  856. }
  857. HLSLIntrinsicAttr *A = FD->getAttr<HLSLIntrinsicAttr>();
  858. S = A->getLowering();
  859. return true;
  860. }
  861. /// <summary>Parses a column or row digit.</summary>
  862. static
  863. bool TryParseColOrRowChar(const char digit, _Out_ int* count) {
  864. if ('1' <= digit && digit <= '4') {
  865. *count = digit - '0';
  866. return true;
  867. }
  868. *count = 0;
  869. return false;
  870. }
  871. /// <summary>Parses a matrix shorthand identifier (eg, float3x2).</summary>
  872. _Use_decl_annotations_
  873. bool hlsl::TryParseMatrixShorthand(
  874. const char* typeName,
  875. size_t typeNameLen,
  876. HLSLScalarType* parsedType,
  877. int* rowCount,
  878. int* colCount,
  879. const clang::LangOptions& langOptions
  880. )
  881. {
  882. //
  883. // Matrix shorthand format is PrimitiveTypeRxC, where R is the row count and C is the column count.
  884. // R and C should be between 1 and 4 inclusive.
  885. // x is a literal 'x' character.
  886. // PrimitiveType is one of the HLSLScalarTypeNames values.
  887. //
  888. if (TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions) &&
  889. *rowCount != 0 && *colCount != 0) {
  890. // compare scalar component
  891. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-3, langOptions);
  892. if (type!= HLSLScalarType_unknown) {
  893. *parsedType = type;
  894. return true;
  895. }
  896. }
  897. // Unable to parse.
  898. return false;
  899. }
  900. /// <summary>Parses a vector shorthand identifier (eg, float3).</summary>
  901. _Use_decl_annotations_
  902. bool hlsl::TryParseVectorShorthand(
  903. const char* typeName,
  904. size_t typeNameLen,
  905. HLSLScalarType* parsedType,
  906. int* elementCount,
  907. const clang::LangOptions& langOptions
  908. )
  909. {
  910. // At least *something*N characters necessary, where something is at least 'int'
  911. if (TryParseColOrRowChar(typeName[typeNameLen - 1], elementCount)) {
  912. // compare scalar component
  913. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-1, langOptions);
  914. if (type!= HLSLScalarType_unknown) {
  915. *parsedType = type;
  916. return true;
  917. }
  918. }
  919. // Unable to parse.
  920. return false;
  921. }
  922. /// <summary>Parses a hlsl scalar type (e.g min16float, uint3x4) </summary>
  923. _Use_decl_annotations_
  924. bool hlsl::TryParseScalar(
  925. _In_count_(typenameLen)
  926. const char* typeName,
  927. size_t typeNameLen,
  928. _Out_ HLSLScalarType *parsedType,
  929. _In_ const clang::LangOptions& langOptions) {
  930. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen, langOptions);
  931. if (type!= HLSLScalarType_unknown) {
  932. *parsedType = type;
  933. return true;
  934. }
  935. return false; // unable to parse
  936. }
  937. /// <summary>Parse any (scalar, vector, matrix) hlsl types (e.g float, int3x4, uint2) </summary>
  938. _Use_decl_annotations_
  939. bool hlsl::TryParseAny(
  940. _In_count_(typenameLen)
  941. const char* typeName,
  942. size_t typeNameLen,
  943. _Out_ HLSLScalarType *parsedType,
  944. int *rowCount,
  945. int *colCount,
  946. _In_ const clang::LangOptions& langOptions) {
  947. // at least 'int'
  948. const size_t MinValidLen = 3;
  949. if (typeNameLen >= MinValidLen) {
  950. TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions);
  951. int suffixLen = *colCount == 0 ? 0 :
  952. *rowCount == 0 ? 1 : 3;
  953. HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-suffixLen, langOptions);
  954. if (type!= HLSLScalarType_unknown) {
  955. *parsedType = type;
  956. return true;
  957. }
  958. }
  959. return false;
  960. }
  961. /// <summary>Parse any kind of dimension for vector or matrix (e.g 4,3 in int4x3).
  962. /// If it's a matrix type, rowCount and colCount will be nonzero. If it's a vector type, colCount is 0.
  963. /// Otherwise both rowCount and colCount is 0. Returns true if either matrix or vector dimensions detected. </summary>
  964. _Use_decl_annotations_
  965. bool hlsl::TryParseMatrixOrVectorDimension(
  966. _In_count_(typeNameLen)
  967. const char *typeName,
  968. size_t typeNameLen,
  969. _Out_opt_ int *rowCount,
  970. _Out_opt_ int *colCount,
  971. _In_ const clang::LangOptions& langOptions) {
  972. *rowCount = 0;
  973. *colCount = 0;
  974. size_t MinValidLen = 3; // at least int
  975. if (typeNameLen > MinValidLen) {
  976. if (TryParseColOrRowChar(typeName[typeNameLen - 1], colCount)) {
  977. // Try parse matrix
  978. if (typeName[typeNameLen - 2] == 'x')
  979. TryParseColOrRowChar(typeName[typeNameLen - 3], rowCount);
  980. return true;
  981. }
  982. }
  983. return false;
  984. }
  985. /// <summary>Creates a typedef for a matrix shorthand (eg, float3x2).</summary>
  986. TypedefDecl* hlsl::CreateMatrixSpecializationShorthand(
  987. ASTContext& context,
  988. QualType matrixSpecialization,
  989. HLSLScalarType scalarType,
  990. size_t rowCount,
  991. size_t colCount)
  992. {
  993. DXASSERT(rowCount <= 4, "else caller didn't validate rowCount");
  994. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  995. char typeName[64];
  996. sprintf_s(typeName, _countof(typeName), "%s%ux%u",
  997. HLSLScalarTypeNames[scalarType], (unsigned)rowCount, (unsigned)colCount);
  998. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  999. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  1000. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  1001. context.getTrivialTypeSourceInfo(matrixSpecialization, NoLoc));
  1002. decl->setImplicit(true);
  1003. currentDeclContext->addDecl(decl);
  1004. return decl;
  1005. }
  1006. /// <summary>Creates a typedef for a vector shorthand (eg, float3).</summary>
  1007. TypedefDecl* hlsl::CreateVectorSpecializationShorthand(
  1008. ASTContext& context,
  1009. QualType vectorSpecialization,
  1010. HLSLScalarType scalarType,
  1011. size_t colCount)
  1012. {
  1013. DXASSERT(colCount <= 4, "else caller didn't validate colCount");
  1014. char typeName[64];
  1015. sprintf_s(typeName, _countof(typeName), "%s%u",
  1016. HLSLScalarTypeNames[scalarType], (unsigned)colCount);
  1017. IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
  1018. DeclContext* currentDeclContext = context.getTranslationUnitDecl();
  1019. TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
  1020. context.getTrivialTypeSourceInfo(vectorSpecialization, NoLoc));
  1021. decl->setImplicit(true);
  1022. currentDeclContext->addDecl(decl);
  1023. return decl;
  1024. }
  1025. llvm::ArrayRef<hlsl::UnusualAnnotation*>
  1026. hlsl::UnusualAnnotation::CopyToASTContextArray(
  1027. clang::ASTContext& Context, hlsl::UnusualAnnotation** begin, size_t count) {
  1028. if (count == 0) {
  1029. return llvm::ArrayRef<hlsl::UnusualAnnotation*>();
  1030. }
  1031. UnusualAnnotation** arr = ::new (Context) UnusualAnnotation*[count];
  1032. for (size_t i = 0; i < count; ++i) {
  1033. arr[i] = begin[i]->CopyToASTContext(Context);
  1034. }
  1035. return llvm::ArrayRef<hlsl::UnusualAnnotation*>(arr, count);
  1036. }
  1037. UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context) {
  1038. // All UnusualAnnotation instances can be blitted.
  1039. size_t instanceSize;
  1040. switch (Kind) {
  1041. case UA_RegisterAssignment:
  1042. instanceSize = sizeof(hlsl::RegisterAssignment);
  1043. break;
  1044. case UA_ConstantPacking:
  1045. instanceSize = sizeof(hlsl::ConstantPacking);
  1046. break;
  1047. default:
  1048. DXASSERT(Kind == UA_SemanticDecl, "Kind == UA_SemanticDecl -- otherwise switch is incomplete");
  1049. instanceSize = sizeof(hlsl::SemanticDecl);
  1050. break;
  1051. }
  1052. void* result = Context.Allocate(instanceSize);
  1053. memcpy(result, this, instanceSize);
  1054. return (UnusualAnnotation*)result;
  1055. }
  1056. static bool HasTessFactorSemantic(const ValueDecl *decl) {
  1057. for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
  1058. if (it->getKind() == UnusualAnnotation::UA_SemanticDecl) {
  1059. const SemanticDecl *sd = cast<SemanticDecl>(it);
  1060. const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
  1061. if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
  1062. return true;
  1063. }
  1064. }
  1065. return false;
  1066. }
  1067. static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
  1068. if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
  1069. return false;
  1070. if (const RecordType *RT = Ty->getAsStructureType()) {
  1071. RecordDecl *RD = RT->getDecl();
  1072. for (FieldDecl *fieldDecl : RD->fields()) {
  1073. if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
  1074. return true;
  1075. }
  1076. return false;
  1077. }
  1078. if (Ty->getAsArrayTypeUnsafe())
  1079. return HasTessFactorSemantic(decl);
  1080. return false;
  1081. }
  1082. bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
  1083. // This checks whether the function is structurally capable of being a patch
  1084. // constant function, not whether it is in fact the patch constant function
  1085. // for the entry point of a compiled hull shader (which may not have been
  1086. // seen yet). So the answer is conservative.
  1087. if (!FD->getReturnType()->isVoidType()) {
  1088. // Try to find TessFactor in return type.
  1089. if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
  1090. return true;
  1091. }
  1092. // Try to find TessFactor in out param.
  1093. for (const ParmVarDecl *param : FD->params()) {
  1094. if (param->hasAttr<HLSLOutAttr>()) {
  1095. if (HasTessFactorSemanticRecurse(param, param->getType()))
  1096. return true;
  1097. }
  1098. }
  1099. return false;
  1100. }