| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531 |
- //===--- SpirvContext.cpp - SPIR-V SpirvContext implementation-------------===//
- //
- // The LLVM Compiler Infrastructure
- //
- // This file is distributed under the University of Illinois Open Source
- // License. See LICENSE.TXT for details.
- //
- //===----------------------------------------------------------------------===//
- #include <algorithm>
- #include <tuple>
- #include "clang/SPIRV/SpirvContext.h"
- #include "clang/SPIRV/SpirvModule.h"
- namespace clang {
- namespace spirv {
- SpirvContext::SpirvContext()
- : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
- uintTypes({}), floatTypes({}), samplerType(nullptr),
- curShaderModelKind(ShaderModelKind::Invalid), majorVersion(0),
- minorVersion(0), currentLexicalScope(nullptr) {
- voidType = new (this) VoidType;
- boolType = new (this) BoolType;
- samplerType = new (this) SamplerType;
- accelerationStructureTypeNV = new (this) AccelerationStructureTypeNV;
- rayQueryTypeKHR = new (this) RayQueryTypeKHR;
- }
- SpirvContext::~SpirvContext() {
- voidType->~VoidType();
- boolType->~BoolType();
- samplerType->~SamplerType();
- accelerationStructureTypeNV->~AccelerationStructureTypeNV();
- rayQueryTypeKHR->~RayQueryTypeKHR();
- for (auto *sintType : sintTypes)
- if (sintType) // sintTypes may contain nullptr
- sintType->~IntegerType();
- for (auto *uintType : uintTypes)
- if (uintType) // uintTypes may contain nullptr
- uintType->~IntegerType();
- for (auto *floatType : floatTypes)
- if (floatType) // floatTypes may contain nullptr
- floatType->~FloatType();
- for (auto &pair : vecTypes)
- for (auto *vecType : pair.second)
- if (vecType) // vecTypes may contain nullptr
- vecType->~VectorType();
- for (auto &pair : matTypes)
- for (auto *matType : pair.second)
- matType->~MatrixType();
- for (auto *arrType : arrayTypes)
- arrType->~ArrayType();
- for (auto *raType : runtimeArrayTypes)
- raType->~RuntimeArrayType();
- for (auto *fnType : functionTypes)
- fnType->~FunctionType();
- for (auto *structType : structTypes)
- structType->~StructType();
- for (auto *hybridStructType : hybridStructTypes)
- hybridStructType->~HybridStructType();
- for (auto pair : sampledImageTypes)
- pair.second->~SampledImageType();
- for (auto *hybridSampledImageType : hybridSampledImageTypes)
- hybridSampledImageType->~HybridSampledImageType();
- for (auto *imgType : imageTypes)
- imgType->~ImageType();
- for (auto &pair : pointerTypes)
- for (auto &scPtrTypePair : pair.second)
- scPtrTypePair.second->~SpirvPointerType();
- for (auto *hybridPtrType : hybridPointerTypes)
- hybridPtrType->~HybridPointerType();
- for (auto &typePair : debugTypes)
- typePair.second->releaseMemory();
- for (auto &typePair : typeTemplates)
- typePair.second->releaseMemory();
- for (auto &typePair : typeTemplateParams)
- typePair.second->releaseMemory();
- }
- inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
- assert(bitwidth >= 8 && bitwidth <= 64 && llvm::isPowerOf2_32(bitwidth));
- return llvm::Log2_32(bitwidth);
- }
- const IntegerType *SpirvContext::getSIntType(uint32_t bitwidth) {
- auto &type = sintTypes[log2ForBitwidth(bitwidth)];
- if (type == nullptr) {
- type = new (this) IntegerType(bitwidth, true);
- }
- return type;
- }
- const IntegerType *SpirvContext::getUIntType(uint32_t bitwidth) {
- auto &type = uintTypes[log2ForBitwidth(bitwidth)];
- if (type == nullptr) {
- type = new (this) IntegerType(bitwidth, false);
- }
- return type;
- }
- const FloatType *SpirvContext::getFloatType(uint32_t bitwidth) {
- auto &type = floatTypes[log2ForBitwidth(bitwidth)];
- if (type == nullptr) {
- type = new (this) FloatType(bitwidth);
- }
- return type;
- }
- const VectorType *SpirvContext::getVectorType(const SpirvType *elemType,
- uint32_t count) {
- // We are certain this should be a scalar type. Otherwise, cast causes an
- // assertion failure.
- const ScalarType *scalarType = cast<ScalarType>(elemType);
- assert(count == 2 || count == 3 || count == 4);
- auto found = vecTypes.find(scalarType);
- if (found != vecTypes.end()) {
- auto &type = found->second[count];
- if (type != nullptr)
- return type;
- } else {
- // Make sure to initialize since std::array is "an aggregate type with the
- // same semantics as a struct holding a C-style array T[N]".
- vecTypes[scalarType] = {};
- }
- return vecTypes[scalarType][count] = new (this) VectorType(scalarType, count);
- }
- const SpirvType *SpirvContext::getMatrixType(const SpirvType *elemType,
- uint32_t count) {
- // We are certain this should be a vector type. Otherwise, cast causes an
- // assertion failure.
- const VectorType *vecType = cast<VectorType>(elemType);
- assert(count == 2 || count == 3 || count == 4);
- // In the case of non-floating-point matrices, we represent them as array of
- // vectors.
- if (!isa<FloatType>(vecType->getElementType())) {
- return getArrayType(elemType, count, llvm::None);
- }
- auto foundVec = matTypes.find(vecType);
- if (foundVec != matTypes.end()) {
- const auto &matVector = foundVec->second;
- // Create a temporary object for finding in the vector.
- MatrixType type(vecType, count);
- for (const auto *cachedType : matVector)
- if (type == *cachedType)
- return cachedType;
- }
- const auto *ptr = new (this) MatrixType(vecType, count);
- matTypes[vecType].push_back(ptr);
- return ptr;
- }
- const ImageType *
- SpirvContext::getImageType(const ImageType *imageTypeWithUnknownFormat,
- spv::ImageFormat format) {
- return getImageType(imageTypeWithUnknownFormat->getSampledType(),
- imageTypeWithUnknownFormat->getDimension(),
- imageTypeWithUnknownFormat->getDepth(),
- imageTypeWithUnknownFormat->isArrayedImage(),
- imageTypeWithUnknownFormat->isMSImage(),
- imageTypeWithUnknownFormat->withSampler(), format);
- }
- const ImageType *SpirvContext::getImageType(const SpirvType *sampledType,
- spv::Dim dim,
- ImageType::WithDepth depth,
- bool arrayed, bool ms,
- ImageType::WithSampler sampled,
- spv::ImageFormat format) {
- // We are certain this should be a numerical type. Otherwise, cast causes an
- // assertion failure.
- const NumericalType *elemType = cast<NumericalType>(sampledType);
- // Create a temporary object for finding in the set.
- ImageType type(elemType, dim, depth, arrayed, ms, sampled, format);
- auto found = imageTypes.find(&type);
- if (found != imageTypes.end())
- return *found;
- auto inserted = imageTypes.insert(
- new (this) ImageType(elemType, dim, depth, arrayed, ms, sampled, format));
- return *(inserted.first);
- }
- const SampledImageType *
- SpirvContext::getSampledImageType(const ImageType *image) {
- auto found = sampledImageTypes.find(image);
- if (found != sampledImageTypes.end())
- return found->second;
- return sampledImageTypes[image] = new (this) SampledImageType(image);
- }
- const HybridSampledImageType *
- SpirvContext::getSampledImageType(QualType image) {
- const HybridSampledImageType *result =
- new (this) HybridSampledImageType(image);
- hybridSampledImageTypes.push_back(result);
- return result;
- }
- const ArrayType *
- SpirvContext::getArrayType(const SpirvType *elemType, uint32_t elemCount,
- llvm::Optional<uint32_t> arrayStride) {
- ArrayType type(elemType, elemCount, arrayStride);
- auto found = arrayTypes.find(&type);
- if (found != arrayTypes.end())
- return *found;
- auto inserted =
- arrayTypes.insert(new (this) ArrayType(elemType, elemCount, arrayStride));
- // The return value is an (iterator, bool) pair. The boolean indicates whether
- // it was actually added as a new type.
- return *(inserted.first);
- }
- const RuntimeArrayType *
- SpirvContext::getRuntimeArrayType(const SpirvType *elemType,
- llvm::Optional<uint32_t> arrayStride) {
- RuntimeArrayType type(elemType, arrayStride);
- auto found = runtimeArrayTypes.find(&type);
- if (found != runtimeArrayTypes.end())
- return *found;
- auto inserted = runtimeArrayTypes.insert(
- new (this) RuntimeArrayType(elemType, arrayStride));
- return *(inserted.first);
- }
- const StructType *
- SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
- llvm::StringRef name, bool isReadOnly,
- StructInterfaceType interfaceType) {
- // We are creating a temporary struct type here for querying whether the
- // same type was already created. It is a little bit costly, but we can
- // avoid allocating directly from the bump pointer allocator, from which
- // then we are unable to reclaim until the allocator itself is destroyed.
- StructType type(fields, name, isReadOnly, interfaceType);
- auto found = std::find_if(
- structTypes.begin(), structTypes.end(),
- [&type](const StructType *cachedType) { return type == *cachedType; });
- if (found != structTypes.end())
- return *found;
- structTypes.push_back(
- new (this) StructType(fields, name, isReadOnly, interfaceType));
- return structTypes.back();
- }
- const HybridStructType *SpirvContext::getHybridStructType(
- llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
- bool isReadOnly, StructInterfaceType interfaceType) {
- const HybridStructType *result =
- new (this) HybridStructType(fields, name, isReadOnly, interfaceType);
- hybridStructTypes.push_back(result);
- return result;
- }
- const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
- spv::StorageClass sc) {
- auto foundPointee = pointerTypes.find(pointee);
- if (foundPointee != pointerTypes.end()) {
- auto &pointeeMap = foundPointee->second;
- auto foundSC = pointeeMap.find(sc);
- if (foundSC != pointeeMap.end())
- return foundSC->second;
- }
- return pointerTypes[pointee][sc] = new (this) SpirvPointerType(pointee, sc);
- }
- const HybridPointerType *SpirvContext::getPointerType(QualType pointee,
- spv::StorageClass sc) {
- const HybridPointerType *result = new (this) HybridPointerType(pointee, sc);
- hybridPointerTypes.push_back(result);
- return result;
- }
- FunctionType *
- SpirvContext::getFunctionType(const SpirvType *ret,
- llvm::ArrayRef<const SpirvType *> param) {
- // Create a temporary object for finding in the set.
- FunctionType type(ret, param);
- auto found = functionTypes.find(&type);
- if (found != functionTypes.end())
- return *found;
- auto inserted = functionTypes.insert(new (this) FunctionType(ret, param));
- return *inserted.first;
- }
- const StructType *SpirvContext::getByteAddressBufferType(bool isWritable) {
- // Create a uint RuntimeArray.
- const auto *raType =
- getRuntimeArrayType(getUIntType(32), /* ArrayStride */ 4);
- // Create a struct containing the runtime array as its only member.
- return getStructType(
- {StructType::FieldInfo(raType, /*name*/ "", /*offset*/ 0)},
- isWritable ? "type.RWByteAddressBuffer" : "type.ByteAddressBuffer",
- !isWritable, StructInterfaceType::StorageBuffer);
- }
- const StructType *SpirvContext::getACSBufferCounterType() {
- // Create int32.
- const auto *int32Type = getSIntType(32);
- // Create a struct containing the integer counter as its only member.
- const StructType *type =
- getStructType({StructType::FieldInfo(int32Type, "counter", /*offset*/ 0)},
- "type.ACSBuffer.counter",
- /*isReadOnly*/ false, StructInterfaceType::StorageBuffer);
- return type;
- }
- SpirvDebugType *SpirvContext::getDebugTypeBasic(const SpirvType *spirvType,
- llvm::StringRef name,
- SpirvConstant *size,
- uint32_t encoding) {
- // Reuse existing debug type if possible.
- if (debugTypes.find(spirvType) != debugTypes.end())
- return debugTypes[spirvType];
- auto *debugType = new (this) SpirvDebugTypeBasic(name, size, encoding);
- debugTypes[spirvType] = debugType;
- return debugType;
- }
- SpirvDebugType *
- SpirvContext::getDebugTypeMember(llvm::StringRef name, SpirvDebugType *type,
- SpirvDebugSource *source, uint32_t line,
- uint32_t column, SpirvDebugInstruction *parent,
- uint32_t flags, uint32_t offsetInBits,
- uint32_t sizeInBits, const APValue *value) {
- // NOTE: Do not search it in debugTypes because it would have the same
- // spirvType but has different parent i.e., type composite.
- SpirvDebugTypeMember *debugType =
- new (this) SpirvDebugTypeMember(name, type, source, line, column, parent,
- flags, offsetInBits, sizeInBits, value);
- return debugType;
- }
- SpirvDebugTypeComposite *SpirvContext::getDebugTypeComposite(
- const SpirvType *spirvType, llvm::StringRef name, SpirvDebugSource *source,
- uint32_t line, uint32_t column, SpirvDebugInstruction *parent,
- llvm::StringRef linkageName, uint32_t flags, uint32_t tag) {
- // Reuse existing debug type if possible.
- auto it = debugTypes.find(spirvType);
- if (it != debugTypes.end()) {
- assert(it->second != nullptr && isa<SpirvDebugTypeComposite>(it->second));
- return dyn_cast<SpirvDebugTypeComposite>(it->second);
- }
- auto *debugType = new (this) SpirvDebugTypeComposite(
- name, source, line, column, parent, linkageName, flags, tag);
- debugType->setDebugSpirvType(spirvType);
- debugTypes[spirvType] = debugType;
- return debugType;
- }
- SpirvDebugType *SpirvContext::getDebugType(const SpirvType *spirvType) {
- auto it = debugTypes.find(spirvType);
- if (it != debugTypes.end())
- return it->second;
- return nullptr;
- }
- SpirvDebugType *
- SpirvContext::getDebugTypeArray(const SpirvType *spirvType,
- SpirvDebugInstruction *elemType,
- llvm::ArrayRef<uint32_t> elemCount) {
- // Reuse existing debug type if possible.
- if (debugTypes.find(spirvType) != debugTypes.end())
- return debugTypes[spirvType];
- auto *eTy = dyn_cast<SpirvDebugType>(elemType);
- assert(eTy && "Element type must be a SpirvDebugType.");
- auto *debugType = new (this) SpirvDebugTypeArray(eTy, elemCount);
- debugTypes[spirvType] = debugType;
- return debugType;
- }
- SpirvDebugType *
- SpirvContext::getDebugTypeVector(const SpirvType *spirvType,
- SpirvDebugInstruction *elemType,
- uint32_t elemCount) {
- // Reuse existing debug type if possible.
- if (debugTypes.find(spirvType) != debugTypes.end())
- return debugTypes[spirvType];
- auto *eTy = dyn_cast<SpirvDebugType>(elemType);
- assert(eTy && "Element type must be a SpirvDebugType.");
- auto *debugType = new (this) SpirvDebugTypeVector(eTy, elemCount);
- debugTypes[spirvType] = debugType;
- return debugType;
- }
- SpirvDebugType *
- SpirvContext::getDebugTypeFunction(const SpirvType *spirvType, uint32_t flags,
- SpirvDebugType *ret,
- llvm::ArrayRef<SpirvDebugType *> params) {
- // Reuse existing debug type if possible.
- if (debugTypes.find(spirvType) != debugTypes.end())
- return debugTypes[spirvType];
- auto *debugType = new (this) SpirvDebugTypeFunction(flags, ret, params);
- debugTypes[spirvType] = debugType;
- return debugType;
- }
- SpirvDebugTypeTemplate *SpirvContext::createDebugTypeTemplate(
- const ClassTemplateSpecializationDecl *templateType,
- SpirvDebugInstruction *target,
- const llvm::SmallVector<SpirvDebugTypeTemplateParameter *, 2> ¶ms) {
- auto *tempTy = getDebugTypeTemplate(templateType);
- if (tempTy != nullptr)
- return tempTy;
- tempTy = new (this) SpirvDebugTypeTemplate(target, params);
- typeTemplates[templateType] = tempTy;
- return tempTy;
- }
- SpirvDebugTypeTemplate *SpirvContext::getDebugTypeTemplate(
- const ClassTemplateSpecializationDecl *templateType) {
- auto it = typeTemplates.find(templateType);
- if (it != typeTemplates.end())
- return it->second;
- return nullptr;
- }
- SpirvDebugTypeTemplateParameter *SpirvContext::createDebugTypeTemplateParameter(
- const TemplateArgument *templateArg, llvm::StringRef name,
- SpirvDebugType *type, SpirvInstruction *value, SpirvDebugSource *source,
- uint32_t line, uint32_t column) {
- auto *param = getDebugTypeTemplateParameter(templateArg);
- if (param != nullptr)
- return param;
- param = new (this)
- SpirvDebugTypeTemplateParameter(name, type, value, source, line, column);
- typeTemplateParams[templateArg] = param;
- return param;
- }
- SpirvDebugTypeTemplateParameter *SpirvContext::getDebugTypeTemplateParameter(
- const TemplateArgument *templateArg) {
- auto it = typeTemplateParams.find(templateArg);
- if (it != typeTemplateParams.end())
- return it->second;
- return nullptr;
- }
- void SpirvContext::pushDebugLexicalScope(RichDebugInfo *info,
- SpirvDebugInstruction *scope) {
- assert((isa<SpirvDebugLexicalBlock>(scope) ||
- isa<SpirvDebugFunction>(scope) ||
- isa<SpirvDebugCompilationUnit>(scope) ||
- isa<SpirvDebugTypeComposite>(scope)) &&
- "Given scope is not a lexical scope");
- currentLexicalScope = scope;
- info->scopeStack.push_back(scope);
- }
- void SpirvContext::moveDebugTypesToModule(SpirvModule *module) {
- for (const auto &typePair : debugTypes) {
- module->addDebugInfo(typePair.second);
- if (auto *composite = dyn_cast<SpirvDebugTypeComposite>(typePair.second)) {
- for (auto *member : composite->getMembers()) {
- module->addDebugInfo(member);
- }
- }
- }
- for (const auto &typePair : typeTemplates) {
- module->addDebugInfo(typePair.second);
- }
- for (const auto &typePair : typeTemplateParams) {
- module->addDebugInfo(typePair.second);
- }
- debugTypes.clear();
- typeTemplates.clear();
- typeTemplateParams.clear();
- }
- } // end namespace spirv
- } // end namespace clang
|