ソースを参照

Refactor builtin type declaration (#2340)

There's a lot of repeated boilerplate needed to declare builtin types with the Clang API. This change encapsulates it in a friendlier "Builder" class.
Tristan Labelle 6 年 前
コミット
2d6803382f

+ 54 - 0
tools/clang/include/clang/AST/HlslBuiltinTypeDeclBuilder.h

@@ -0,0 +1,54 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#ifndef LLVM_CLANG_AST_HLSLBUILTINTYPEDECLBUILDER_H
+#define LLVM_CLANG_AST_HLSLBUILTINTYPEDECLBUILDER_H
+
+#include "clang/AST/Decl.h"
+#include "clang/AST/Type.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace clang {
+  class ASTContext;
+  class DeclContext;
+  class CXXRecordDecl;
+  class ClassTemplateDecl;
+  class NamedDecl;
+}
+
+namespace hlsl {
+// Helper to declare a builtin HLSL type in the clang AST with minimal boilerplate.
+class BuiltinTypeDeclBuilder final {
+public:
+  BuiltinTypeDeclBuilder(clang::DeclContext* declContext, llvm::StringRef name,
+    clang::TagDecl::TagKind tagKind = clang::TagDecl::TagKind::TTK_Class);
+
+  clang::TemplateTypeParmDecl* addTypeTemplateParam(llvm::StringRef name, clang::TypeSourceInfo* defaultValue = nullptr);
+  clang::TemplateTypeParmDecl* addTypeTemplateParam(llvm::StringRef name, clang::QualType defaultValue);
+  clang::NonTypeTemplateParmDecl* addIntegerTemplateParam(llvm::StringRef name, clang::QualType type,
+    llvm::Optional<int64_t> defaultValue = llvm::None);
+
+  void startDefinition();
+
+  clang::FieldDecl* addField(llvm::StringRef name, clang::QualType type,
+    clang::AccessSpecifier access = clang::AccessSpecifier::AS_private);
+
+  clang::CXXRecordDecl* completeDefinition();
+
+  clang::CXXRecordDecl* getRecordDecl() const { return m_recordDecl; }
+  clang::ClassTemplateDecl* getTemplateDecl() const;
+
+private:
+  clang::CXXRecordDecl* m_recordDecl = nullptr;
+  clang::ClassTemplateDecl* m_templateDecl = nullptr;
+  llvm::SmallVector<clang::NamedDecl*, 2> m_templateParams;
+};
+} // end hlsl namespace
+#endif

+ 9 - 20
tools/clang/include/clang/AST/HlslTypes.h

@@ -22,7 +22,9 @@
 #include "dxc/DXIL/DxilConstants.h"
 #include "dxc/Support/WinAdapter.h"
 #include "llvm/Support/Casting.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
 
 namespace clang {
   class ASTContext;
@@ -297,16 +299,10 @@ void AddHLSLVectorTemplate(
   clang::ASTContext& context, 
   _Outptr_ clang::ClassTemplateDecl** vectorTemplateDecl);
 
-void AddRecordTypeWithHandle(
-            clang::ASTContext& context,
-  _Outptr_  clang::CXXRecordDecl** typeDecl, 
-  _In_z_    const char* typeName);
+clang::CXXRecordDecl* DeclareRecordTypeWithHandle(
+  clang::ASTContext& context, llvm::StringRef name);
 
-void AddRayFlags(clang::ASTContext& context);
-void AddHitKinds(clang::ASTContext& context);
-void AddStateObjectFlags(clang::ASTContext& context);
-void AddCommittedStatus(clang::ASTContext& context);
-void AddCandidateType(clang::ASTContext& context);
+void AddRayTracingConstants(clang::ASTContext& context);
 
 /// <summary>Adds the implementation for std::is_equal.</summary>
 void AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema);
@@ -315,23 +311,16 @@ void AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema);
 /// Adds a new template type in the specified context with the given name. The record type will have a handle field.
 /// </summary>
 /// <parm name="context">AST context to which template will be added.</param>
-/// <parm name="typeDecl">After execution, template declaration.</param>
-/// <parm name="recordDecl">After execution, record declaration for template.</param>
-/// <parm name="typeName">Name of template to create.</param>
 /// <parm name="templateArgCount">Number of template arguments (one or two).</param>
 /// <parm name="defaultTypeArgValue">If assigned, the default argument for the element template.</param>
-void AddTemplateTypeWithHandle(
+clang::CXXRecordDecl* DeclareTemplateTypeWithHandle(
             clang::ASTContext& context,
-  _Outptr_  clang::ClassTemplateDecl** typeDecl,
-  _Outptr_  clang::CXXRecordDecl** recordDecl,
-  _In_z_    const char* typeName,
+            llvm::StringRef name,
             uint8_t templateArgCount,
   _In_opt_  clang::TypeSourceInfo* defaultTypeArgValue);
 
-void AddRayQueryTemplate(
-           clang::ASTContext& context,
-  _Outptr_ clang::ClassTemplateDecl** typeDecl,
-  _Outptr_ clang::CXXRecordDecl** recordDecl);
+clang::CXXRecordDecl* DeclareUIntTemplatedTypeWithHandle(
+  clang::ASTContext& context, llvm::StringRef typeName, llvm::StringRef templateParamName);
 
 /// <summary>Create a function template declaration for the specified method.</summary>
 /// <param name="context">AST context in which to work.</param>

+ 96 - 364
tools/clang/lib/AST/ASTContextHLSL.cpp

@@ -17,6 +17,7 @@
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
 #include "clang/AST/ExternalASTSource.h"
+#include "clang/AST/HlslBuiltinTypeDeclBuilder.h"
 #include "clang/AST/TypeLoc.h"
 #include "clang/Sema/SemaDiagnostic.h"
 #include "clang/Sema/Sema.h"
@@ -238,22 +239,6 @@ CanQualType GetHLSLObjectHandleType(ASTContext& context)
   return context.IntTy;
 }
 
-/// <summary>Adds a handle field to the specified record.</summary>
-static 
-void AddHLSLHandleField(ASTContext& context, DeclContext* recordDecl, QualType handleQualType)
-{
-  IdentifierInfo& handleId = context.Idents.get(StringRef("h"), tok::TokenKind::identifier);
-  TypeSourceInfo* fieldTypeSource = context.getTrivialTypeSourceInfo(handleQualType, NoLoc);
-  const bool MutableFalse = false;
-  const InClassInitStyle initStyle = InClassInitStyle::ICIS_NoInit;
-  FieldDecl* handleDecl = FieldDecl::Create(
-    context, recordDecl, NoLoc, NoLoc, &handleId, handleQualType, fieldTypeSource, nullptr, MutableFalse, initStyle);
-  handleDecl->setAccess(AccessSpecifier::AS_private);
-  handleDecl->setImplicit(true); 
-
-  recordDecl->addDecl(handleDecl);
-}
-
 static
 void AddSubscriptOperator(
   ASTContext& context, unsigned int templateDepth, TemplateTypeParmDecl *elementTemplateParamDecl,
@@ -297,52 +282,15 @@ void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorT
   DXASSERT_NOMSG(matrixTemplateDecl != nullptr);
   DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
 
-  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
-
   // Create a matrix template declaration in translation unit scope.
   // template<typename element, int row_count, int col_count> matrix { ... }
-  IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
-  TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
-  elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
-  QualType intType = context.IntTy;
-  Expr *literalIntFour = IntegerLiteral::Create(
-      context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
-  IdentifierInfo& rowCountParamId = context.Idents.get(StringRef("row_count"), tok::TokenKind::identifier);
-  NonTypeTemplateParmDecl* rowCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition + 1, &rowCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
-  rowCountTemplateParamDecl->setDefaultArgument(literalIntFour);
-  IdentifierInfo& colCountParamId = context.Idents.get(StringRef("col_count"), tok::TokenKind::identifier);
-  NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition + 2, &colCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
-  colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
-  NamedDecl* templateParameters[] =
-  {
-    elementTemplateParamDecl, rowCountTemplateParamDecl, colCountTemplateParamDecl
-  };
-  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
-    context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
-
-  IdentifierInfo& matrixId = context.Idents.get(StringRef("matrix"), tok::TokenKind::identifier);
-  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
-    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &matrixId,
-    nullptr, DelayTypeCreationTrue);
-  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
-    context, currentDeclContext, NoLoc, DeclarationName(&matrixId),
-    templateParameterList, templateRecordDecl, nullptr);
-  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
-  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
-
-  // Requesting the class name specialization will fault in required types.
-  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
-  T = context.getInjectedClassNameType(templateRecordDecl, T);
-  assert(T->isDependentType() && "Class template type is not dependent?");
-  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->startDefinition();
+  BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "matrix");
+  TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", (QualType)context.FloatTy);
+  NonTypeTemplateParmDecl* rowCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("row_count", context.IntTy, 4);
+  NonTypeTemplateParmDecl* colCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("col_count", context.IntTy, 4);
+  typeDeclBuilder.startDefinition();
+  CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
+  ClassTemplateDecl* classTemplateDecl = typeDeclBuilder.getTemplateDecl();
 
   // Add an 'h' field to hold the handle.
   // The type is vector<element, col>[row].
@@ -352,20 +300,20 @@ void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorT
       context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
       false,
       DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
-      intType, ExprValueKind::VK_RValue);
+      context.IntTy, ExprValueKind::VK_RValue);
 
   Expr *rowSizeExpr = DeclRefExpr::Create(
       context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
       false,
       DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
-      intType, ExprValueKind::VK_RValue);
+    context.IntTy, ExprValueKind::VK_RValue);
 
   QualType vectorType = context.getDependentSizedExtVectorType(
       elementType, rowSizeExpr, SourceLocation());
   QualType vectorArrayType = context.getDependentSizedArrayType(
       vectorType, sizeExpr, ArrayType::Normal, 0, SourceRange());
 
-  AddHLSLHandleField(context, templateRecordDecl, vectorArrayType);
+  typeDeclBuilder.addField("h", vectorArrayType);
 
   // Add an operator[]. The operator ranges from zero to rowcount-1, and returns a vector of colcount elements.
   const unsigned int templateDepth = 0;
@@ -376,22 +324,7 @@ void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorT
                        colCountTemplateParamDecl, context.UnsignedIntTy,
                        templateRecordDecl, vectorTemplateDecl, ForConstTrue);
 
-  templateRecordDecl->completeDefinition();
-
-  classTemplateDecl->setImplicit(true);
-  templateRecordDecl->setImplicit(true);
-
-  // Both declarations need to be present for correct handling.
-  currentDeclContext->addDecl(classTemplateDecl);
-  currentDeclContext->addDecl(templateRecordDecl);
-
-#ifdef DBG
-  // Verify that we can read the field member from the template record.
-  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
-    DeclarationName(&context.Idents.get(StringRef("h"))));
-  DXASSERT(!lookupResult.empty(), "otherwise matrix handle cannot be looked up");
-#endif
-
+  typeDeclBuilder.completeDefinition();
   *matrixTemplateDecl = classTemplateDecl;
 }
 
@@ -405,53 +338,20 @@ void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vector
 {
   DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
 
-  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
-
   // Create a vector template declaration in translation unit scope.
   // template<typename element, int element_count> vector { ... }
-  IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
-  TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
-  elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
-  QualType intType = context.IntTy;
-  Expr *literalIntFour = IntegerLiteral::Create(
-      context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
-  IdentifierInfo& colCountParamId = context.Idents.get(StringRef("element_count"), tok::TokenKind::identifier);
-  NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition + 1, &colCountParamId, intType, ParameterPackFalse, nullptr);
-  colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
-  NamedDecl* templateParameters[] =
-  {
-    elementTemplateParamDecl, colCountTemplateParamDecl
-  };
-  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
-    context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
-
-  IdentifierInfo& vectorId = context.Idents.get(StringRef("vector"), tok::TokenKind::identifier);
-  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
-    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &vectorId,
-    nullptr, DelayTypeCreationTrue);
-  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
-    context, currentDeclContext, NoLoc, DeclarationName(&vectorId),
-    templateParameterList, templateRecordDecl, nullptr);
-  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
-  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
-
-  // Requesting the class name specialization will fault in required types.
-  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
-  T = context.getInjectedClassNameType(templateRecordDecl, T);
-  assert(T->isDependentType() && "Class template type is not dependent?");
-  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->startDefinition();
+  BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), "vector");
+  TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", (QualType)context.FloatTy);
+  NonTypeTemplateParmDecl* colCountTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("element_count", context.IntTy, 4);
+  typeDeclBuilder.startDefinition();
+  CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
+  ClassTemplateDecl* classTemplateDecl = typeDeclBuilder.getTemplateDecl();
 
   Expr *vecSizeExpr = DeclRefExpr::Create(
       context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
       false,
       DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
-      intType, ExprValueKind::VK_RValue);
+      context.IntTy, ExprValueKind::VK_RValue);
 
   const unsigned int templateDepth = 0;
   QualType resultType = context.getTemplateTypeParmType(
@@ -459,7 +359,7 @@ void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vector
   QualType vectorType = context.getDependentSizedExtVectorType(
       resultType, vecSizeExpr, SourceLocation());
   // Add an 'h' field to hold the handle.
-  AddHLSLHandleField(context, templateRecordDecl, vectorType);
+  typeDeclBuilder.addField("h", vectorType);
 
   // Add an operator[]. The operator ranges from zero to colcount-1, and returns a scalar.
 
@@ -478,48 +378,18 @@ void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vector
     context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstFalse);
   AddHLSLVectorSubscriptAttr(functionDecl, context);
 
-  templateRecordDecl->completeDefinition();
-
-  classTemplateDecl->setImplicit(true);
-  templateRecordDecl->setImplicit(true);
-
-  // Both declarations need to be present for correct handling.
-  currentDeclContext->addDecl(classTemplateDecl);
-  currentDeclContext->addDecl(templateRecordDecl);
-
-#ifdef DBG
-  // Verify that we can read the field member from the template record.
-  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
-    DeclarationName(&context.Idents.get(StringRef("h"))));
-  DXASSERT(!lookupResult.empty(), "otherwise vector handle cannot be looked up");
-#endif
-
+  typeDeclBuilder.completeDefinition();
   *vectorTemplateDecl = classTemplateDecl;
 }
 
 /// <summary>
 /// Adds a new record type in the specified context with the given name. The record type will have a handle field.
 /// </summary>
-void hlsl::AddRecordTypeWithHandle(ASTContext& context, _Outptr_ CXXRecordDecl** typeDecl, _In_z_ const char* typeName)
-{
-  DXASSERT_NOMSG(typeDecl != nullptr);
-  DXASSERT_NOMSG(typeName != nullptr);
-  
-  *typeDecl = nullptr;
-
-  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
-  IdentifierInfo& newTypeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
-  CXXRecordDecl* newDecl = CXXRecordDecl::Create(
-    context, TagDecl::TagKind::TTK_Struct, currentDeclContext, NoLoc, NoLoc, &newTypeId, nullptr);
-  newDecl->setLexicalDeclContext(currentDeclContext);
-  newDecl->setFreeStanding();
-  newDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
-  newDecl->startDefinition();
-  AddHLSLHandleField(context, newDecl, QualType(GetHLSLObjectHandleType(context)));
-  currentDeclContext->addDecl(newDecl);
-  newDecl->completeDefinition();
-
-  *typeDecl = newDecl;
+CXXRecordDecl* hlsl::DeclareRecordTypeWithHandle(ASTContext& context, StringRef name) {
+  BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name, TagDecl::TagKind::TTK_Struct);
+  typeDeclBuilder.startDefinition();
+  typeDeclBuilder.addField("h", GetHLSLObjectHandleType(context));
+  return typeDeclBuilder.completeDefinition();
 }
 
 // creates a global static constant unsigned integer with value.
@@ -537,71 +407,58 @@ static void AddConstUInt(clang::ASTContext& context, DeclContext *DC, StringRef
   DC->addDecl(varDecl);
 }
 
-/// <summary> Adds a const integers for ray flags </summary>
-void hlsl::AddRayFlags(ASTContext& context) {
-  DeclContext *curDC = context.getTranslationUnitDecl();
-  // typedef uint RAY_FLAG;
-  IdentifierInfo &rayFlagId = context.Idents.get(StringRef("RAY_FLAG"), tok::TokenKind::identifier);
-  TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
-  TypedefDecl *rayFlagDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &rayFlagId, uintTypeSource);
-  curDC->addDecl(rayFlagDecl);
-  rayFlagDecl->setImplicit(true);
-  // static const uint RAY_FLAG_* = *;
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_NONE"), (unsigned)DXIL::RayFlag::None);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_FORCE_OPAQUE"), (unsigned)DXIL::RayFlag::ForceOpaque);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_FORCE_NON_OPAQUE"), (unsigned)DXIL::RayFlag::ForceNonOpaque);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH"), (unsigned)DXIL::RayFlag::AcceptFirstHitAndEndSearch);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_SKIP_CLOSEST_HIT_SHADER"), (unsigned)DXIL::RayFlag::SkipClosestHitShader);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_BACK_FACING_TRIANGLES"), (unsigned)DXIL::RayFlag::CullBackFacingTriangles);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_FRONT_FACING_TRIANGLES"), (unsigned)DXIL::RayFlag::CullFrontFacingTriangles);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_OPAQUE"), (unsigned)DXIL::RayFlag::CullOpaque);
-  AddConstUInt(context, curDC, StringRef("RAY_FLAG_CULL_NON_OPAQUE"), (unsigned)DXIL::RayFlag::CullNonOpaque);
+static void AddConstUInt(clang::ASTContext& context, StringRef name, unsigned val) {
+  AddConstUInt(context, context.getTranslationUnitDecl(), name, val);
 }
 
-/// <summary> Adds a constant integers for hit kinds </summary>
-void hlsl::AddHitKinds(ASTContext& context) {
-  DeclContext *curDC = context.getTranslationUnitDecl();
-  // static const uint HIT_KIND_* = *;
-  AddConstUInt(context, curDC, StringRef("HIT_KIND_NONE"), (unsigned)DXIL::HitKind::None);
-  AddConstUInt(context, curDC, StringRef("HIT_KIND_TRIANGLE_FRONT_FACE"), (unsigned)DXIL::HitKind::TriangleFrontFace);
-  AddConstUInt(context, curDC, StringRef("HIT_KIND_TRIANGLE_BACK_FACE"), (unsigned)DXIL::HitKind::TriangleBackFace);
-}
-
-/// <summary> Adds a constant integers for state object flags </summary>
-void hlsl::AddStateObjectFlags(ASTContext& context) {
-  DeclContext *curDC = context.getTranslationUnitDecl();
- 
-  AddConstUInt(context, curDC, StringRef("STATE_OBJECT_FLAGS_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS"), (unsigned)DXIL::StateObjectFlags::AllowLocalDependenciesOnExternalDefinitions);
-  AddConstUInt(context, curDC, StringRef("STATE_OBJECT_FLAGS_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
-}
-
-/// <summary> Adds const integers for committed status </summary>
-void hlsl::AddCommittedStatus(ASTContext& context) {
-  DeclContext *curDC = context.getTranslationUnitDecl();
-  // typedef uint COMMITTED_STATUS;
-  IdentifierInfo &enumId = context.Idents.get(StringRef("COMMITTED_STATUS"), tok::TokenKind::identifier);
-  TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
-  TypedefDecl *enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
+// Adds a top-level enum with the given enumerants.
+struct Enumerant { StringRef name; unsigned value; };
+static void AddTypedefPseudoEnum(ASTContext& context, StringRef name, ArrayRef<Enumerant> enumerants) {
+  DeclContext* curDC = context.getTranslationUnitDecl();
+  // typedef uint <name>;
+  IdentifierInfo& enumId = context.Idents.get(name, tok::TokenKind::identifier);
+  TypeSourceInfo* uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
+  TypedefDecl* enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
   curDC->addDecl(enumDecl);
   enumDecl->setImplicit(true);
-  // static const uint COMMITTED_* = *;
-  AddConstUInt(context, curDC, StringRef("COMMITTED_NOTHING"), (unsigned)DXIL::CommittedStatus::CommittedNothing);
-  AddConstUInt(context, curDC, StringRef("COMMITTED_TRIANGLE_HIT"), (unsigned)DXIL::CommittedStatus::CommittedTriangleHit);
-  AddConstUInt(context, curDC, StringRef("COMMITTED_PROCEDURAL_PRIMITIVE_HIT"), (unsigned)DXIL::CommittedStatus::CommittedProceduralPrimitiveHit);
+  // static const uint <enumerant.name> = <enumerant.value>;
+  for (const Enumerant& enumerant : enumerants) {
+    AddConstUInt(context, curDC, enumerant.name, enumerant.value);
+  }
 }
 
-/// <summary> Adds const integers for candidate type </summary>
-void hlsl::AddCandidateType(ASTContext& context) {
-  DeclContext *curDC = context.getTranslationUnitDecl();
-  // typedef uint CANDIDATE_TYPE;
-  IdentifierInfo &enumId = context.Idents.get(StringRef("CANDIDATE_TYPE"), tok::TokenKind::identifier);
-  TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
-  TypedefDecl *enumDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &enumId, uintTypeSource);
-  curDC->addDecl(enumDecl);
-  enumDecl->setImplicit(true);
-  // static const uint CANDIDATE_* = *;
-  AddConstUInt(context, curDC, StringRef("CANDIDATE_NON_OPAQUE_TRIANGLE"), (unsigned)DXIL::CandidateType::CandidateNonOpaqueTriangle);
-  AddConstUInt(context, curDC, StringRef("CANDIDATE_PROCEDURAL_PRIMITIVE"), (unsigned)DXIL::CandidateType::CandidateProceduralPrimitive);
+/// <summary> Adds all constants and enums for ray tracing </summary>
+void hlsl::AddRayTracingConstants(ASTContext& context) {
+  AddTypedefPseudoEnum(context, "RAY_FLAG", {
+    { "RAY_FLAG_NONE", (unsigned)DXIL::RayFlag::None },
+    { "RAY_FLAG_FORCE_OPAQUE", (unsigned)DXIL::RayFlag::ForceOpaque },
+    { "RAY_FLAG_FORCE_NON_OPAQUE", (unsigned)DXIL::RayFlag::ForceNonOpaque },
+    { "RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH", (unsigned)DXIL::RayFlag::AcceptFirstHitAndEndSearch },
+    { "RAY_FLAG_SKIP_CLOSEST_HIT_SHADER", (unsigned)DXIL::RayFlag::SkipClosestHitShader },
+    { "RAY_FLAG_CULL_BACK_FACING_TRIANGLES", (unsigned)DXIL::RayFlag::CullBackFacingTriangles },
+    { "RAY_FLAG_CULL_FRONT_FACING_TRIANGLES", (unsigned)DXIL::RayFlag::CullFrontFacingTriangles },
+    { "RAY_FLAG_CULL_OPAQUE", (unsigned)DXIL::RayFlag::CullOpaque },
+    { "RAY_FLAG_CULL_NON_OPAQUE", (unsigned)DXIL::RayFlag::CullNonOpaque },
+  });
+
+  AddTypedefPseudoEnum(context, "COMMITTED_STATUS", {
+    { "COMMITTED_NOTHING", (unsigned)DXIL::CommittedStatus::CommittedNothing },
+    { "COMMITTED_TRIANGLE_HIT", (unsigned)DXIL::CommittedStatus::CommittedTriangleHit },
+    { "COMMITTED_PROCEDURAL_PRIMITIVE_HIT", (unsigned)DXIL::CommittedStatus::CommittedProceduralPrimitiveHit }
+  });
+
+  AddTypedefPseudoEnum(context, "CANDIDATE_TYPE", {
+    { "CANDIDATE_NON_OPAQUE_TRIANGLE", (unsigned)DXIL::CandidateType::CandidateNonOpaqueTriangle },
+    { "CANDIDATE_PROCEDURAL_PRIMITIVE", (unsigned)DXIL::CandidateType::CandidateProceduralPrimitive }
+  });
+
+  // static const uint HIT_KIND_* = *;
+  AddConstUInt(context, StringRef("HIT_KIND_NONE"), (unsigned)DXIL::HitKind::None);
+  AddConstUInt(context, StringRef("HIT_KIND_TRIANGLE_FRONT_FACE"), (unsigned)DXIL::HitKind::TriangleFrontFace);
+  AddConstUInt(context, StringRef("HIT_KIND_TRIANGLE_BACK_FACE"), (unsigned)DXIL::HitKind::TriangleBackFace);
+
+  AddConstUInt(context, StringRef("STATE_OBJECT_FLAGS_ALLOW_LOCAL_DEPENDENCIES_ON_EXTERNAL_DEFINITONS"), (unsigned)DXIL::StateObjectFlags::AllowLocalDependenciesOnExternalDefinitions);
+  AddConstUInt(context, StringRef("STATE_OBJECT_FLAGS_ALLOW_EXTERNAL_DEPENDENCIES_ON_LOCAL_DEFINITIONS"), (unsigned)DXIL::StateObjectFlags::AllowExternalDependenciesOnLocalDefinitions);
 }
 
 static
@@ -765,79 +622,29 @@ void hlsl::AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema&
 /// Adds a new template type in the specified context with the given name. The record type will have a handle field.
 /// </summary>
 /// <parm name="context">AST context to which template will be added.</param>
-/// <parm name="typeDecl">After execution, template declaration.</param>
-/// <parm name="recordDecl">After execution, record declaration for template.</param>
 /// <parm name="typeName">Name of template to create.</param>
 /// <parm name="templateArgCount">Number of template arguments (one or two).</param>
 /// <parm name="defaultTypeArgValue">If assigned, the default argument for the element template.</param>
-void hlsl::AddTemplateTypeWithHandle(
+CXXRecordDecl* hlsl::DeclareTemplateTypeWithHandle(
   ASTContext& context,
-  _Outptr_ ClassTemplateDecl** typeDecl,
-  _Outptr_ CXXRecordDecl** recordDecl,
-  _In_z_ const char* typeName,
+  StringRef name,
   uint8_t templateArgCount, 
-  _In_opt_ TypeSourceInfo* defaultTypeArgValue
-)
+  _In_opt_ TypeSourceInfo* defaultTypeArgValue)
 {
-  DXASSERT_NOMSG(typeDecl != nullptr);
-  DXASSERT_NOMSG(recordDecl != nullptr);
-  DXASSERT_NOMSG(typeName != nullptr);
-
   DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
   DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");
 
-  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
-
   // Create an object template declaration in translation unit scope.
   // templateArgCount=1: template<typename element> typeName { ... }
   // templateArgCount=2: template<typename element, int count> typeName { ... }
-  IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
-  TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
-  QualType intType = context.IntTy;
-
-  if (defaultTypeArgValue != nullptr)
-  {
-    elementTemplateParamDecl->setDefaultArgument(defaultTypeArgValue);
-  }
-
+  BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), name);
+  TemplateTypeParmDecl* elementTemplateParamDecl = typeDeclBuilder.addTypeTemplateParam("element", defaultTypeArgValue);
   NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
-  if (templateArgCount > 1) {
-    IdentifierInfo& countParamId = context.Idents.get(StringRef("count"), tok::TokenKind::identifier);
-    countTemplateParamDecl = NonTypeTemplateParmDecl::Create(
-      context, currentDeclContext, NoLoc, NoLoc,
-      FirstTemplateDepth, FirstParamPosition + 1, &countParamId, intType, ParameterPackFalse, nullptr);
-    // Zero means default here. The count is decided by runtime.
-    Expr *literalIntZero = IntegerLiteral::Create(
-        context, llvm::APInt(context.getIntWidth(intType), 0), intType, NoLoc);
-    countTemplateParamDecl->setDefaultArgument(literalIntZero);
-  }
-  NamedDecl* templateParameters[] =
-  {
-    elementTemplateParamDecl, countTemplateParamDecl
-  };
-  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
-    context, NoLoc, NoLoc, templateParameters, templateArgCount, NoLoc);
-
-  IdentifierInfo& typeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
-  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
-    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
-    nullptr, DelayTypeCreationTrue);
-  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
-    context, currentDeclContext, NoLoc, DeclarationName(&typeId),
-    templateParameterList, templateRecordDecl, nullptr);
-  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
-  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
-  
-  // Requesting the class name specialization will fault in required types.
-  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
-  T = context.getInjectedClassNameType(templateRecordDecl, T);
-  assert(T->isDependentType() && "Class template type is not dependent?");
-  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->startDefinition();
-  // Many more things to come here, like constructors and the like....
+  if (templateArgCount > 1)
+    countTemplateParamDecl = typeDeclBuilder.addIntegerTemplateParam("count", context.IntTy, 0);
+
+  typeDeclBuilder.startDefinition();
+  CXXRecordDecl* templateRecordDecl = typeDeclBuilder.getRecordDecl();
 
   // Add an 'h' field to hold the handle.
   QualType elementType = context.getTemplateTypeParmType(
@@ -847,11 +654,11 @@ void hlsl::AddTemplateTypeWithHandle(
       // Only need array type for inputpatch and outputpatch.
       // Avoid Texture2DMS which may use 0 count.
       // TODO: use hlsl types to do the check.
-      !typeId.getName().startswith("Texture")) {
+      !name.startswith("Texture")) {
     Expr *countExpr = DeclRefExpr::Create(
         context, NestedNameSpecifierLoc(), NoLoc, countTemplateParamDecl, false,
         DeclarationNameInfo(countTemplateParamDecl->getDeclName(), NoLoc),
-        intType, ExprValueKind::VK_RValue);
+        context.IntTy, ExprValueKind::VK_RValue);
 
     elementType = context.getDependentSizedArrayType(
         elementType, countExpr, ArrayType::ArraySizeModifier::Normal, 0,
@@ -859,31 +666,17 @@ void hlsl::AddTemplateTypeWithHandle(
 
     // InputPatch and OutputPatch also have a "Length" static const member for the number of control points
     IdentifierInfo& lengthId = context.Idents.get(StringRef("Length"), tok::TokenKind::identifier);
-    TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(intType.withConst());
+    TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(context.IntTy.withConst());
     VarDecl* lengthValueDecl = VarDecl::Create(context, templateRecordDecl, NoLoc, NoLoc, &lengthId,
-      intType.withConst(), lengthTypeSource, SC_Static);
+      context.IntTy.withConst(), lengthTypeSource, SC_Static);
     lengthValueDecl->setInit(countExpr);
     lengthValueDecl->setAccess(AS_public);
     templateRecordDecl->addDecl(lengthValueDecl);
   }
 
-  AddHLSLHandleField(context, templateRecordDecl, elementType);
-
-  templateRecordDecl->completeDefinition();
+  typeDeclBuilder.addField("h", elementType);
 
-  // Both declarations need to be present for correct handling.
-  currentDeclContext->addDecl(classTemplateDecl);
-  currentDeclContext->addDecl(templateRecordDecl);
-
-#ifdef DBG
-  // Verify that we can read the field member from the template record.
-  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
-    DeclarationName(&context.Idents.get(StringRef("h"))));
-  DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
-#endif
-
-  *typeDecl = classTemplateDecl;
-  *recordDecl = templateRecordDecl;
+  return typeDeclBuilder.completeDefinition();
 }
 
 FunctionTemplateDecl* hlsl::CreateFunctionTemplateDecl(
@@ -988,75 +781,14 @@ CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
   return functionDecl;
 }
 
-void hlsl::AddRayQueryTemplate(
-  ASTContext& context,
-  _Outptr_ ClassTemplateDecl** typeDecl,
-  _Outptr_ CXXRecordDecl** recordDecl
-)
-{
-  DXASSERT_NOMSG(typeDecl != nullptr);
-  DXASSERT_NOMSG(recordDecl != nullptr);
-
-  DeclContext* currentDeclContext = context.getTranslationUnitDecl();
-
-  // Create a RayQuery template declaration in translation unit scope.
-  // template<uint flags> RayQuery { ... }
-  QualType uintType = context.UnsignedIntTy;
-
-  NonTypeTemplateParmDecl* flagsTemplateParamDecl = nullptr;
-  IdentifierInfo& countParamId = context.Idents.get(StringRef("flags"), tok::TokenKind::identifier);
-  flagsTemplateParamDecl = NonTypeTemplateParmDecl::Create(
-    context, currentDeclContext, NoLoc, NoLoc,
-    FirstTemplateDepth, FirstParamPosition, &countParamId, uintType, ParameterPackFalse, nullptr);
-
-  // Should flags default to zero?
-  Expr *literalIntZero = IntegerLiteral::Create(
-    context, llvm::APInt(context.getIntWidth(uintType), 0), uintType, NoLoc);
-  flagsTemplateParamDecl->setDefaultArgument(literalIntZero);
-
-  NamedDecl* templateParameters[] =
-  {
-    flagsTemplateParamDecl
-  };
-  TemplateParameterList* templateParameterList = TemplateParameterList::Create(
-    context, NoLoc, NoLoc, templateParameters, 1, NoLoc);
-
-  IdentifierInfo& typeId = context.Idents.get(StringRef("RayQuery"), tok::TokenKind::identifier);
-  CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
-    context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
-    nullptr, DelayTypeCreationTrue);
-  ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
-    context, currentDeclContext, NoLoc, DeclarationName(&typeId),
-    templateParameterList, templateRecordDecl, nullptr);
-  templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
-  templateRecordDecl->addAttr(FinalAttr::CreateImplicit(context, FinalAttr::Keyword_final));
-
-  // Requesting the class name specialization will fault in required types.
-  QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
-  T = context.getInjectedClassNameType(templateRecordDecl, T);
-  assert(T->isDependentType() && "Class template type is not dependent?");
-  classTemplateDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->setLexicalDeclContext(currentDeclContext);
-  templateRecordDecl->startDefinition();
-
-  // Add an 'h' field to hold the handle.
-  AddHLSLHandleField(context, templateRecordDecl, uintType);
-
-  templateRecordDecl->completeDefinition();
-
-  // Both declarations need to be present for correct handling.
-  currentDeclContext->addDecl(classTemplateDecl);
-  currentDeclContext->addDecl(templateRecordDecl);
-
-#ifdef DBG
-  // Verify that we can read the field member from the template record.
-  DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
-    DeclarationName(&context.Idents.get(StringRef("h"))));
-  DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
-#endif
-
-  *typeDecl = classTemplateDecl;
-  *recordDecl = templateRecordDecl;
+CXXRecordDecl* hlsl::DeclareUIntTemplatedTypeWithHandle(
+  ASTContext& context, StringRef typeName, StringRef templateParamName) {
+  // template<uint kind> RayQuery/FeedbackTexture2D[Array] { ... }
+  BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), typeName);
+  typeDeclBuilder.addIntegerTemplateParam(templateParamName, context.UnsignedIntTy);
+  typeDeclBuilder.startDefinition();
+  typeDeclBuilder.addField("h", context.UnsignedIntTy); // Add an 'h' field to hold the handle.
+  return typeDeclBuilder.completeDefinition();
 }
 
 bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {

+ 2 - 1
tools/clang/lib/AST/CMakeLists.txt

@@ -36,7 +36,8 @@ add_clang_library(clangAST
   ExprConstant.cpp
   ExprCXX.cpp
   ExternalASTSource.cpp
-  HlslTypes.cpp
+  HlslBuiltinTypeDeclBuilder.cpp # HLSL Change
+  HlslTypes.cpp # HLSL Change
   InheritViz.cpp
   ItaniumCXXABI.cpp
   ItaniumMangle.cpp

+ 131 - 0
tools/clang/lib/AST/HlslBuiltinTypeDeclBuilder.cpp

@@ -0,0 +1,131 @@
+///////////////////////////////////////////////////////////////////////////////
+//                                                                           //
+// Copyright (C) Microsoft Corporation. All rights reserved.                 //
+// This file is distributed under the University of Illinois Open Source     //
+// License. See LICENSE.TXT for details.                                     //
+//                                                                           //
+///////////////////////////////////////////////////////////////////////////////
+
+#include "clang/AST/HlslBuiltinTypeDeclBuilder.h"
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/DeclTemplate.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/ExprCXX.h"
+#include "clang/AST/TypeLoc.h"
+#include "dxc/Support/Global.h"
+
+using namespace clang;
+using namespace hlsl;
+
+static const SourceLocation NoLoc; // no source location attribution available
+
+BuiltinTypeDeclBuilder::BuiltinTypeDeclBuilder(DeclContext* declContext, StringRef name, TagDecl::TagKind tagKind) {
+  ASTContext& astContext = declContext->getParentASTContext();
+  IdentifierInfo& nameId = astContext.Idents.get(name, tok::TokenKind::identifier);
+  m_recordDecl = CXXRecordDecl::Create(astContext, tagKind, declContext, NoLoc, NoLoc, &nameId, nullptr, /* DelayTypeCreation */ true);
+  m_recordDecl->setImplicit(true);
+  declContext->addDecl(m_recordDecl);
+}
+
+TemplateTypeParmDecl* BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name, TypeSourceInfo* defaultValue) {
+  DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() && !m_recordDecl->isCompleteDefinition());
+
+  ASTContext& astContext = m_recordDecl->getASTContext();
+  unsigned index = (unsigned)m_templateParams.size();
+  TemplateTypeParmDecl* decl = TemplateTypeParmDecl::Create(
+    astContext, m_recordDecl->getDeclContext(), NoLoc, NoLoc, /* TemplateDepth */ 0, index,
+    &astContext.Idents.get(name, tok::TokenKind::identifier), /* Typename */ false, /* ParameterPack */ false);
+  if (defaultValue!= nullptr) decl->setDefaultArgument(defaultValue);
+  m_templateParams.emplace_back(decl);
+  return decl;
+}
+
+TemplateTypeParmDecl* BuiltinTypeDeclBuilder::addTypeTemplateParam(StringRef name, QualType defaultValue) {
+  TypeSourceInfo* defaultValueSourceInfo = nullptr;
+  if (!defaultValue.isNull()) defaultValueSourceInfo = m_recordDecl->getASTContext().getTrivialTypeSourceInfo(defaultValue);
+  return addTypeTemplateParam(name, defaultValueSourceInfo);
+}
+
+NonTypeTemplateParmDecl* BuiltinTypeDeclBuilder::addIntegerTemplateParam(StringRef name, QualType type, Optional<int64_t> defaultValue) {
+  DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() && !m_recordDecl->isCompleteDefinition());
+
+  ASTContext& astContext = m_recordDecl->getASTContext();
+  unsigned index = (unsigned)m_templateParams.size();
+  NonTypeTemplateParmDecl* decl = NonTypeTemplateParmDecl::Create(
+    astContext, m_recordDecl->getDeclContext(), NoLoc, NoLoc, /* TemplateDepth */ 0, index,
+    &astContext.Idents.get(name, tok::TokenKind::identifier), type, /* ParameterPack */ false,
+    astContext.getTrivialTypeSourceInfo(type));
+  if (defaultValue.hasValue()) {
+    Expr* defaultValueLiteral = IntegerLiteral::Create(
+      astContext, llvm::APInt(astContext.getIntWidth(type), defaultValue.getValue()), type, NoLoc);
+    decl->setDefaultArgument(defaultValueLiteral);
+  }
+  m_templateParams.emplace_back(decl);
+  return decl;
+}
+
+void BuiltinTypeDeclBuilder::startDefinition() {
+  DXASSERT_NOMSG(!m_recordDecl->isBeingDefined() && !m_recordDecl->isCompleteDefinition());
+
+  ASTContext& astContext = m_recordDecl->getASTContext();
+  DeclContext* declContext = m_recordDecl->getDeclContext();
+
+  if (!m_templateParams.empty()) {
+    TemplateParameterList* templateParameterList = TemplateParameterList::Create(
+      astContext, NoLoc, NoLoc, m_templateParams.data(), m_templateParams.size(), NoLoc);
+    m_templateDecl = ClassTemplateDecl::Create(
+      astContext, declContext, NoLoc, DeclarationName(m_recordDecl->getIdentifier()),
+      templateParameterList, m_recordDecl, nullptr);
+    m_recordDecl->setDescribedClassTemplate(m_templateDecl);
+    m_templateDecl->setImplicit(true);
+    m_templateDecl->setLexicalDeclContext(declContext);
+    declContext->addDecl(m_templateDecl);
+
+    // Requesting the class name specialization will fault in required types.
+    QualType T = m_templateDecl->getInjectedClassNameSpecialization();
+    T = astContext.getInjectedClassNameType(m_recordDecl, T);
+    assert(T->isDependentType() && "Class template type is not dependent?");
+  }
+
+  m_recordDecl->setLexicalDeclContext(declContext);
+  m_recordDecl->addAttr(FinalAttr::CreateImplicit(astContext, FinalAttr::Keyword_final));
+  m_recordDecl->startDefinition();
+}
+
+FieldDecl* BuiltinTypeDeclBuilder::addField(StringRef name, QualType type, AccessSpecifier access) {
+  DXASSERT_NOMSG(m_recordDecl->isBeingDefined());
+
+  ASTContext& astContext = m_recordDecl->getASTContext();
+
+  IdentifierInfo& nameId = astContext.Idents.get(name, tok::TokenKind::identifier);
+  TypeSourceInfo* fieldTypeSource = astContext.getTrivialTypeSourceInfo(type, NoLoc);
+  const bool MutableFalse = false;
+  const InClassInitStyle initStyle = InClassInitStyle::ICIS_NoInit;
+  FieldDecl* fieldDecl = FieldDecl::Create(
+    astContext, m_recordDecl, NoLoc, NoLoc, &nameId, type, fieldTypeSource, nullptr, MutableFalse, initStyle);
+  fieldDecl->setAccess(access);
+  fieldDecl->setImplicit(true);
+  m_recordDecl->addDecl(fieldDecl);
+
+#ifdef DBG
+  // Verify that we can read the field member from the record.
+  DeclContext::lookup_result lookupResult = m_recordDecl->lookup(DeclarationName(&nameId));
+  DXASSERT(!lookupResult.empty(), "Field cannot be looked up");
+#endif
+
+  return fieldDecl;
+}
+
+CXXRecordDecl* BuiltinTypeDeclBuilder::completeDefinition() {
+  DXASSERT_NOMSG(!m_recordDecl->isCompleteDefinition());
+  if (!m_recordDecl->isBeingDefined()) startDefinition();
+  m_recordDecl->completeDefinition();
+  return m_recordDecl;
+}
+
+ClassTemplateDecl* BuiltinTypeDeclBuilder::getTemplateDecl() const {
+  DXASSERT_NOMSG(m_recordDecl->isBeingDefined() || m_recordDecl->isCompleteDefinition());
+  return m_templateDecl;
+}

+ 5 - 20
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -3292,28 +3292,17 @@ private:
           break;
         }
       } else if (kind == AR_OBJECT_RAY_QUERY) {
-        ClassTemplateDecl* typeDecl = nullptr;
-        AddRayQueryTemplate(*m_context, &typeDecl, &recordDecl);
-        DXASSERT(typeDecl != nullptr, "AddRayQueryTemplate failed to return the object declaration");
-        typeDecl->setImplicit(true);
-        recordDecl->setImplicit(true);
+        recordDecl = DeclareUIntTemplatedTypeWithHandle(*m_context, "RayQuery", "flags");
       }
-      else if (templateArgCount == 0)
-      {
-        AddRecordTypeWithHandle(*m_context, &recordDecl, typeName);
-        DXASSERT(recordDecl != nullptr, "AddRecordTypeWithHandle failed to return the object declaration");
-        recordDecl->setImplicit(true);
+      else if (templateArgCount == 0) {
+        recordDecl = DeclareRecordTypeWithHandle(*m_context, typeName);
       }
       else
       {
         DXASSERT(templateArgCount == 1 || templateArgCount == 2, "otherwise a new case has been added");
 
-        ClassTemplateDecl* typeDecl = nullptr;
         TypeSourceInfo* typeDefault = TemplateHasDefaultType(kind) ? float4TypeSourceInfo : nullptr;
-        AddTemplateTypeWithHandle(*m_context, &typeDecl, &recordDecl, typeName, templateArgCount, typeDefault);
-        DXASSERT(typeDecl != nullptr, "AddTemplateTypeWithHandle failed to return the object declaration");
-        typeDecl->setImplicit(true);
-        recordDecl->setImplicit(true);
+        recordDecl = DeclareTemplateTypeWithHandle(*m_context, typeName, templateArgCount, typeDefault);
       }
       m_objectTypeDecls[i] = recordDecl;
       m_objectTypeDeclsMap[i] = std::make_pair(recordDecl, i);
@@ -4237,11 +4226,7 @@ public:
     DXASSERT(m_matrixTemplateDecl != nullptr, "AddHLSLMatrixTypes failed to return the matrix template declaration");
 
     // Initializing built in integers for ray tracing
-    AddRayFlags(*m_context);
-    AddHitKinds(*m_context);
-    AddStateObjectFlags(*m_context);
-    AddCommittedStatus(*m_context);
-    AddCandidateType(*m_context);
+    AddRayTracingConstants(*m_context);
 
     return true;
   }