//===--- ASTContextHLSL.cpp - HLSL support for AST nodes and operations ---===//
///////////////////////////////////////////////////////////////////////////////
// //
// ASTContextHLSL.cpp //
// Copyright (C) Microsoft Corporation. All rights reserved. //
// This file is distributed under the University of Illinois Open Source //
// License. See LICENSE.TXT for details. //
// //
// This file implements the ASTContext interface for HLSL. //
// //
///////////////////////////////////////////////////////////////////////////////
#include "clang/AST/ASTContext.h"
#include "clang/AST/Attr.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/Expr.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/ExternalASTSource.h"
#include "clang/AST/TypeLoc.h"
#include "clang/Sema/SemaDiagnostic.h"
#include "clang/Sema/Sema.h"
#include "clang/Sema/Overload.h"
#include "dxc/Support/Global.h"
#include "dxc/HLSL/HLOperations.h"
#include "dxc/HLSL/DxilSemantic.h"
using namespace clang;
using namespace hlsl;
static const int FirstTemplateDepth = 0;
static const int FirstParamPosition = 0;
static const bool ForConstFalse = false; // a construct is targeting a const type
static const bool ForConstTrue = true; // a construct is targeting a non-const type
static const bool ExplicitConversionFalse = false;// a conversion operation is not the result of an explicit cast
static const bool InheritedFalse = false; // template parameter default value is not inherited.
static const bool ParameterPackFalse = false; // template parameter is not an ellipsis.
static const bool TypenameFalse = false; // 'typename' specified rather than 'class' for a template argument.
static const bool DelayTypeCreationTrue = true; // delay type creation for a declaration
static const bool DelayTypeCreationFalse = false; // immediately create a type when the declaration is created
static const unsigned int NoQuals = 0; // no qualifiers in effect
static const SourceLocation NoLoc; // no source location attribution available
static const bool HasWrittenPrototypeTrue = true; // function had the prototype written
static const bool InlineFalse = false; // namespace is not an inline namespace
static const bool InlineSpecifiedFalse = false; // function was not specified as inline
static const bool IsConstexprFalse = false; // function is not constexpr
static const bool ListInitializationFalse = false;// not performing a list initialization
static const bool SuppressDiagTrue = true; // suppress diagnostics
static const bool VirtualFalse = false; // whether the base class is declares 'virtual'
static const bool BaseClassFalse = false; // whether the base class is declared as 'class' (vs. 'struct')
/// Names of HLSLScalarType enumeration values, in matching order to HLSLScalarType.
const char* HLSLScalarTypeNames[] = {
"",
"bool",
"int",
"uint",
"dword",
"half",
"float",
"double",
"min10float",
"min16float",
"min12int",
"min16int",
"min16uint",
"literal float",
"literal int",
"int16_t",
"int32_t",
"int64_t",
"uint16_t",
"uint32_t",
"uint64_t",
"float16_t",
"float32_t",
"float64_t"
};
static_assert(HLSLScalarTypeCount == _countof(HLSLScalarTypeNames), "otherwise scalar constants are not aligned");
static HLSLScalarType FindScalarTypeByName(const char *typeName, const size_t typeLen, const LangOptions& langOptions) {
// skipped HLSLScalarType: unknown, literal int, literal float
switch (typeLen) {
case 3: // int
if (typeName[0] == 'i') {
if (strncmp(typeName, "int", 3))
break;
return HLSLScalarType_int;
}
break;
case 4: // bool, uint, half
if (typeName[0] == 'b') {
if (strncmp(typeName, "bool", 4))
break;
return HLSLScalarType_bool;
}
else if (typeName[0] == 'u') {
if (strncmp(typeName, "uint", 4))
break;
return HLSLScalarType_uint;
}
else if (typeName[0] == 'h') {
if (strncmp(typeName, "half", 4))
break;
return HLSLScalarType_half;
}
break;
case 5: // dword, float
if (typeName[0] == 'd') {
if (strncmp(typeName, "dword", 5))
break;
return HLSLScalarType_dword;
}
else if (typeName[0] == 'f') {
if (strncmp(typeName, "float", 5))
break;
return HLSLScalarType_float;
}
break;
case 6: // double
if (typeName[0] == 'd') {
if (strncmp(typeName, "double", 6))
break;
return HLSLScalarType_double;
}
break;
case 7: // int64_t
if (typeName[0] == 'i' && typeName[1] == 'n') {
if (typeName[3] == '6') {
if (strncmp(typeName, "int64_t", 7))
break;
return HLSLScalarType_int64;
}
}
case 8: // min12int, min16int, uint64_t
if (typeName[0] == 'm' && typeName[1] == 'i') {
if (typeName[4] == '2') {
if (strncmp(typeName, "min12int", 8))
break;
return HLSLScalarType_int_min12;
}
else if (typeName[4] == '6') {
if (strncmp(typeName, "min16int", 8))
break;
return HLSLScalarType_int_min16;
}
}
else if (typeName[0] == 'u' && typeName[1] == 'i') {
if (typeName[4] == '6') {
if (strncmp(typeName, "uint64_t", 8))
break;
return HLSLScalarType_uint64;
}
}
break;
case 9: // min16uint
if (typeName[0] == 'm' && typeName[1] == 'i') {
if (strncmp(typeName, "min16uint", 9))
break;
return HLSLScalarType_uint_min16;
}
break;
case 10: // min10float, min16float
if (typeName[0] == 'm' && typeName[1] == 'i') {
if (typeName[4] == '0') {
if (strncmp(typeName, "min10float", 10))
break;
return HLSLScalarType_float_min10;
}
if (typeName[4] == '6') {
if (strncmp(typeName, "min16float", 10))
break;
return HLSLScalarType_float_min16;
}
}
break;
default:
break;
}
// fixed width types (int16_t, uint16_t, int32_t, uint32_t, float16_t, float32_t, float64_t)
// are only supported in HLSL 2018
if (langOptions.HLSLVersion >= 2018) {
switch (typeLen) {
case 7: // int16_t, int32_t
if (typeName[0] == 'i' && typeName[1] == 'n') {
if (!langOptions.UseMinPrecision) {
if (typeName[3] == '1') {
if (strncmp(typeName, "int16_t", 7))
break;
return HLSLScalarType_int16;
}
}
if (typeName[3] == '3') {
if (strncmp(typeName, "int32_t", 7))
break;
return HLSLScalarType_int32;
}
}
case 8: // uint16_t, uint32_t
if (!langOptions.UseMinPrecision) {
if (typeName[0] == 'u' && typeName[1] == 'i') {
if (typeName[4] == '1') {
if (strncmp(typeName, "uint16_t", 8))
break;
return HLSLScalarType_uint16;
}
}
}
if (typeName[4] == '3') {
if (strncmp(typeName, "uint32_t", 8))
break;
return HLSLScalarType_uint32;
}
case 9: // float16_t, float32_t, float64_t
if (typeName[0] == 'f' && typeName[1] == 'l') {
if (!langOptions.UseMinPrecision) {
if (typeName[5] == '1') {
if (strncmp(typeName, "float16_t", 9))
break;
return HLSLScalarType_float16;
}
}
if (typeName[5] == '3') {
if (strncmp(typeName, "float32_t", 9))
break;
return HLSLScalarType_float32;
}
else if (typeName[5] == '6') {
if (strncmp(typeName, "float64_t", 9))
break;
return HLSLScalarType_float64;
}
}
}
}
return HLSLScalarType_unknown;
}
/// Provides the primitive type for lowering matrix types to IR.
static
CanQualType GetHLSLObjectHandleType(ASTContext& context)
{
return context.IntTy;
}
/// Adds a handle field to the specified record.
static
void AddHLSLHandleField(ASTContext& context, DeclContext* recordDecl, QualType handleQualType)
{
IdentifierInfo& handleId = context.Idents.get(StringRef("h"), tok::TokenKind::identifier);
TypeSourceInfo* fieldTypeSource = context.getTrivialTypeSourceInfo(handleQualType, NoLoc);
const bool MutableFalse = false;
const InClassInitStyle initStyle = InClassInitStyle::ICIS_NoInit;
FieldDecl* handleDecl = FieldDecl::Create(
context, recordDecl, NoLoc, NoLoc, &handleId, handleQualType, fieldTypeSource, nullptr, MutableFalse, initStyle);
handleDecl->setAccess(AccessSpecifier::AS_private);
handleDecl->setImplicit(true);
recordDecl->addDecl(handleDecl);
}
static
void AddSubscriptOperator(
ASTContext& context, unsigned int templateDepth, TemplateTypeParmDecl *elementTemplateParamDecl,
NonTypeTemplateParmDecl* colCountTemplateParamDecl, QualType intType, CXXRecordDecl* templateRecordDecl,
ClassTemplateDecl* vectorTemplateDecl,
bool forConst)
{
QualType elementType = context.getTemplateTypeParmType(
templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
Expr* sizeExpr = DeclRefExpr::Create(context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl, false,
DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
intType, ExprValueKind::VK_RValue);
CXXRecordDecl *vecTemplateRecordDecl = vectorTemplateDecl->getTemplatedDecl();
const clang::Type *vecTy = vecTemplateRecordDecl->getTypeForDecl();
TemplateArgument templateArgs[2] =
{
TemplateArgument(elementType),
TemplateArgument(sizeExpr)
};
TemplateName canonName = context.getCanonicalTemplateName(TemplateName(vectorTemplateDecl));
QualType vectorType = context.getTemplateSpecializationType(
canonName, templateArgs, _countof(templateArgs), QualType(vecTy, 0));
vectorType = context.getLValueReferenceType(vectorType);
if (forConst)
vectorType = context.getConstType(vectorType);
QualType indexType = intType;
CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
context, templateRecordDecl, vectorType,
ArrayRef(indexType), ArrayRef(StringRef("index")),
context.DeclarationNames.getCXXOperatorName(OO_Subscript), forConst);
}
/// Adds up-front support for HLSL matrix types (just the template declaration).
void hlsl::AddHLSLMatrixTemplate(ASTContext& context, ClassTemplateDecl* vectorTemplateDecl, ClassTemplateDecl** matrixTemplateDecl)
{
DXASSERT_NOMSG(matrixTemplateDecl != nullptr);
DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
DeclContext* currentDeclContext = context.getTranslationUnitDecl();
// Create a matrix template declaration in translation unit scope.
// template matrix { ... }
IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
QualType intType = context.IntTy;
Expr *literalIntFour = IntegerLiteral::Create(
context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
IdentifierInfo& rowCountParamId = context.Idents.get(StringRef("row_count"), tok::TokenKind::identifier);
NonTypeTemplateParmDecl* rowCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition + 1, &rowCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
rowCountTemplateParamDecl->setDefaultArgument(literalIntFour);
IdentifierInfo& colCountParamId = context.Idents.get(StringRef("col_count"), tok::TokenKind::identifier);
NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition + 2, &colCountParamId, intType, ParameterPackFalse, context.getTrivialTypeSourceInfo(intType));
colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
NamedDecl* templateParameters[] =
{
elementTemplateParamDecl, rowCountTemplateParamDecl, colCountTemplateParamDecl
};
TemplateParameterList* templateParameterList = TemplateParameterList::Create(
context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
IdentifierInfo& matrixId = context.Idents.get(StringRef("matrix"), tok::TokenKind::identifier);
CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &matrixId,
nullptr, DelayTypeCreationTrue);
ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
context, currentDeclContext, NoLoc, DeclarationName(&matrixId),
templateParameterList, templateRecordDecl, nullptr);
templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
// Requesting the class name specialization will fault in required types.
QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
T = context.getInjectedClassNameType(templateRecordDecl, T);
assert(T->isDependentType() && "Class template type is not dependent?");
classTemplateDecl->setLexicalDeclContext(currentDeclContext);
templateRecordDecl->setLexicalDeclContext(currentDeclContext);
templateRecordDecl->startDefinition();
// Add an 'h' field to hold the handle.
// The type is vector[row].
QualType elementType = context.getTemplateTypeParmType(
/*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
Expr *sizeExpr = DeclRefExpr::Create(
context, NestedNameSpecifierLoc(), NoLoc, rowCountTemplateParamDecl,
false,
DeclarationNameInfo(rowCountTemplateParamDecl->getDeclName(), NoLoc),
intType, ExprValueKind::VK_RValue);
Expr *rowSizeExpr = DeclRefExpr::Create(
context, NestedNameSpecifierLoc(), NoLoc, colCountTemplateParamDecl,
false,
DeclarationNameInfo(colCountTemplateParamDecl->getDeclName(), NoLoc),
intType, ExprValueKind::VK_RValue);
QualType vectorType = context.getDependentSizedExtVectorType(
elementType, rowSizeExpr, SourceLocation());
QualType vectorArrayType = context.getDependentSizedArrayType(
vectorType, sizeExpr, ArrayType::Normal, 0, SourceRange());
AddHLSLHandleField(context, templateRecordDecl, vectorArrayType);
// Add an operator[]. The operator ranges from zero to rowcount-1, and returns a vector of colcount elements.
const unsigned int templateDepth = 0;
AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
colCountTemplateParamDecl, context.UnsignedIntTy,
templateRecordDecl, vectorTemplateDecl, ForConstFalse);
AddSubscriptOperator(context, templateDepth, elementTemplateParamDecl,
colCountTemplateParamDecl, context.UnsignedIntTy,
templateRecordDecl, vectorTemplateDecl, ForConstTrue);
templateRecordDecl->completeDefinition();
classTemplateDecl->setImplicit(true);
templateRecordDecl->setImplicit(true);
// Both declarations need to be present for correct handling.
currentDeclContext->addDecl(classTemplateDecl);
currentDeclContext->addDecl(templateRecordDecl);
#ifdef DBG
// Verify that we can read the field member from the template record.
DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
DeclarationName(&context.Idents.get(StringRef("h"))));
DXASSERT(!lookupResult.empty(), "otherwise matrix handle cannot be looked up");
#endif
*matrixTemplateDecl = classTemplateDecl;
}
static void AddHLSLVectorSubscriptAttr(Decl *D, ASTContext &context) {
StringRef group = GetHLOpcodeGroupName(HLOpcodeGroup::HLSubscript);
D->addAttr(HLSLIntrinsicAttr::CreateImplicit(context, group, "", static_cast(HLSubscriptOpcode::VectorSubscript)));
}
/// Adds up-front support for HLSL vector types (just the template declaration).
void hlsl::AddHLSLVectorTemplate(ASTContext& context, ClassTemplateDecl** vectorTemplateDecl)
{
DXASSERT_NOMSG(vectorTemplateDecl != nullptr);
DeclContext* currentDeclContext = context.getTranslationUnitDecl();
// Create a vector template declaration in translation unit scope.
// template vector { ... }
IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
elementTemplateParamDecl->setDefaultArgument(context.getTrivialTypeSourceInfo(context.FloatTy));
QualType intType = context.IntTy;
Expr *literalIntFour = IntegerLiteral::Create(
context, llvm::APInt(context.getIntWidth(intType), 4), intType, NoLoc);
IdentifierInfo& colCountParamId = context.Idents.get(StringRef("element_count"), tok::TokenKind::identifier);
NonTypeTemplateParmDecl* colCountTemplateParamDecl = NonTypeTemplateParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition + 1, &colCountParamId, intType, ParameterPackFalse, nullptr);
colCountTemplateParamDecl->setDefaultArgument(literalIntFour);
NamedDecl* templateParameters[] =
{
elementTemplateParamDecl, colCountTemplateParamDecl
};
TemplateParameterList* templateParameterList = TemplateParameterList::Create(
context, NoLoc, NoLoc, templateParameters, _countof(templateParameters), NoLoc);
IdentifierInfo& vectorId = context.Idents.get(StringRef("vector"), tok::TokenKind::identifier);
CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &vectorId,
nullptr, DelayTypeCreationTrue);
ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
context, currentDeclContext, NoLoc, DeclarationName(&vectorId),
templateParameterList, templateRecordDecl, nullptr);
templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
// Requesting the class name specialization will fault in required types.
QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
T = context.getInjectedClassNameType(templateRecordDecl, T);
assert(T->isDependentType() && "Class template type is not dependent?");
classTemplateDecl->setLexicalDeclContext(currentDeclContext);
templateRecordDecl->setLexicalDeclContext(currentDeclContext);
templateRecordDecl->startDefinition();
// Add an 'h' field to hold the handle.
AddHLSLHandleField(context, templateRecordDecl, QualType(GetHLSLObjectHandleType(context)));
// Add an operator[]. The operator ranges from zero to colcount-1, and returns a scalar.
const unsigned int templateDepth = 0;
QualType resultType = context.getTemplateTypeParmType(
templateDepth, 0, ParameterPackFalse, elementTemplateParamDecl);
// ForConstTrue:
QualType refResultType = context.getConstType(context.getLValueReferenceType(resultType));
CXXMethodDecl* functionDecl = CreateObjectFunctionDeclarationWithParams(
context, templateRecordDecl, refResultType,
ArrayRef(context.UnsignedIntTy), ArrayRef(StringRef("index")),
context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstTrue);
AddHLSLVectorSubscriptAttr(functionDecl, context);
// ForConstFalse:
resultType = context.getLValueReferenceType(resultType);
functionDecl = CreateObjectFunctionDeclarationWithParams(
context, templateRecordDecl, resultType,
ArrayRef(context.UnsignedIntTy), ArrayRef(StringRef("index")),
context.DeclarationNames.getCXXOperatorName(OO_Subscript), ForConstFalse);
AddHLSLVectorSubscriptAttr(functionDecl, context);
templateRecordDecl->completeDefinition();
classTemplateDecl->setImplicit(true);
templateRecordDecl->setImplicit(true);
// Both declarations need to be present for correct handling.
currentDeclContext->addDecl(classTemplateDecl);
currentDeclContext->addDecl(templateRecordDecl);
#ifdef DBG
// Verify that we can read the field member from the template record.
DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
DeclarationName(&context.Idents.get(StringRef("h"))));
DXASSERT(!lookupResult.empty(), "otherwise vector handle cannot be looked up");
#endif
*vectorTemplateDecl = classTemplateDecl;
}
///
/// Adds a new record type in the specified context with the given name. The record type will have a handle field.
///
void hlsl::AddRecordTypeWithHandle(ASTContext& context, _Outptr_ CXXRecordDecl** typeDecl, _In_z_ const char* typeName)
{
DXASSERT_NOMSG(typeDecl != nullptr);
DXASSERT_NOMSG(typeName != nullptr);
*typeDecl = nullptr;
DeclContext* currentDeclContext = context.getTranslationUnitDecl();
IdentifierInfo& newTypeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
CXXRecordDecl* newDecl = CXXRecordDecl::Create(
context, TagDecl::TagKind::TTK_Struct, currentDeclContext, NoLoc, NoLoc, &newTypeId, nullptr);
newDecl->setLexicalDeclContext(currentDeclContext);
newDecl->setFreeStanding();
newDecl->startDefinition();
AddHLSLHandleField(context, newDecl, QualType(GetHLSLObjectHandleType(context)));
currentDeclContext->addDecl(newDecl);
newDecl->completeDefinition();
*typeDecl = newDecl;
}
// creates a global static constant unsigned integer with value.
// equivalent to: static const uint name = val;
static void AddConstInt(clang::ASTContext& context, DeclContext *DC, StringRef name, int val) {
IdentifierInfo &Id = context.Idents.get(name, tok::TokenKind::identifier);
QualType type = context.getConstType(context.UnsignedIntTy);
VarDecl *varDecl = VarDecl::Create(context, DC, NoLoc, NoLoc, &Id, type,
context.getTrivialTypeSourceInfo(type),
clang::StorageClass::SC_Static);
Expr *exprVal = IntegerLiteral::Create(
context, llvm::APInt(context.getIntWidth(type), val), type, NoLoc);
varDecl->setInit(exprVal);
varDecl->setImplicit(true);
DC->addDecl(varDecl);
}
/// Adds a const integers for ray flags
void hlsl::AddRayFlags(ASTContext& context) {
DeclContext *curDC = context.getTranslationUnitDecl();
// typedef uint RAY_FLAG;
IdentifierInfo &rayFlagId = context.Idents.get(StringRef("RAY_FLAG"), tok::TokenKind::identifier);
TypeSourceInfo *uintTypeSource = context.getTrivialTypeSourceInfo(context.UnsignedIntTy, NoLoc);
TypedefDecl *rayFlagDecl = TypedefDecl::Create(context, curDC, NoLoc, NoLoc, &rayFlagId, uintTypeSource);
curDC->addDecl(rayFlagDecl);
rayFlagDecl->setImplicit(true);
// static const uint RAY_FLAG_* = *;
AddConstInt(context, curDC, StringRef("RAY_FLAG_NONE"), 0x00);
AddConstInt(context, curDC, StringRef("RAY_FLAG_FORCE_OPAQUE"), 0x01);
AddConstInt(context, curDC, StringRef("RAY_FLAG_FORCE_NON_OPAQUE"), 0x02);
AddConstInt(context, curDC, StringRef("RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH"), 0x04);
AddConstInt(context, curDC, StringRef("RAY_FLAG_SKIP_CLOSEST_HIT_SHADER"), 0x08);
AddConstInt(context, curDC, StringRef("RAY_FLAG_CULL_BACK_FACING_TRIANGLES"), 0x10);
AddConstInt(context, curDC, StringRef("RAY_FLAG_CULL_FRONT_FACING_TRIANGLES"), 0x20);
AddConstInt(context, curDC, StringRef("RAY_FLAG_CULL_OPAQUE"), 0x40);
AddConstInt(context, curDC, StringRef("RAY_FLAG_CULL_NON_OPAQUE"), 0x80);
}
/// Adds a constant integers for hit kinds
void hlsl::AddHitKinds(ASTContext& context) {
DeclContext *curDC = context.getTranslationUnitDecl();
// static const uint HIT_KIND_* = *;
AddConstInt(context, curDC, StringRef("HIT_KIND_TRIANGLE_FRONT_FACE"), 0xfe);
AddConstInt(context, curDC, StringRef("HIT_KIND_TRIANGLE_BACK_FACE"), 0xff);
}
static
Expr* IntConstantAsBoolExpr(clang::Sema& sema, uint64_t value)
{
return sema.ImpCastExprToType(
sema.ActOnIntegerConstant(NoLoc, value).get(), sema.getASTContext().BoolTy, CK_IntegralToBoolean).get();
}
static
CXXRecordDecl* CreateStdStructWithStaticBool(clang::ASTContext& context, NamespaceDecl* stdNamespace, IdentifierInfo& trueTypeId, IdentifierInfo& valueId, Expr* trueExpression)
{
// struct true_type { static const bool value = true; }
TypeSourceInfo* boolTypeSource = context.getTrivialTypeSourceInfo(context.BoolTy.withConst());
CXXRecordDecl* trueTypeDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &trueTypeId, nullptr, DelayTypeCreationTrue);
QualType trueTypeQT = context.getTagDeclType(trueTypeDecl); // Fault this in now.
// static fields are variables in the AST
VarDecl* trueValueDecl = VarDecl::Create(context, trueTypeDecl, NoLoc, NoLoc, &valueId,
context.BoolTy.withConst(), boolTypeSource, SC_Static);
trueValueDecl->setInit(trueExpression);
trueValueDecl->setConstexpr(true);
trueValueDecl->setAccess(AS_public);
trueTypeDecl->setLexicalDeclContext(stdNamespace);
trueTypeDecl->startDefinition();
trueTypeDecl->addDecl(trueValueDecl);
trueTypeDecl->completeDefinition();
stdNamespace->addDecl(trueTypeDecl);
return trueTypeDecl;
}
static
void DefineRecordWithBase(CXXRecordDecl* decl, DeclContext* lexicalContext, const CXXBaseSpecifier* base)
{
decl->setLexicalDeclContext(lexicalContext);
decl->startDefinition();
decl->setBases(&base, 1);
decl->completeDefinition();
lexicalContext->addDecl(decl);
}
static
void SetPartialExplicitSpecialization(ClassTemplateDecl* templateDecl, ClassTemplatePartialSpecializationDecl* specializationDecl)
{
specializationDecl->setSpecializationKind(TSK_ExplicitSpecialization);
templateDecl->AddPartialSpecialization(specializationDecl, nullptr);
}
static
void CreateIsEqualSpecialization(ASTContext& context, ClassTemplateDecl* templateDecl, TemplateName& templateName,
DeclContext* lexicalContext, const CXXBaseSpecifier* base, TemplateParameterList* templateParamList,
TemplateArgument (&templateArgs)[2])
{
QualType specializationCanonType = context.getTemplateSpecializationType(templateName, templateArgs, _countof(templateArgs));
TemplateArgumentListInfo templateArgsListInfo = TemplateArgumentListInfo(NoLoc, NoLoc);
templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[0], context.getTrivialTypeSourceInfo(templateArgs[0].getAsType())));
templateArgsListInfo.addArgument(TemplateArgumentLoc(templateArgs[1], context.getTrivialTypeSourceInfo(templateArgs[1].getAsType())));
ClassTemplatePartialSpecializationDecl* specializationDecl =
ClassTemplatePartialSpecializationDecl::Create(context, TTK_Struct, lexicalContext, NoLoc, NoLoc,
templateParamList, templateDecl, templateArgs, _countof(templateArgs),
templateArgsListInfo, specializationCanonType, nullptr);
context.getTagDeclType(specializationDecl); // Fault this in now.
DefineRecordWithBase(specializationDecl, lexicalContext, base);
SetPartialExplicitSpecialization(templateDecl, specializationDecl);
}
/// Adds the implementation for std::is_equal.
void hlsl::AddStdIsEqualImplementation(clang::ASTContext& context, clang::Sema& sema)
{
// The goal is to support std::is_same::value for testing purposes, in a manner that can
// evolve into a compliant feature in the future.
//
// The definitions necessary are as follows (all in the std namespace).
// template
// struct integral_constant {
// typedef T value_type;
// static const value_type value = v;
// operator value_type() { return value; }
// };
//
// typedef integral_constant true_type;
// typedef integral_constant false_type;
//
// template struct is_same : public false_type {};
// template struct is_same : public true_type{};
//
// We instead use these simpler definitions for true_type and false_type.
// struct false_type { static const bool value = false; };
// struct true_type { static const bool value = true; };
DeclContext* tuContext = context.getTranslationUnitDecl();
IdentifierInfo& stdId = context.Idents.get(StringRef("std"), tok::TokenKind::identifier);
IdentifierInfo& trueTypeId = context.Idents.get(StringRef("true_type"), tok::TokenKind::identifier);
IdentifierInfo& falseTypeId = context.Idents.get(StringRef("false_type"), tok::TokenKind::identifier);
IdentifierInfo& valueId = context.Idents.get(StringRef("value"), tok::TokenKind::identifier);
IdentifierInfo& isSameId = context.Idents.get(StringRef("is_same"), tok::TokenKind::identifier);
IdentifierInfo& tId = context.Idents.get(StringRef("T"), tok::TokenKind::identifier);
IdentifierInfo& vId = context.Idents.get(StringRef("V"), tok::TokenKind::identifier);
Expr* trueExpression = IntConstantAsBoolExpr(sema, 1);
Expr* falseExpression = IntConstantAsBoolExpr(sema, 0);
// namespace std
NamespaceDecl* stdNamespace = NamespaceDecl::Create(context, tuContext, InlineFalse, NoLoc, NoLoc, &stdId, nullptr);
CXXRecordDecl* trueTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, trueTypeId, valueId, trueExpression);
CXXRecordDecl* falseTypeDecl = CreateStdStructWithStaticBool(context, stdNamespace, falseTypeId, valueId, falseExpression);
// template struct is_same : public false_type {};
CXXRecordDecl* isSameFalseRecordDecl = CXXRecordDecl::Create(context, TagTypeKind::TTK_Struct, stdNamespace, NoLoc, NoLoc, &isSameId, nullptr, false);
TemplateTypeParmDecl* tParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition, &tId, TypenameFalse, ParameterPackFalse);
TemplateTypeParmDecl* uParam = TemplateTypeParmDecl::Create(context, stdNamespace, NoLoc, NoLoc, FirstTemplateDepth, FirstParamPosition + 1, &vId, TypenameFalse, ParameterPackFalse);
NamedDecl* falseParams[] = { tParam, uParam };
TemplateParameterList* falseParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, falseParams, _countof(falseParams), NoLoc);
ClassTemplateDecl* isSameFalseTemplateDecl = ClassTemplateDecl::Create(context, stdNamespace, NoLoc, DeclarationName(&isSameId), falseParamList, isSameFalseRecordDecl, nullptr);
context.getTagDeclType(isSameFalseRecordDecl); // Fault this in now.
CXXBaseSpecifier* falseBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
context.getTrivialTypeSourceInfo(context.getTypeDeclType(falseTypeDecl)), NoLoc);
isSameFalseRecordDecl->setDescribedClassTemplate(isSameFalseTemplateDecl);
isSameFalseTemplateDecl->setLexicalDeclContext(stdNamespace);
DefineRecordWithBase(isSameFalseRecordDecl, stdNamespace, falseBase);
// is_same for 'true' is a specialization of is_same for 'false', taking a single T, where both T will match
// template struct is_same : public true_type{};
TemplateName tn = TemplateName(isSameFalseTemplateDecl);
NamedDecl* trueParams[] = { tParam };
TemplateParameterList* trueParamList = TemplateParameterList::Create(context, NoLoc, NoLoc, trueParams, _countof(trueParams), NoLoc);
CXXBaseSpecifier* trueBase = new (context)CXXBaseSpecifier(SourceRange(), VirtualFalse, BaseClassFalse, AS_public,
context.getTrivialTypeSourceInfo(context.getTypeDeclType(trueTypeDecl)), NoLoc);
TemplateArgument ta = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)));
TemplateArgument isSameTrueTemplateArgs[] = { ta, ta };
CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueTemplateArgs);
stdNamespace->addDecl(isSameFalseTemplateDecl);
stdNamespace->setImplicit(true);
tuContext->addDecl(stdNamespace);
// This could be a parameter if ever needed.
const bool SupportExtensions = true;
// Consider right-hand const and right-hand ref to be true for is_same:
// template struct is_same : public true_type{};
// template struct is_same : public true_type{};
if (SupportExtensions)
{
TemplateArgument trueConstArg = TemplateArgument(context.getCanonicalType(context.getTypeDeclType(tParam)).withConst());
TemplateArgument isSameTrueConstTemplateArgs[] = { ta, trueConstArg };
CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueConstTemplateArgs);
TemplateArgument trueRefArg = TemplateArgument(
context.getLValueReferenceType(context.getCanonicalType(context.getTypeDeclType(tParam))));
TemplateArgument isSameTrueRefTemplateArgs[] = { ta, trueRefArg };
CreateIsEqualSpecialization(context, isSameFalseTemplateDecl, tn, stdNamespace, trueBase, trueParamList, isSameTrueRefTemplateArgs);
}
}
///
/// Adds a new template type in the specified context with the given name. The record type will have a handle field.
///
/// AST context to which template will be added.
/// After execution, template declaration.
/// After execution, record declaration for template.
/// Name of template to create.
/// Number of template arguments (one or two).
/// If assigned, the default argument for the element template.
void hlsl::AddTemplateTypeWithHandle(
ASTContext& context,
_Outptr_ ClassTemplateDecl** typeDecl,
_Outptr_ CXXRecordDecl** recordDecl,
_In_z_ const char* typeName,
uint8_t templateArgCount,
_In_opt_ TypeSourceInfo* defaultTypeArgValue
)
{
DXASSERT_NOMSG(typeDecl != nullptr);
DXASSERT_NOMSG(recordDecl != nullptr);
DXASSERT_NOMSG(typeName != nullptr);
DXASSERT(templateArgCount != 0, "otherwise caller should be creating a class or struct");
DXASSERT(templateArgCount <= 2, "otherwise the function needs to be updated for a different template pattern");
DeclContext* currentDeclContext = context.getTranslationUnitDecl();
// Create an object template declaration in translation unit scope.
// templateArgCount=1: template typeName { ... }
// templateArgCount=2: template typeName { ... }
IdentifierInfo& elementTemplateParamId = context.Idents.get(StringRef("element"), tok::TokenKind::identifier);
TemplateTypeParmDecl *elementTemplateParamDecl = TemplateTypeParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition, &elementTemplateParamId, TypenameFalse, ParameterPackFalse);
QualType intType = context.IntTy;
if (defaultTypeArgValue != nullptr)
{
elementTemplateParamDecl->setDefaultArgument(defaultTypeArgValue);
}
NonTypeTemplateParmDecl* countTemplateParamDecl = nullptr;
if (templateArgCount > 1) {
IdentifierInfo& countParamId = context.Idents.get(StringRef("count"), tok::TokenKind::identifier);
countTemplateParamDecl = NonTypeTemplateParmDecl::Create(
context, currentDeclContext, NoLoc, NoLoc,
FirstTemplateDepth, FirstParamPosition + 1, &countParamId, intType, ParameterPackFalse, nullptr);
// Zero means default here. The count is decided by runtime.
Expr *literalIntZero = IntegerLiteral::Create(
context, llvm::APInt(context.getIntWidth(intType), 0), intType, NoLoc);
countTemplateParamDecl->setDefaultArgument(literalIntZero);
}
NamedDecl* templateParameters[] =
{
elementTemplateParamDecl, countTemplateParamDecl
};
TemplateParameterList* templateParameterList = TemplateParameterList::Create(
context, NoLoc, NoLoc, templateParameters, templateArgCount, NoLoc);
IdentifierInfo& typeId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
CXXRecordDecl* templateRecordDecl = CXXRecordDecl::Create(
context, TagDecl::TagKind::TTK_Class, currentDeclContext, NoLoc, NoLoc, &typeId,
nullptr, DelayTypeCreationTrue);
ClassTemplateDecl* classTemplateDecl = ClassTemplateDecl::Create(
context, currentDeclContext, NoLoc, DeclarationName(&typeId),
templateParameterList, templateRecordDecl, nullptr);
templateRecordDecl->setDescribedClassTemplate(classTemplateDecl);
// Requesting the class name specialization will fault in required types.
QualType T = classTemplateDecl->getInjectedClassNameSpecialization();
T = context.getInjectedClassNameType(templateRecordDecl, T);
assert(T->isDependentType() && "Class template type is not dependent?");
classTemplateDecl->setLexicalDeclContext(currentDeclContext);
templateRecordDecl->setLexicalDeclContext(currentDeclContext);
templateRecordDecl->startDefinition();
// Many more things to come here, like constructors and the like....
// Add an 'h' field to hold the handle.
QualType elementType = context.getTemplateTypeParmType(
/*templateDepth*/ 0, 0, ParameterPackFalse, elementTemplateParamDecl);
if (templateArgCount > 1 &&
// Only need array type for inputpatch and outputpatch.
// Avoid Texture2DMS which may use 0 count.
// TODO: use hlsl types to do the check.
!typeId.getName().startswith("Texture")) {
Expr *countExpr = DeclRefExpr::Create(
context, NestedNameSpecifierLoc(), NoLoc, countTemplateParamDecl, false,
DeclarationNameInfo(countTemplateParamDecl->getDeclName(), NoLoc),
intType, ExprValueKind::VK_RValue);
elementType = context.getDependentSizedArrayType(
elementType, countExpr, ArrayType::ArraySizeModifier::Normal, 0,
SourceRange());
// InputPatch and OutputPatch also have a "Length" static const member for the number of control points
IdentifierInfo& lengthId = context.Idents.get(StringRef("Length"), tok::TokenKind::identifier);
TypeSourceInfo* lengthTypeSource = context.getTrivialTypeSourceInfo(intType.withConst());
VarDecl* lengthValueDecl = VarDecl::Create(context, templateRecordDecl, NoLoc, NoLoc, &lengthId,
intType.withConst(), lengthTypeSource, SC_Static);
lengthValueDecl->setInit(countExpr);
lengthValueDecl->setAccess(AS_public);
templateRecordDecl->addDecl(lengthValueDecl);
}
AddHLSLHandleField(context, templateRecordDecl, elementType);
templateRecordDecl->completeDefinition();
// Both declarations need to be present for correct handling.
currentDeclContext->addDecl(classTemplateDecl);
currentDeclContext->addDecl(templateRecordDecl);
#ifdef DBG
// Verify that we can read the field member from the template record.
DeclContext::lookup_result lookupResult = templateRecordDecl->lookup(
DeclarationName(&context.Idents.get(StringRef("h"))));
DXASSERT(!lookupResult.empty(), "otherwise template object handle cannot be looked up");
#endif
*typeDecl = classTemplateDecl;
*recordDecl = templateRecordDecl;
}
FunctionTemplateDecl* hlsl::CreateFunctionTemplateDecl(
ASTContext& context,
_In_ CXXRecordDecl* recordDecl,
_In_ CXXMethodDecl* functionDecl,
_In_count_(templateParamNamedDeclsCount) NamedDecl** templateParamNamedDecls,
size_t templateParamNamedDeclsCount)
{
DXASSERT_NOMSG(recordDecl != nullptr);
DXASSERT_NOMSG(templateParamNamedDecls != nullptr);
DXASSERT(templateParamNamedDeclsCount > 0, "otherwise caller shouldn't invoke this function");
TemplateParameterList* templateParams = TemplateParameterList::Create(
context, NoLoc, NoLoc, &templateParamNamedDecls[0], templateParamNamedDeclsCount, NoLoc);
FunctionTemplateDecl* functionTemplate =
FunctionTemplateDecl::Create(context, recordDecl, NoLoc, functionDecl->getDeclName(), templateParams, functionDecl);
functionTemplate->setAccess(AccessSpecifier::AS_public);
functionTemplate->setLexicalDeclContext(recordDecl);
functionDecl->setDescribedFunctionTemplate(functionTemplate);
recordDecl->addDecl(functionTemplate);
return functionTemplate;
}
static
void AssociateParametersToFunctionPrototype(
_In_ TypeSourceInfo* tinfo,
_In_count_(numParams) ParmVarDecl** paramVarDecls,
unsigned int numParams)
{
FunctionProtoTypeLoc protoLoc = tinfo->getTypeLoc().getAs();
DXASSERT(protoLoc.getNumParams() == numParams, "otherwise unexpected number of parameters available");
for (unsigned i = 0; i < numParams; i++) {
DXASSERT(protoLoc.getParam(i) == nullptr, "otherwise prototype parameters were already initialized");
protoLoc.setParam(i, paramVarDecls[i]);
}
}
static void CreateObjectFunctionDeclaration(
ASTContext &context, _In_ CXXRecordDecl *recordDecl, QualType resultType,
ArrayRef args, DeclarationName declarationName, bool isConst,
_Out_ CXXMethodDecl **functionDecl, _Out_ TypeSourceInfo **tinfo) {
DXASSERT_NOMSG(recordDecl != nullptr);
DXASSERT_NOMSG(functionDecl != nullptr);
FunctionProtoType::ExtProtoInfo functionExtInfo;
functionExtInfo.TypeQuals = isConst ? Qualifiers::Const : 0;
QualType functionQT = context.getFunctionType(
resultType, args, functionExtInfo, ArrayRef());
DeclarationNameInfo declNameInfo(declarationName, NoLoc);
*tinfo = context.getTrivialTypeSourceInfo(functionQT, NoLoc);
DXASSERT_NOMSG(*tinfo != nullptr);
*functionDecl = CXXMethodDecl::Create(
context, recordDecl, NoLoc, declNameInfo, functionQT, *tinfo,
StorageClass::SC_None, InlineSpecifiedFalse, IsConstexprFalse, NoLoc);
DXASSERT_NOMSG(*functionDecl != nullptr);
(*functionDecl)->setLexicalDeclContext(recordDecl);
(*functionDecl)->setAccess(AccessSpecifier::AS_public);
}
CXXMethodDecl* hlsl::CreateObjectFunctionDeclarationWithParams(
ASTContext& context,
_In_ CXXRecordDecl* recordDecl,
QualType resultType,
ArrayRef paramTypes,
ArrayRef paramNames,
DeclarationName declarationName,
bool isConst)
{
DXASSERT_NOMSG(recordDecl != nullptr);
DXASSERT_NOMSG(!resultType.isNull());
DXASSERT_NOMSG(paramTypes.size() == paramNames.size());
TypeSourceInfo* tinfo;
CXXMethodDecl* functionDecl;
CreateObjectFunctionDeclaration(context, recordDecl, resultType, paramTypes,
declarationName, isConst, &functionDecl,
&tinfo);
// Create and associate parameters to method.
SmallVector parmVarDecls;
if (!paramTypes.empty()) {
for (unsigned int i = 0; i < paramTypes.size(); ++i) {
IdentifierInfo *argIi = &context.Idents.get(paramNames[i]);
ParmVarDecl *parmVarDecl = ParmVarDecl::Create(
context, functionDecl, NoLoc, NoLoc, argIi, paramTypes[i],
context.getTrivialTypeSourceInfo(paramTypes[i], NoLoc),
StorageClass::SC_None, nullptr);
parmVarDecl->setScopeInfo(0, i);
DXASSERT(parmVarDecl->getFunctionScopeIndex() == i,
"otherwise failed to set correct index");
parmVarDecls.push_back(parmVarDecl);
}
functionDecl->setParams(ArrayRef(parmVarDecls));
AssociateParametersToFunctionPrototype(tinfo, &parmVarDecls.front(),
parmVarDecls.size());
}
recordDecl->addDecl(functionDecl);
return functionDecl;
}
bool hlsl::IsIntrinsicOp(const clang::FunctionDecl *FD) {
return FD != nullptr && FD->hasAttr();
}
bool hlsl::GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode,
llvm::StringRef &group) {
if (FD == nullptr || !FD->hasAttr()) {
return false;
}
HLSLIntrinsicAttr *A = FD->getAttr();
opcode = A->getOpcode();
group = A->getGroup();
return true;
}
bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S) {
if (FD == nullptr || !FD->hasAttr()) {
return false;
}
HLSLIntrinsicAttr *A = FD->getAttr();
S = A->getLowering();
return true;
}
/// Parses a column or row digit.
static
bool TryParseColOrRowChar(const char digit, _Out_ int* count) {
if ('1' <= digit && digit <= '4') {
*count = digit - '0';
return true;
}
*count = 0;
return false;
}
/// Parses a matrix shorthand identifier (eg, float3x2).
_Use_decl_annotations_
bool hlsl::TryParseMatrixShorthand(
const char* typeName,
size_t typeNameLen,
HLSLScalarType* parsedType,
int* rowCount,
int* colCount,
const clang::LangOptions& langOptions
)
{
//
// Matrix shorthand format is PrimitiveTypeRxC, where R is the row count and C is the column count.
// R and C should be between 1 and 4 inclusive.
// x is a literal 'x' character.
// PrimitiveType is one of the HLSLScalarTypeNames values.
//
if (TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions) &&
*rowCount != 0 && *colCount != 0) {
// compare scalar component
HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-3, langOptions);
if (type!= HLSLScalarType_unknown) {
*parsedType = type;
return true;
}
}
// Unable to parse.
return false;
}
/// Parses a vector shorthand identifier (eg, float3).
_Use_decl_annotations_
bool hlsl::TryParseVectorShorthand(
const char* typeName,
size_t typeNameLen,
HLSLScalarType* parsedType,
int* elementCount,
const clang::LangOptions& langOptions
)
{
// At least *something*N characters necessary, where something is at least 'int'
if (TryParseColOrRowChar(typeName[typeNameLen - 1], elementCount)) {
// compare scalar component
HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-1, langOptions);
if (type!= HLSLScalarType_unknown) {
*parsedType = type;
return true;
}
}
// Unable to parse.
return false;
}
/// Parses a hlsl scalar type (e.g min16float, uint3x4)
_Use_decl_annotations_
bool hlsl::TryParseScalar(
_In_count_(typenameLen)
const char* typeName,
size_t typeNameLen,
_Out_ HLSLScalarType *parsedType,
_In_ const clang::LangOptions& langOptions) {
HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen, langOptions);
if (type!= HLSLScalarType_unknown) {
*parsedType = type;
return true;
}
return false; // unable to parse
}
/// Parse any (scalar, vector, matrix) hlsl types (e.g float, int3x4, uint2)
_Use_decl_annotations_
bool hlsl::TryParseAny(
_In_count_(typenameLen)
const char* typeName,
size_t typeNameLen,
_Out_ HLSLScalarType *parsedType,
int *rowCount,
int *colCount,
_In_ const clang::LangOptions& langOptions) {
// at least 'int'
const size_t MinValidLen = 3;
if (typeNameLen >= MinValidLen) {
TryParseMatrixOrVectorDimension(typeName, typeNameLen, rowCount, colCount, langOptions);
int suffixLen = *colCount == 0 ? 0 :
*rowCount == 0 ? 1 : 3;
HLSLScalarType type = FindScalarTypeByName(typeName, typeNameLen-suffixLen, langOptions);
if (type!= HLSLScalarType_unknown) {
*parsedType = type;
return true;
}
}
return false;
}
/// Parse any kind of dimension for vector or matrix (e.g 4,3 in int4x3).
/// If it's a matrix type, rowCount and colCount will be nonzero. If it's a vector type, colCount is 0.
/// Otherwise both rowCount and colCount is 0. Returns true if either matrix or vector dimensions detected.
_Use_decl_annotations_
bool hlsl::TryParseMatrixOrVectorDimension(
_In_count_(typeNameLen)
const char *typeName,
size_t typeNameLen,
_Out_opt_ int *rowCount,
_Out_opt_ int *colCount,
_In_ const clang::LangOptions& langOptions) {
*rowCount = 0;
*colCount = 0;
size_t MinValidLen = 3; // at least int
if (typeNameLen > MinValidLen) {
if (TryParseColOrRowChar(typeName[typeNameLen - 1], colCount)) {
// Try parse matrix
if (typeName[typeNameLen - 2] == 'x')
TryParseColOrRowChar(typeName[typeNameLen - 3], rowCount);
return true;
}
}
return false;
}
/// Creates a typedef for a matrix shorthand (eg, float3x2).
TypedefDecl* hlsl::CreateMatrixSpecializationShorthand(
ASTContext& context,
QualType matrixSpecialization,
HLSLScalarType scalarType,
size_t rowCount,
size_t colCount)
{
DXASSERT(rowCount <= 4, "else caller didn't validate rowCount");
DXASSERT(colCount <= 4, "else caller didn't validate colCount");
char typeName[64];
sprintf_s(typeName, _countof(typeName), "%s%ux%u",
HLSLScalarTypeNames[scalarType], (unsigned)rowCount, (unsigned)colCount);
IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
DeclContext* currentDeclContext = context.getTranslationUnitDecl();
TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
context.getTrivialTypeSourceInfo(matrixSpecialization, NoLoc));
decl->setImplicit(true);
currentDeclContext->addDecl(decl);
return decl;
}
/// Creates a typedef for a vector shorthand (eg, float3).
TypedefDecl* hlsl::CreateVectorSpecializationShorthand(
ASTContext& context,
QualType vectorSpecialization,
HLSLScalarType scalarType,
size_t colCount)
{
DXASSERT(colCount <= 4, "else caller didn't validate colCount");
char typeName[64];
sprintf_s(typeName, _countof(typeName), "%s%u",
HLSLScalarTypeNames[scalarType], (unsigned)colCount);
IdentifierInfo& typedefId = context.Idents.get(StringRef(typeName), tok::TokenKind::identifier);
DeclContext* currentDeclContext = context.getTranslationUnitDecl();
TypedefDecl* decl = TypedefDecl::Create(context, currentDeclContext, NoLoc, NoLoc, &typedefId,
context.getTrivialTypeSourceInfo(vectorSpecialization, NoLoc));
decl->setImplicit(true);
currentDeclContext->addDecl(decl);
return decl;
}
llvm::ArrayRef
hlsl::UnusualAnnotation::CopyToASTContextArray(
clang::ASTContext& Context, hlsl::UnusualAnnotation** begin, size_t count) {
if (count == 0) {
return llvm::ArrayRef();
}
UnusualAnnotation** arr = ::new (Context) UnusualAnnotation*[count];
for (size_t i = 0; i < count; ++i) {
arr[i] = begin[i]->CopyToASTContext(Context);
}
return llvm::ArrayRef(arr, count);
}
UnusualAnnotation* hlsl::UnusualAnnotation::CopyToASTContext(ASTContext& Context) {
// All UnusualAnnotation instances can be blitted.
size_t instanceSize;
switch (Kind) {
case UA_RegisterAssignment:
instanceSize = sizeof(hlsl::RegisterAssignment);
break;
case UA_ConstantPacking:
instanceSize = sizeof(hlsl::ConstantPacking);
break;
default:
DXASSERT(Kind == UA_SemanticDecl, "Kind == UA_SemanticDecl -- otherwise switch is incomplete");
instanceSize = sizeof(hlsl::SemanticDecl);
break;
}
void* result = Context.Allocate(instanceSize);
memcpy(result, this, instanceSize);
return (UnusualAnnotation*)result;
}
static bool HasTessFactorSemantic(const ValueDecl *decl) {
for (const UnusualAnnotation *it : decl->getUnusualAnnotations()) {
switch (it->getKind()) {
case UnusualAnnotation::UA_SemanticDecl: {
const SemanticDecl *sd = cast(it);
const Semantic *pSemantic = Semantic::GetByName(sd->SemanticName);
if (pSemantic && pSemantic->GetKind() == Semantic::Kind::TessFactor)
return true;
}
}
}
return false;
}
static bool HasTessFactorSemanticRecurse(const ValueDecl *decl, QualType Ty) {
if (Ty->isBuiltinType() || hlsl::IsHLSLVecMatType(Ty))
return false;
if (const RecordType *RT = Ty->getAsStructureType()) {
RecordDecl *RD = RT->getDecl();
for (FieldDecl *fieldDecl : RD->fields()) {
if (HasTessFactorSemanticRecurse(fieldDecl, fieldDecl->getType()))
return true;
}
return false;
}
if (const clang::ArrayType *arrayTy = Ty->getAsArrayTypeUnsafe())
return HasTessFactorSemantic(decl);
return false;
}
bool ASTContext::IsPatchConstantFunctionDecl(const FunctionDecl *FD) const {
// This checks whether the function is structurally capable of being a patch
// constant function, not whether it is in fact the patch constant function
// for the entry point of a compiled hull shader (which may not have been
// seen yet). So the answer is conservative.
if (!FD->getReturnType()->isVoidType()) {
// Try to find TessFactor in return type.
if (HasTessFactorSemanticRecurse(FD, FD->getReturnType()))
return true;
}
// Try to find TessFactor in out param.
for (const ParmVarDecl *param : FD->params()) {
if (param->hasAttr()) {
if (HasTessFactorSemanticRecurse(param, param->getType()))
return true;
}
}
return false;
}