//===------- SpirvEmitter.cpp - SPIR-V Binary Code Emitter ------*- C++ -*-===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. //===----------------------------------------------------------------------===// // // This file implements a SPIR-V emitter class that takes in HLSL AST and emits // SPIR-V binary words. // //===----------------------------------------------------------------------===// #include "SpirvEmitter.h" #include "AlignmentSizeCalculator.h" #include "RawBufferMethods.h" #include "dxc/HlslIntrinsicOp.h" #include "spirv-tools/optimizer.hpp" #include "clang/SPIRV/AstTypeProbe.h" #include "clang/Sema/Sema.h" #include "llvm/ADT/StringExtras.h" #include "InitListHandler.h" #include "dxc/DXIL/DxilConstants.h" #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO #include "clang/Basic/Version.h" #else namespace clang { uint32_t getGitCommitCount() { return 0; } const char *getGitCommitHash() { return ""; } } // namespace clang #endif // SUPPORT_QUERY_GIT_COMMIT_INFO namespace clang { namespace spirv { namespace { // Returns true if the given decl is an implicit variable declaration inside the // "vk" namespace. bool isImplicitVarDeclInVkNamespace(const Decl *decl) { if (!decl) return false; if (auto *varDecl = dyn_cast(decl)) { // Check whether it is implicitly defined. if (!decl->isImplicit()) return false; if (auto *nsDecl = dyn_cast(varDecl->getDeclContext())) if (nsDecl->getName().equals("vk")) return true; } return false; } // Returns true if the given decl has the given semantic. bool hasSemantic(const DeclaratorDecl *decl, hlsl::DXIL::SemanticKind semanticKind) { using namespace hlsl; for (auto *annotation : decl->getUnusualAnnotations()) { if (auto *semanticDecl = dyn_cast(annotation)) { llvm::StringRef semanticName; uint32_t semanticIndex = 0; Semantic::DecomposeNameAndIndex(semanticDecl->SemanticName, &semanticName, &semanticIndex); const auto *semantic = Semantic::GetByName(semanticName); if (semantic->GetKind() == semanticKind) return true; } } return false; } const ParmVarDecl *patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) { for (const auto *param : pcf->parameters()) if (hlsl::IsHLSLOutputPatchType(param->getType())) return param; return nullptr; } inline bool isSpirvMatrixOp(spv::Op opcode) { return opcode == spv::Op::OpMatrixTimesMatrix || opcode == spv::Op::OpMatrixTimesVector || opcode == spv::Op::OpMatrixTimesScalar; } /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes /// index. Otherwiser, returns false. // TODO: The following doesn't handle Load(int, int) yet. And it is basically a // duplicate of doCXXMemberCallExpr. const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) { using namespace hlsl; if (const auto *indexing = dyn_cast(expr)) { const auto *callee = indexing->getDirectCallee(); uint32_t opcode = static_cast(IntrinsicOp::Num_Intrinsics); llvm::StringRef group; if (GetIntrinsicOp(callee, opcode, group)) { if (static_cast(opcode) == IntrinsicOp::MOP_Load) { const auto *object = indexing->getImplicitObjectArgument(); if (isStructuredBuffer(object->getType())) { *index = indexing->getArg(0); return indexing->getImplicitObjectArgument(); } } } } return nullptr; } /// Returns true if the given VarDecl will be translated into a SPIR-V variable /// not in the Private or Function storage class. inline bool isExternalVar(const VarDecl *var) { // Class static variables should be put in the Private storage class. // groupshared variables are allowed to be declared as "static". But we still // need to put them in the Workgroup storage class. That is, when seeing // "static groupshared", ignore "static". return var->hasExternalFormalLinkage() ? !var->isStaticDataMember() : (var->getAttr() != nullptr); } /// Returns the referenced variable's DeclContext if the given expr is /// a DeclRefExpr referencing a ConstantBuffer/TextureBuffer. Otherwise, /// returns nullptr. const DeclContext *isConstantTextureBufferDeclRef(const Expr *expr) { if (const auto *declRefExpr = dyn_cast(expr->IgnoreParenCasts())) if (const auto *varDecl = dyn_cast(declRefExpr->getFoundDecl())) if (isConstantTextureBuffer(varDecl->getType())) return hlsl::GetHLSLResourceResultType(varDecl->getType()) ->getAs() ->getDecl(); return nullptr; } /// Returns true if /// * the given expr is an DeclRefExpr referencing a kind of structured or byte /// buffer and it is non-alias one, or /// * the given expr is an CallExpr returning a kind of structured or byte /// buffer. /// * the given expr is an ArraySubscriptExpr referencing a kind of structured /// or byte buffer. /// /// Note: legalization specific code bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) { expr = expr->IgnoreParenCasts(); if (const auto *declRefExpr = dyn_cast(expr)) { if (const auto *varDecl = dyn_cast(declRefExpr->getFoundDecl())) if (isAKindOfStructuredOrByteBuffer(varDecl->getType())) return isExternalVar(varDecl); } else if (const auto *callExpr = dyn_cast(expr)) { if (isAKindOfStructuredOrByteBuffer(callExpr->getType())) return true; } else if (const auto *arrSubExpr = dyn_cast(expr)) { return isReferencingNonAliasStructuredOrByteBuffer(arrSubExpr->getBase()); } return false; } /// Translates atomic HLSL opcodes into the equivalent SPIR-V opcode. spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) { using namespace hlsl; using namespace spv; switch (opcode) { case IntrinsicOp::IOP_InterlockedAdd: case IntrinsicOp::MOP_InterlockedAdd: return Op::OpAtomicIAdd; case IntrinsicOp::IOP_InterlockedAnd: case IntrinsicOp::MOP_InterlockedAnd: return Op::OpAtomicAnd; case IntrinsicOp::IOP_InterlockedOr: case IntrinsicOp::MOP_InterlockedOr: return Op::OpAtomicOr; case IntrinsicOp::IOP_InterlockedXor: case IntrinsicOp::MOP_InterlockedXor: return Op::OpAtomicXor; case IntrinsicOp::IOP_InterlockedUMax: case IntrinsicOp::MOP_InterlockedUMax: return Op::OpAtomicUMax; case IntrinsicOp::IOP_InterlockedUMin: case IntrinsicOp::MOP_InterlockedUMin: return Op::OpAtomicUMin; case IntrinsicOp::IOP_InterlockedMax: case IntrinsicOp::MOP_InterlockedMax: return Op::OpAtomicSMax; case IntrinsicOp::IOP_InterlockedMin: case IntrinsicOp::MOP_InterlockedMin: return Op::OpAtomicSMin; case IntrinsicOp::IOP_InterlockedExchange: case IntrinsicOp::MOP_InterlockedExchange: return Op::OpAtomicExchange; default: // Only atomic opcodes are relevant. break; } assert(false && "unimplemented hlsl intrinsic opcode"); return Op::Max; } // Returns true if the given opcode is an accepted binary opcode in // OpSpecConstantOp. bool isAcceptedSpecConstantBinaryOp(spv::Op op) { switch (op) { case spv::Op::OpIAdd: case spv::Op::OpISub: case spv::Op::OpIMul: case spv::Op::OpUDiv: case spv::Op::OpSDiv: case spv::Op::OpUMod: case spv::Op::OpSRem: case spv::Op::OpSMod: case spv::Op::OpShiftRightLogical: case spv::Op::OpShiftRightArithmetic: case spv::Op::OpShiftLeftLogical: case spv::Op::OpBitwiseOr: case spv::Op::OpBitwiseXor: case spv::Op::OpBitwiseAnd: case spv::Op::OpVectorShuffle: case spv::Op::OpCompositeExtract: case spv::Op::OpCompositeInsert: case spv::Op::OpLogicalOr: case spv::Op::OpLogicalAnd: case spv::Op::OpLogicalNot: case spv::Op::OpLogicalEqual: case spv::Op::OpLogicalNotEqual: case spv::Op::OpIEqual: case spv::Op::OpINotEqual: case spv::Op::OpULessThan: case spv::Op::OpSLessThan: case spv::Op::OpUGreaterThan: case spv::Op::OpSGreaterThan: case spv::Op::OpULessThanEqual: case spv::Op::OpSLessThanEqual: case spv::Op::OpUGreaterThanEqual: case spv::Op::OpSGreaterThanEqual: return true; default: // Accepted binary opcodes return true. Anything else is false. return false; } return false; } /// Returns true if the given expression is an accepted initializer for a spec /// constant. bool isAcceptedSpecConstantInit(const Expr *init) { // Allow numeric casts init = init->IgnoreParenCasts(); if (isa(init) || isa(init) || isa(init)) return true; // Allow the minus operator which is used to specify negative values if (const auto *unaryOp = dyn_cast(init)) return unaryOp->getOpcode() == UO_Minus && isAcceptedSpecConstantInit(unaryOp->getSubExpr()); return false; } /// Returns true if the given function parameter can act as shader stage /// input parameter. inline bool canActAsInParmVar(const ParmVarDecl *param) { // If the parameter has no in/out/inout attribute, it is defaulted to // an in parameter. return !param->hasAttr() && // GS output streams are marked as inout, but it should not be // used as in parameter. !hlsl::IsHLSLStreamOutputType(param->getType()); } /// Returns true if the given function parameter can act as shader stage /// output parameter. inline bool canActAsOutParmVar(const ParmVarDecl *param) { return param->hasAttr() || param->hasAttr() || hlsl::IsHLSLRayQueryType(param->getType()); } /// Returns true if the given expression is of builtin type and can be evaluated /// to a constant zero. Returns false otherwise. inline bool evaluatesToConstZero(const Expr *expr, ASTContext &astContext) { const auto type = expr->getType(); if (!type->isBuiltinType()) return false; Expr::EvalResult evalResult; if (expr->EvaluateAsRValue(evalResult, astContext) && !evalResult.HasSideEffects) { const auto &val = evalResult.Val; return ((type->isBooleanType() && !val.getInt().getBoolValue()) || (type->isIntegerType() && !val.getInt().getBoolValue()) || (type->isFloatingType() && val.getFloat().isZero())); } return false; } /// Returns the real definition of the callee of the given CallExpr. /// /// If we are calling a forward-declared function, callee will be the /// FunctionDecl for the foward-declared function, not the actual /// definition. The foward-delcaration and defintion are two completely /// different AST nodes. inline const FunctionDecl *getCalleeDefinition(const CallExpr *expr) { const auto *callee = expr->getDirectCallee(); if (callee->isThisDeclarationADefinition()) return callee; // We need to update callee to the actual definition here if (!callee->isDefined(callee)) return nullptr; return callee; } /// Returns the referenced definition. The given expr is expected to be a /// DeclRefExpr or CallExpr after ignoring casts. Returns nullptr otherwise. const DeclaratorDecl *getReferencedDef(const Expr *expr) { if (!expr) return nullptr; expr = expr->IgnoreParenCasts(); if (const auto *declRefExpr = dyn_cast(expr)) { return dyn_cast_or_null(declRefExpr->getDecl()); } if (const auto *callExpr = dyn_cast(expr)) { return getCalleeDefinition(callExpr); } return nullptr; } /// Returns the number of base classes if this type is a derived class/struct. /// Returns zero otherwise. inline uint32_t getNumBaseClasses(QualType type) { if (const auto *cxxDecl = type->getAsCXXRecordDecl()) return cxxDecl->getNumBases(); return 0; } /// Gets the index sequence of casting a derived object to a base object by /// following the cast chain. void getBaseClassIndices(const CastExpr *expr, llvm::SmallVectorImpl *indices) { assert(expr->getCastKind() == CK_UncheckedDerivedToBase || expr->getCastKind() == CK_HLSLDerivedToBase); indices->clear(); QualType derivedType = expr->getSubExpr()->getType(); // There are two types of UncheckedDerivedToBase/HLSLDerivedToBase casts: // // The first is when a derived object tries to access a member in the base. // For example: derived.base_member. // ImplicitCastExpr 'Base' lvalue // `-DeclRefExpr 'Derived' lvalue Var 0x1f0d9bb2890 'derived' 'Derived' // // The second is when a pointer of the dervied is used to access members or // methods of the base. There are currently no pointers in HLSL, but the // method defintions can use the "this" pointer. // For example: // class Base { float value; }; // class Derviced : Base { // float4 getBaseValue() { return value; } // }; // // In this example, the 'this' pointer (pointing to Derived) is used inside // 'getBaseValue', which is then cast to a Base pointer: // // ImplicitCastExpr 'Base *' // `-CXXThisExpr 'Derviced *' this // // Therefore in order to obtain the derivedDecl below, we must make sure that // we handle the second case too by using the pointee type. if (derivedType->isPointerType()) derivedType = derivedType->getPointeeType(); const auto *derivedDecl = derivedType->getAsCXXRecordDecl(); // Go through the base cast chain: for each of the derived to base cast, find // the index of the base in question in the derived's bases. for (auto pathIt = expr->path_begin(), pathIe = expr->path_end(); pathIt != pathIe; ++pathIt) { // The type of the base in question const auto baseType = (*pathIt)->getType(); uint32_t index = 0; for (auto baseIt = derivedDecl->bases_begin(), baseIe = derivedDecl->bases_end(); baseIt != baseIe; ++baseIt, ++index) if (baseIt->getType() == baseType) { indices->push_back(index); break; } assert(index < derivedDecl->getNumBases()); // Continue to proceed the next base in the chain derivedType = baseType; if (derivedType->isPointerType()) derivedType = derivedType->getPointeeType(); derivedDecl = derivedType->getAsCXXRecordDecl(); } } std::string getNamespacePrefix(const Decl *decl) { std::string nsPrefix = ""; const DeclContext *dc = decl->getDeclContext(); while (dc && !dc->isTranslationUnit()) { if (const NamespaceDecl *ns = dyn_cast(dc)) { if (!ns->isAnonymousNamespace()) { nsPrefix = ns->getName().str() + "::" + nsPrefix; } } dc = dc->getParent(); } return nsPrefix; } std::string getFnName(const FunctionDecl *fn) { // Prefix the function name with the struct name if necessary std::string classOrStructName = ""; if (const auto *memberFn = dyn_cast(fn)) if (const auto *st = dyn_cast(memberFn->getDeclContext())) classOrStructName = st->getName().str() + "."; return getNamespacePrefix(fn) + classOrStructName + fn->getName().str(); } bool isMemoryObjectDeclaration(SpirvInstruction *inst) { return isa(inst) || isa(inst); } } // namespace SpirvEmitter::SpirvEmitter(CompilerInstance &ci) : theCompilerInstance(ci), astContext(ci.getASTContext()), diags(ci.getDiagnostics()), spirvOptions(ci.getCodeGenOpts().SpirvOptions), entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), spvContext(), featureManager(diags, spirvOptions), spvBuilder(astContext, spvContext, spirvOptions), declIdMapper(astContext, spvContext, spvBuilder, *this, featureManager, spirvOptions), entryFunction(nullptr), curFunction(nullptr), curThis(nullptr), seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false), beforeHlslLegalization(false), mainSourceFile(nullptr) { // Get ShaderModel from command line hlsl profile option. const hlsl::ShaderModel *shaderModel = hlsl::ShaderModel::GetByName(ci.getCodeGenOpts().HLSLProfile.c_str()); if (shaderModel->GetKind() == hlsl::ShaderModel::Kind::Invalid) emitError("unknown shader module: %0", {}) << shaderModel->GetName(); if (spirvOptions.invertY && !shaderModel->IsVS() && !shaderModel->IsDS() && !shaderModel->IsGS()) emitError("-fvk-invert-y can only be used in VS/DS/GS", {}); if (spirvOptions.useGlLayout && spirvOptions.useDxLayout) emitError("cannot specify both -fvk-use-dx-layout and -fvk-use-gl-layout", {}); // Set shader model kind and hlsl major/minor version. spvContext.setCurrentShaderModelKind(shaderModel->GetKind()); spvContext.setMajorVersion(shaderModel->GetMajor()); spvContext.setMinorVersion(shaderModel->GetMinor()); if (spirvOptions.useDxLayout) { spirvOptions.cBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer; spirvOptions.tBufferLayoutRule = SpirvLayoutRule::FxcCTBuffer; spirvOptions.sBufferLayoutRule = SpirvLayoutRule::FxcSBuffer; spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::FxcSBuffer; } else if (spirvOptions.useGlLayout) { spirvOptions.cBufferLayoutRule = SpirvLayoutRule::GLSLStd140; spirvOptions.tBufferLayoutRule = SpirvLayoutRule::GLSLStd430; spirvOptions.sBufferLayoutRule = SpirvLayoutRule::GLSLStd430; spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::GLSLStd430; } else if (spirvOptions.useScalarLayout) { spirvOptions.cBufferLayoutRule = SpirvLayoutRule::Scalar; spirvOptions.tBufferLayoutRule = SpirvLayoutRule::Scalar; spirvOptions.sBufferLayoutRule = SpirvLayoutRule::Scalar; spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::Scalar; } else { spirvOptions.cBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd140; spirvOptions.tBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430; spirvOptions.sBufferLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430; spirvOptions.ampPayloadLayoutRule = SpirvLayoutRule::RelaxedGLSLStd430; } // Set shader module version, source file name, and source file content (if // needed). llvm::StringRef source; std::vector fileNames; const auto &inputFiles = ci.getFrontendOpts().Inputs; // File name if (spirvOptions.debugInfoFile && !inputFiles.empty()) { for (const auto &inputFile : inputFiles) { fileNames.push_back(inputFile.getFile()); } } // Source code if (spirvOptions.debugInfoSource) { const auto &sm = ci.getSourceManager(); const llvm::MemoryBuffer *mainFile = sm.getBuffer(sm.getMainFileID(), SourceLocation()); source = StringRef(mainFile->getBufferStart(), mainFile->getBufferSize()); } mainSourceFile = spvBuilder.setDebugSource(spvContext.getMajorVersion(), spvContext.getMinorVersion(), fileNames, source); // OpenCL.DebugInfo.100 DebugSource if (spirvOptions.debugInfoRich) { auto *dbgSrc = spvBuilder.createDebugSource(mainSourceFile->getString()); // spvContext.getDebugInfo().insert() inserts {string key, RichDebugInfo} // pair and returns {{string key, RichDebugInfo}, true /*Success*/}. // spvContext.getDebugInfo().insert().first->second is a RichDebugInfo. auto *richDebugInfo = &spvContext.getDebugInfo() .insert( {mainSourceFile->getString(), RichDebugInfo(dbgSrc, spvBuilder.createDebugCompilationUnit(dbgSrc))}) .first->second; spvContext.pushDebugLexicalScope(richDebugInfo, richDebugInfo->scopeStack.back()); } if (spirvOptions.debugInfoTool && featureManager.isTargetEnvVulkan1p1OrAbove()) { // Emit OpModuleProcessed to indicate the commit information. std::string commitHash = std::string("dxc-commit-hash: ") + clang::getGitCommitHash(); spvBuilder.addModuleProcessed(commitHash); // Emit OpModuleProcessed to indicate the command line options that were // used to generate this module. if (!spirvOptions.clOptions.empty()) { // Using this format: "dxc-cl-option: XXXXXX" std::string clOptionStr = "dxc-cl-option:" + spirvOptions.clOptions; spvBuilder.addModuleProcessed(clOptionStr); } } } void SpirvEmitter::HandleTranslationUnit(ASTContext &context) { // Stop translating if there are errors in previous compilation stages. if (context.getDiagnostics().hasErrorOccurred()) return; TranslationUnitDecl *tu = context.getTranslationUnitDecl(); uint32_t numEntryPoints = 0; // The entry function is the seed of the queue. for (auto *decl : tu->decls()) { if (auto *funcDecl = dyn_cast(decl)) { if (spvContext.isLib()) { if (const auto *shaderAttr = funcDecl->getAttr()) { // If we are compiling as a library then add everything that has a // ShaderAttr. addFunctionToWorkQueue(getShaderModelKind(shaderAttr->getStage()), funcDecl, /*isEntryFunction*/ true); numEntryPoints++; } else if (funcDecl->getAttr()) { addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(), funcDecl, /*isEntryFunction*/ false); } } else { if (funcDecl->getName() == entryFunctionName) { addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(), funcDecl, /*isEntryFunction*/ true); numEntryPoints++; } } } else { doDecl(decl); } if (context.getDiagnostics().hasErrorOccurred()) return; } // Translate all functions reachable from the entry function. // The queue can grow in the meanwhile; so need to keep evaluating // workQueue.size(). for (uint32_t i = 0; i < workQueue.size(); ++i) { const FunctionInfo *curEntryOrCallee = workQueue[i]; spvContext.setCurrentShaderModelKind(curEntryOrCallee->shaderModelKind); doDecl(curEntryOrCallee->funcDecl); if (context.getDiagnostics().hasErrorOccurred()) return; } // Addressing and memory model are required in a valid SPIR-V module. spvBuilder.setMemoryModel(spv::AddressingModel::Logical, spv::MemoryModel::GLSL450); // Even though the 'workQueue' grows due to the above loop, the first // 'numEntryPoints' entries in the 'workQueue' are the ones with the HLSL // 'shader' attribute, and must therefore be entry functions. assert(numEntryPoints <= workQueue.size()); for (uint32_t i = 0; i < numEntryPoints; ++i) { // TODO: assign specific StageVars w.r.t. to entry point const FunctionInfo *entryInfo = workQueue[i]; assert(entryInfo->isEntryFunction); spvBuilder.addEntryPoint( getSpirvShaderStage(entryInfo->shaderModelKind), entryInfo->entryFunction, entryInfo->funcDecl->getName(), featureManager.isTargetEnvVulkan1p2OrAbove() ? spvBuilder.getModule()->getVariables() : llvm::ArrayRef(declIdMapper.collectStageVars())); } // Add Location decorations to stage input/output variables. if (!declIdMapper.decorateStageIOLocations()) return; // Add descriptor set and binding decorations to resource variables. if (!declIdMapper.decorateResourceBindings()) return; // Add Coherent docrations to resource variables. if (!declIdMapper.decorateResourceCoherent()) return; // Output the constructed module. std::vector m = spvBuilder.takeModule(); if (!spirvOptions.codeGenHighLevel) { // In order to flatten composite resources, we must also unroll loops. // Therefore we should run legalization before optimization. needsLegalization = needsLegalization || declIdMapper.requiresLegalization() || spirvOptions.flattenResourceArrays || declIdMapper.requiresFlatteningCompositeResources(); // Run legalization passes if (needsLegalization) { std::string messages; if (!spirvToolsLegalize(&m, &messages)) { emitFatalError("failed to legalize SPIR-V: %0", {}) << messages; emitNote("please file a bug report on " "https://github.com/Microsoft/DirectXShaderCompiler/issues " "with source code if possible", {}); return; } else if (!messages.empty()) { emitWarning("SPIR-V legalization: %0", {}) << messages; } } // Run optimization passes if (theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) { std::string messages; if (!spirvToolsOptimize(&m, &messages)) { emitFatalError("failed to optimize SPIR-V: %0", {}) << messages; emitNote("please file a bug report on " "https://github.com/Microsoft/DirectXShaderCompiler/issues " "with source code if possible", {}); return; } } } // Validate the generated SPIR-V code if (!spirvOptions.disableValidation) { std::string messages; if (!spirvToolsValidate(&m, &messages)) { emitFatalError("generated SPIR-V is invalid: %0", {}) << messages; emitNote("please file a bug report on " "https://github.com/Microsoft/DirectXShaderCompiler/issues " "with source code if possible", {}); return; } } theCompilerInstance.getOutStream()->write( reinterpret_cast(m.data()), m.size() * 4); } void SpirvEmitter::doDecl(const Decl *decl) { if (isa(decl) || isa(decl)) return; // Implicit decls are lazily created when needed. if (decl->isImplicit()) { return; } if (const auto *varDecl = dyn_cast(decl)) { doVarDecl(varDecl); } else if (const auto *namespaceDecl = dyn_cast(decl)) { for (auto *subDecl : namespaceDecl->decls()) // Note: We only emit functions as they are discovered through the call // graph starting from the entry-point. We should not emit unused // functions inside namespaces. if (!isa(subDecl)) doDecl(subDecl); } else if (const auto *funcDecl = dyn_cast(decl)) { doFunctionDecl(funcDecl); } else if (const auto *bufferDecl = dyn_cast(decl)) { doHLSLBufferDecl(bufferDecl); } else if (const auto *recordDecl = dyn_cast(decl)) { doRecordDecl(recordDecl); } else if (const auto *enumDecl = dyn_cast(decl)) { doEnumDecl(enumDecl); } else { emitError("decl type %0 unimplemented", decl->getLocation()) << decl->getDeclKindName(); } } RichDebugInfo * SpirvEmitter::getOrCreateRichDebugInfo(const SourceLocation &loc) { const StringRef file = astContext.getSourceManager().getPresumedLoc(loc).getFilename(); auto &debugInfo = spvContext.getDebugInfo(); auto it = debugInfo.find(file); if (it != debugInfo.end()) return &it->second; auto *dbgSrc = spvBuilder.createDebugSource(file); // debugInfo.insert() inserts {string key, RichDebugInfo} pair and // returns {{string key, RichDebugInfo}, true /*Success*/}. // debugInfo.insert().first->second is a RichDebugInfo. return &debugInfo .insert({file, RichDebugInfo( dbgSrc, spvBuilder.createDebugCompilationUnit( dbgSrc))}) .first->second; } void SpirvEmitter::doStmt(const Stmt *stmt, llvm::ArrayRef attrs) { if (const auto *compoundStmt = dyn_cast(stmt)) { if (spirvOptions.debugInfoRich) { // Any opening of curly braces ('{') starts a CompoundStmt in the AST // tree. It also means we have a new lexical block! const auto loc = stmt->getLocStart(); const auto &sm = astContext.getSourceManager(); const uint32_t line = sm.getPresumedLineNumber(loc); const uint32_t column = sm.getPresumedColumnNumber(loc); RichDebugInfo *info = getOrCreateRichDebugInfo(loc); auto *debugLexicalBlock = spvBuilder.createDebugLexicalBlock( info->source, line, column, info->scopeStack.back()); // Add this lexical block to the stack of lexical scopes. spvContext.pushDebugLexicalScope(info, debugLexicalBlock); // Update or add DebugScope. if (spvBuilder.getInsertPoint()->empty()) { spvBuilder.getInsertPoint()->updateDebugScope( new (spvContext) SpirvDebugScope(debugLexicalBlock)); } else if (!spvBuilder.isCurrentBasicBlockTerminated()) { spvBuilder.createDebugScope(debugLexicalBlock); } // Iterate over sub-statements for (auto *st : compoundStmt->body()) doStmt(st); // We are done with processing this compound statement. Remove its lexical // block from the stack of lexical scopes. spvContext.popDebugLexicalScope(info); if (!spvBuilder.isCurrentBasicBlockTerminated()) { spvBuilder.createDebugScope(spvContext.getCurrentLexicalScope()); } } else { // Iterate over sub-statements for (auto *st : compoundStmt->body()) doStmt(st); } } else if (const auto *retStmt = dyn_cast(stmt)) { doReturnStmt(retStmt); } else if (const auto *declStmt = dyn_cast(stmt)) { doDeclStmt(declStmt); } else if (const auto *ifStmt = dyn_cast(stmt)) { doIfStmt(ifStmt, attrs); } else if (const auto *switchStmt = dyn_cast(stmt)) { doSwitchStmt(switchStmt, attrs); } else if (dyn_cast(stmt)) { processCaseStmtOrDefaultStmt(stmt); } else if (dyn_cast(stmt)) { processCaseStmtOrDefaultStmt(stmt); } else if (const auto *breakStmt = dyn_cast(stmt)) { doBreakStmt(breakStmt); } else if (const auto *theDoStmt = dyn_cast(stmt)) { doDoStmt(theDoStmt, attrs); } else if (const auto *discardStmt = dyn_cast(stmt)) { doDiscardStmt(discardStmt); } else if (const auto *continueStmt = dyn_cast(stmt)) { doContinueStmt(continueStmt); } else if (const auto *whileStmt = dyn_cast(stmt)) { doWhileStmt(whileStmt, attrs); } else if (const auto *forStmt = dyn_cast(stmt)) { doForStmt(forStmt, attrs); } else if (dyn_cast(stmt)) { // For the null statement ";". We don't need to do anything. } else if (const auto *expr = dyn_cast(stmt)) { // All cases for expressions used as statements doExpr(expr); } else if (const auto *attrStmt = dyn_cast(stmt)) { doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs()); } else { emitError("statement class '%0' unimplemented", stmt->getLocStart()) << stmt->getStmtClassName() << stmt->getSourceRange(); } } SpirvInstruction *SpirvEmitter::doExpr(const Expr *expr) { SpirvInstruction *result = nullptr; expr = expr->IgnoreParens(); if (const auto *declRefExpr = dyn_cast(expr)) { auto *decl = declRefExpr->getDecl(); if (isImplicitVarDeclInVkNamespace(declRefExpr->getDecl())) { result = doExpr(cast(decl)->getInit()); } else { result = declIdMapper.getDeclEvalInfo(decl, expr->getLocStart()); } } else if (const auto *memberExpr = dyn_cast(expr)) { result = doMemberExpr(memberExpr); } else if (const auto *castExpr = dyn_cast(expr)) { result = doCastExpr(castExpr); } else if (const auto *initListExpr = dyn_cast(expr)) { result = doInitListExpr(initListExpr); } else if (const auto *boolLiteral = dyn_cast(expr)) { result = spvBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode); result->setRValue(); } else if (const auto *intLiteral = dyn_cast(expr)) { result = translateAPInt(intLiteral->getValue(), expr->getType()); result->setRValue(); } else if (const auto *floatLiteral = dyn_cast(expr)) { result = translateAPFloat(floatLiteral->getValue(), expr->getType()); result->setRValue(); } else if (const auto *stringLiteral = dyn_cast(expr)) { result = spvBuilder.getString(stringLiteral->getString()); } else if (const auto *compoundAssignOp = dyn_cast(expr)) { // CompoundAssignOperator is a subclass of BinaryOperator. It should be // checked before BinaryOperator. result = doCompoundAssignOperator(compoundAssignOp); } else if (const auto *binOp = dyn_cast(expr)) { result = doBinaryOperator(binOp); } else if (const auto *unaryOp = dyn_cast(expr)) { result = doUnaryOperator(unaryOp); } else if (const auto *vecElemExpr = dyn_cast(expr)) { result = doHLSLVectorElementExpr(vecElemExpr); } else if (const auto *matElemExpr = dyn_cast(expr)) { result = doExtMatrixElementExpr(matElemExpr); } else if (const auto *funcCall = dyn_cast(expr)) { result = doCallExpr(funcCall); } else if (const auto *subscriptExpr = dyn_cast(expr)) { result = doArraySubscriptExpr(subscriptExpr); } else if (const auto *condExpr = dyn_cast(expr)) { result = doConditionalOperator(condExpr); } else if (const auto *defaultArgExpr = dyn_cast(expr)) { result = doExpr(defaultArgExpr->getParam()->getDefaultArg()); } else if (isa(expr)) { assert(curThis); result = curThis; } else if (isa(expr)) { result = curThis; } else if (const auto *unaryExpr = dyn_cast(expr)) { result = doUnaryExprOrTypeTraitExpr(unaryExpr); } else { emitError("expression class '%0' unimplemented", expr->getExprLoc()) << expr->getStmtClassName() << expr->getSourceRange(); } return result; } SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr) { // We are trying to load the value here, which is what an LValueToRValue // implicit cast is intended to do. We can ignore the cast if exists. expr = expr->IgnoreParenLValueCasts(); return loadIfGLValue(expr, doExpr(expr)); } SpirvInstruction *SpirvEmitter::loadIfGLValue(const Expr *expr, SpirvInstruction *info) { const auto exprType = expr->getType(); // Do nothing if this is already rvalue if (!info || info->isRValue()) return info; // Check whether we are trying to load an array of opaque objects as a whole. // If true, we are likely to copy it as a whole. To assist per-element // copying, avoid the load here and return the pointer directly. // TODO: consider moving this hack into SPIRV-Tools as a transformation. if (isOpaqueArrayType(exprType)) return info; // Check whether we are trying to load an externally visible structured/byte // buffer as a whole. If true, it means we are creating alias for it. Avoid // the load and write the pointer directly to the alias variable then. // // Also for the case of alias function returns. If we are trying to load an // alias function return as a whole, it means we are assigning it to another // alias variable. Avoid the load and write the pointer directly. // // Note: legalization specific code if (isReferencingNonAliasStructuredOrByteBuffer(expr)) { return info; } if (loadIfAliasVarRef(expr, &info)) { // We are loading an alias variable as a whole here. This is likely for // wholesale assignments or function returns. Need to load the pointer. // // Note: legalization specific code return info; } SpirvInstruction *loadedInstr = nullptr; // TODO: Ouch. Very hacky. We need special path to get the value type if // we are loading a whole ConstantBuffer/TextureBuffer since the normal // type translation path won't work. if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) { loadedInstr = spvBuilder.createLoad( declIdMapper.getCTBufferPushConstantType(declContext), info, expr->getExprLoc()); } else { loadedInstr = spvBuilder.createLoad(exprType, info, expr->getExprLoc()); } assert(loadedInstr); // Special-case: According to the SPIR-V Spec: There is no physical size or // bit pattern defined for boolean type. Therefore an unsigned integer is used // to represent booleans when layout is required. In such cases, after loading // the uint, we should perform a comparison. { uint32_t vecSize = 1, numRows = 0, numCols = 0; if (info->getLayoutRule() != SpirvLayoutRule::Void && isBoolOrVecMatOfBoolType(exprType)) { QualType uintType = astContext.UnsignedIntTy; if (isScalarType(exprType) || isVectorType(exprType, nullptr, &vecSize)) { const auto fromType = vecSize == 1 ? uintType : astContext.getExtVectorType(uintType, vecSize); loadedInstr = castToBool(loadedInstr, fromType, exprType, expr->getLocStart()); } else { const bool isMat = isMxNMatrix(exprType, nullptr, &numRows, &numCols); assert(isMat); (void)isMat; const clang::Type *type = exprType.getCanonicalType().getTypePtr(); const RecordType *RT = cast(type); const ClassTemplateSpecializationDecl *templateSpecDecl = cast(RT->getDecl()); ClassTemplateDecl *templateDecl = templateSpecDecl->getSpecializedTemplate(); const auto fromType = getHLSLMatrixType( astContext, theCompilerInstance.getSema(), templateDecl, astContext.UnsignedIntTy, numRows, numCols); loadedInstr = castToBool(loadedInstr, fromType, exprType, expr->getLocStart()); } // Now that it is converted to Bool, it has no layout rule. // This result-id should be evaluated as bool from here on out. loadedInstr->setLayoutRule(SpirvLayoutRule::Void); } } loadedInstr->setRValue(); return loadedInstr; } SpirvInstruction *SpirvEmitter::loadIfAliasVarRef(const Expr *expr) { auto *instr = doExpr(expr); loadIfAliasVarRef(expr, &instr); return instr; } bool SpirvEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvInstruction **instr) { assert(instr); if ((*instr) && (*instr)->containsAliasComponent() && isAKindOfStructuredOrByteBuffer(varExpr->getType())) { // Load the pointer of the aliased-to-variable if the expression has a // pointer to pointer type. if (varExpr->isGLValue()) { *instr = spvBuilder.createLoad(varExpr->getType(), *instr, varExpr->getExprLoc()); } return true; } return false; } SpirvInstruction *SpirvEmitter::castToType(SpirvInstruction *value, QualType fromType, QualType toType, SourceLocation srcLoc) { if (isFloatOrVecMatOfFloatType(toType)) return castToFloat(value, fromType, toType, srcLoc); // Order matters here. Bool (vector) values will also be considered as uint // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType() // will also return true. We need to check bool before uint. The opposite is // not true. if (isBoolOrVecMatOfBoolType(toType)) return castToBool(value, fromType, toType, srcLoc); if (isSintOrVecMatOfSintType(toType) || isUintOrVecMatOfUintType(toType)) return castToInt(value, fromType, toType, srcLoc); emitError("casting to type %0 unimplemented", {}) << toType; return nullptr; } void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) { // Forward declaration of a function inside another. if(!decl->isThisDeclarationADefinition()) { addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(), decl, /*isEntryFunction*/ false); return; } // A RAII class for maintaining the current function under traversal. class FnEnvRAII { public: // Creates a new instance which sets fnEnv to the newFn on creation, // and resets fnEnv to its original value on destruction. FnEnvRAII(const FunctionDecl **fnEnv, const FunctionDecl *newFn) : oldFn(*fnEnv), fnSlot(fnEnv) { *fnEnv = newFn; } ~FnEnvRAII() { *fnSlot = oldFn; } private: const FunctionDecl *oldFn; const FunctionDecl **fnSlot; }; FnEnvRAII fnEnvRAII(&curFunction, decl); // We are about to start translation for a new function. Clear the break stack // and the continue stack. breakStack = std::stack(); continueStack = std::stack(); // This will allow the entry-point name to be something like // myNamespace::myEntrypointFunc. std::string funcName = getFnName(decl); std::string debugFuncName = funcName; SpirvFunction *func = declIdMapper.getOrRegisterFn(decl); const auto iter = functionInfoMap.find(decl); if (iter != functionInfoMap.end()) { const auto &entryInfo = iter->second; if (entryInfo->isEntryFunction) { funcName = "src." + funcName; // Create wrapper for the entry function if (!emitEntryFunctionWrapper(decl, func)) return; } } const QualType retType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl); spvBuilder.beginFunction(retType, decl->getLocStart(), funcName, decl->hasAttr(), decl->hasAttr(), func); auto loc = decl->getLocStart(); RichDebugInfo *info = nullptr; const auto &sm = astContext.getSourceManager(); if (spirvOptions.debugInfoRich && decl->hasBody()) { const uint32_t line = sm.getPresumedLineNumber(loc); const uint32_t column = sm.getPresumedColumnNumber(loc); info = getOrCreateRichDebugInfo(loc); auto *source = info->source; // Note that info->scopeStack.back() is a lexical scope of the function // caller. auto *parentScope = info->compilationUnit; // TODO: figure out the proper flag based on the function decl. // using FlagIsPublic for now. uint32_t flags = 3u; // The line number in the source program at which the function scope begins. auto scopeLine = sm.getPresumedLineNumber(decl->getBody()->getLocStart()); SpirvDebugFunction *debugFunction = spvBuilder.createDebugFunction( decl, debugFuncName, source, line, column, parentScope, "", flags, scopeLine, func); func->setDebugScope(new (spvContext) SpirvDebugScope(debugFunction)); spvContext.pushDebugLexicalScope(info, debugFunction); } bool isNonStaticMemberFn = false; if (const auto *memberFn = dyn_cast(decl)) { if (!memberFn->isStatic()) { // For non-static member function, the first parameter should be the // object on which we are invoking this method. QualType valueType = memberFn->getThisType(astContext)->getPointeeType(); // Remember the parameter for the 'this' object so later we can handle // CXXThisExpr correctly. curThis = spvBuilder.addFnParam(valueType, /*isPrecise*/ false, decl->getLocStart(), "param.this"); if (isOrContainsAKindOfStructuredOrByteBuffer(valueType)) { curThis->setContainsAliasComponent(true); needsLegalization = true; } if (spirvOptions.debugInfoRich) { // Add DebugLocalVariable information const auto &sm = astContext.getSourceManager(); const uint32_t line = sm.getPresumedLineNumber(loc); const uint32_t column = sm.getPresumedColumnNumber(loc); if (!info) info = getOrCreateRichDebugInfo(loc); // TODO: replace this with FlagArtificial|FlagObjectPointer. uint32_t flags = (1 << 5) | (1 << 8); auto *debugLocalVar = spvBuilder.createDebugLocalVariable( valueType, "this", info->source, line, column, info->scopeStack.back(), flags, 1); spvBuilder.createDebugDeclare(debugLocalVar, curThis); } isNonStaticMemberFn = true; } } // Create all parameters. for (uint32_t i = 0; i < decl->getNumParams(); ++i) { const ParmVarDecl *paramDecl = decl->getParamDecl(i); if (spvContext.isHS() && decl == patchConstFunc && hlsl::IsHLSLOutputPatchType(paramDecl->getType())) { // Since the output patch used in hull shaders is translated to // a variable with Output storage class, there is no need // to pass the variable as function parameter in SPIR-V. continue; } (void)declIdMapper.createFnParam(paramDecl, i + 1 + isNonStaticMemberFn); } if (decl->hasBody()) { // The entry basic block. auto *entryLabel = spvBuilder.createBasicBlock("bb.entry"); spvBuilder.setInsertPoint(entryLabel); // Process all statments in the body. doStmt(decl->getBody()); // We have processed all Stmts in this function and now in the last // basic block. Make sure we have a termination instruction. if (!spvBuilder.isCurrentBasicBlockTerminated()) { const auto retType = decl->getReturnType(); const auto returnLoc = decl->getBody()->getLocEnd(); if (retType->isVoidType()) { spvBuilder.createReturn(returnLoc); } else { // If the source code does not provide a proper return value for some // control flow path, it's undefined behavior. We just return null // value here. spvBuilder.createReturnValue(spvBuilder.getConstantNull(retType), returnLoc); } } } spvBuilder.endFunction(); if (spirvOptions.debugInfoRich) { spvContext.popDebugLexicalScope(info); } } bool SpirvEmitter::validateVKAttributes(const NamedDecl *decl) { bool success = true; if (const auto *varDecl = dyn_cast(decl)) { const auto varType = varDecl->getType(); if ((isSubpassInput(varType) || isSubpassInputMS(varType)) && !varDecl->hasAttr()) { emitError("missing vk::input_attachment_index attribute", varDecl->getLocation()); success = false; } } if (decl->getAttr()) { if (!spvContext.isPS()) { emitError("SubpassInput(MS) only allowed in pixel shader", decl->getLocation()); success = false; } if (!decl->isExternallyVisible()) { emitError("SubpassInput(MS) must be externally visible", decl->getLocation()); success = false; } // We only allow VKInputAttachmentIndexAttr to be attached to global // variables. So it should be fine to cast here. const auto elementType = hlsl::GetHLSLResourceResultType(cast(decl)->getType()); if (!isScalarType(elementType) && !isVectorType(elementType)) { emitError( "only scalar/vector types allowed as SubpassInput(MS) parameter type", decl->getLocation()); // Return directly to avoid further type processing, which will hit // asserts when lowering the type. return false; } } // The frontend will make sure that // * vk::push_constant applies to global variables of struct type // * vk::binding applies to global variables or cbuffers/tbuffers // * vk::counter_binding applies to global variables of RW/Append/Consume // StructuredBuffer // * vk::location applies to function parameters/returns and struct fields // So the only case we need to check co-existence is vk::push_constant and // vk::binding. if (const auto *pcAttr = decl->getAttr()) { const auto loc = pcAttr->getLocation(); if (seenPushConstantAt.isInvalid()) { seenPushConstantAt = loc; } else { // TODO: Actually this is slightly incorrect. The Vulkan spec says: // There must be no more than one push constant block statically used // per shader entry point. // But we are checking whether there are more than one push constant // blocks defined. Tracking usage requires more work. emitError("cannot have more than one push constant block", loc); emitNote("push constant block previously defined here", seenPushConstantAt); success = false; } if (decl->hasAttr()) { emitError("vk::push_constant attribute cannot be used together with " "vk::binding attribute", loc); success = false; } } // vk::shader_record_nv is supported only on cbuffer/ConstantBuffer if (const auto *srbAttr = decl->getAttr()) { const auto loc = srbAttr->getLocation(); const HLSLBufferDecl *bufDecl = nullptr; bool isValidType = false; if ((bufDecl = dyn_cast(decl))) isValidType = bufDecl->isCBuffer(); else if ((bufDecl = dyn_cast(decl->getDeclContext()))) isValidType = bufDecl->isCBuffer(); else if (isa(decl)) isValidType = isConstantBuffer(dyn_cast(decl)->getType()); if (!isValidType) { emitError( "vk::shader_record_nv can be applied only to cbuffer/ConstantBuffer", loc); success = false; } if (decl->hasAttr()) { emitError("vk::shader_record_nv attribute cannot be used together with " "vk::binding attribute", loc); success = false; } } // vk::shader_record_ext is supported only on cbuffer/ConstantBuffer if (const auto *srbAttr = decl->getAttr()) { const auto loc = srbAttr->getLocation(); const HLSLBufferDecl *bufDecl = nullptr; bool isValidType = false; if ((bufDecl = dyn_cast(decl))) isValidType = bufDecl->isCBuffer(); else if ((bufDecl = dyn_cast(decl->getDeclContext()))) isValidType = bufDecl->isCBuffer(); else if (isa(decl)) isValidType = isConstantBuffer(dyn_cast(decl)->getType()); if (!isValidType) { emitError( "vk::shader_record_ext can be applied only to cbuffer/ConstantBuffer", loc); success = false; } if (decl->hasAttr()) { emitError("vk::shader_record_ext attribute cannot be used together with " "vk::binding attribute", loc); success = false; } } return success; } void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) { // This is a cbuffer/tbuffer decl. // Check and emit warnings for member intializers which are not // supported in Vulkan for (const auto *member : bufferDecl->decls()) { if (const auto *varMember = dyn_cast(member)) { if (!spirvOptions.noWarnIgnoredFeatures) { if (const auto *init = varMember->getInit()) emitWarning("%select{tbuffer|cbuffer}0 member initializer " "ignored since no Vulkan equivalent", init->getExprLoc()) << bufferDecl->isCBuffer() << init->getSourceRange(); } // We cannot handle external initialization of column-major matrices now. if (isOrContainsNonFpColMajorMatrix(astContext, spirvOptions, varMember->getType(), varMember)) { emitError("externally initialized non-floating-point column-major " "matrices not supported yet", varMember->getLocation()); } } } if (!validateVKAttributes(bufferDecl)) return; if (bufferDecl->hasAttr()) { (void)declIdMapper.createShaderRecordBuffer( bufferDecl, DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferNV); } else if (bufferDecl->hasAttr()) { (void)declIdMapper.createShaderRecordBuffer( bufferDecl, DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferEXT); } else { (void)declIdMapper.createCTBuffer(bufferDecl); } } void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) { // Ignore implict records // Somehow we'll have implicit records with: // static const int Length = count; // that can mess up with the normal CodeGen. if (recordDecl->isImplicit()) return; // Handle each static member with inline initializer. // Each static member has a corresponding VarDecl inside the // RecordDecl. For those defined in the translation unit, // their VarDecls do not have initializer. for (auto *subDecl : recordDecl->decls()) if (auto *varDecl = dyn_cast(subDecl)) if (varDecl->isStaticDataMember() && varDecl->hasInit()) doVarDecl(varDecl); } void SpirvEmitter::doEnumDecl(const EnumDecl *decl) { for (auto it = decl->enumerator_begin(); it != decl->enumerator_end(); ++it) declIdMapper.createEnumConstant(*it); } void SpirvEmitter::doVarDecl(const VarDecl *decl) { if (!validateVKAttributes(decl)) return; const auto loc = decl->getLocation(); // HLSL has the 'string' type which can be used for rare purposes such as // printf (SPIR-V's DebugPrintf). SPIR-V does not have a 'char' or 'string' // type, and therefore any variable of such type should not be created. // DeclResultIdMapper maps such decl to an OpString instruction that // represents the variable's initializer literal. if (isStringType(decl->getType())) { declIdMapper.createOrUpdateStringVar(decl); return; } // We cannot handle external initialization of column-major matrices now. if (isExternalVar(decl) && isOrContainsNonFpColMajorMatrix(astContext, spirvOptions, decl->getType(), decl)) { emitError("externally initialized non-floating-point column-major " "matrices not supported yet", loc); } // Reject arrays of RW/append/consume structured buffers. They have assoicated // counters, which are quite nasty to handle. if (decl->getType()->isArrayType()) { auto type = decl->getType(); do { type = type->getAsArrayTypeUnsafe()->getElementType(); } while (type->isArrayType()); if (isRWAppendConsumeSBuffer(type)) { emitError("arrays of RW/append/consume structured buffers unsupported", loc); return; } } if (decl->hasAttr()) { // This is a VarDecl for specialization constant. createSpecConstant(decl); return; } if (decl->hasAttr()) { // This is a VarDecl for PushConstant block. (void)declIdMapper.createPushConstant(decl); return; } if (decl->hasAttr()) { (void)declIdMapper.createShaderRecordBuffer( decl, DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferNV); return; } if (decl->hasAttr()) { (void)declIdMapper.createShaderRecordBuffer( decl, DeclResultIdMapper::ContextUsageKind::ShaderRecordBufferEXT); return; } // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need // to emit their cbuffer/tbuffer as a whole and access each individual one // using access chains. // cbuffers and tbuffers are HLSLBufferDecls // ConstantBuffers and TextureBuffers are not HLSLBufferDecls. if (const auto *bufferDecl = dyn_cast(decl->getDeclContext())) { // This is a VarDecl of cbuffer/tbuffer type. doHLSLBufferDecl(bufferDecl); return; } if (isConstantTextureBuffer(decl->getType())) { // This is a VarDecl of ConstantBuffer/TextureBuffer type. (void)declIdMapper.createCTBuffer(decl); return; } SpirvVariable *var = nullptr; // The contents in externally visible variables can be updated via the // pipeline. They should be handled differently from file and function scope // variables. // File scope variables (static "global" and "local" variables) belongs to // the Private storage class, while function scope variables (normal "local" // variables) belongs to the Function storage class. if (isExternalVar(decl)) { var = declIdMapper.createExternVar(decl); } else { // We already know the variable is not externally visible here. If it does // not have local storage, it should be file scope variable. const bool isFileScopeVar = !decl->hasLocalStorage(); if (isFileScopeVar) var = declIdMapper.createFileVar(decl, llvm::None); else var = declIdMapper.createFnVar(decl, llvm::None); // Emit OpStore to initialize the variable // TODO: revert back to use OpVariable initializer // We should only evaluate the initializer once for a static variable. if (isFileScopeVar) { if (decl->isStaticLocal()) { initOnce(decl->getType(), decl->getName(), var, decl->getInit()); } else { // Defer to initialize these global variables at the beginning of the // entry function. toInitGloalVars.push_back(decl); } } // Function local variables. Just emit OpStore at the current insert point. else if (const Expr *init = decl->getInit()) { if (auto *constInit = tryToEvaluateAsConst(init)) { spvBuilder.createStore(var, constInit, loc); } else { storeValue(var, loadIfGLValue(init), decl->getType(), loc); } // Update counter variable associated with local variables tryToAssignCounterVar(decl, init); } if (!isFileScopeVar && spirvOptions.debugInfoRich) { // Add DebugLocalVariable information const auto &sm = astContext.getSourceManager(); const uint32_t line = sm.getPresumedLineNumber(loc); const uint32_t column = sm.getPresumedColumnNumber(loc); const auto *info = getOrCreateRichDebugInfo(loc); // TODO: replace this with FlagIsLocal enum. uint32_t flags = 1 << 2; auto *debugLocalVar = spvBuilder.createDebugLocalVariable( decl->getType(), decl->getName(), info->source, line, column, info->scopeStack.back(), flags); spvBuilder.createDebugDeclare(debugLocalVar, var); } // Variables that are not externally visible and of opaque types should // request legalization. if (!needsLegalization && isOpaqueType(decl->getType())) needsLegalization = true; } // All variables that are of opaque struct types should request legalization. if (!needsLegalization && isOpaqueStructType(decl->getType())) needsLegalization = true; } spv::LoopControlMask SpirvEmitter::translateLoopAttribute(const Stmt *stmt, const Attr &attr) { switch (attr.getKind()) { case attr::HLSLLoop: case attr::HLSLFastOpt: return spv::LoopControlMask::DontUnroll; case attr::HLSLUnroll: return spv::LoopControlMask::Unroll; case attr::HLSLAllowUAVCondition: if (!spirvOptions.noWarnIgnoredFeatures) { emitWarning("unsupported allow_uav_condition attribute ignored", stmt->getLocStart()); } break; default: llvm_unreachable("found unknown loop attribute"); } return spv::LoopControlMask::MaskNone; } void SpirvEmitter::doDiscardStmt(const DiscardStmt *discardStmt) { assert(!spvBuilder.isCurrentBasicBlockTerminated()); // The discard statement can only be called from a pixel shader if (!spvContext.isPS()) { emitError("discard statement may only be used in pixel shaders", discardStmt->getLoc()); return; } if (featureManager.isExtensionEnabled( Extension::EXT_demote_to_helper_invocation)) { // SPV_EXT_demote_to_helper_invocation SPIR-V extension provides a new // instruction OpDemoteToHelperInvocationEXT allowing shaders to "demote" a // fragment shader invocation to behave like a helper invocation for its // duration. The demoted invocation will have no further side effects and // will not output to the framebuffer, but remains active and can // participate in computing derivatives and in subgroup operations. This is // a better match for the "discard" instruction in HLSL. spvBuilder.createDemoteToHelperInvocationEXT(discardStmt->getLoc()); } else { // Note: if/when the demote behavior becomes part of the core Vulkan spec, // we should no longer generate OpKill for 'discard', and always generate // the demote behavior. spvBuilder.createKill(discardStmt->getLoc()); // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement // that may follow them. auto *newBB = spvBuilder.createBasicBlock(); spvBuilder.setInsertPoint(newBB); } } void SpirvEmitter::doDoStmt(const DoStmt *theDoStmt, llvm::ArrayRef attrs) { // do-while loops are composed of: // // do { // // } while(); // // SPIR-V requires loops to have a merge basic block as well as a continue // basic block. Even though do-while loops do not have an explicit continue // block as in for-loops, we still do need to create a continue block. // // Since SPIR-V requires structured control flow, we need two more basic // blocks,
and .
is the block before control flow // diverges, and is the block where control flow subsequently // converges. The can be performed in the basic block. // The final CFG should normally be like the following. Exceptions // will occur with non-local exits like loop breaks or early returns. // // +----------+ // | header | <-----------------------------------+ // +----------+ | // | | (true) // v | // +------+ +--------------------+ | // | body | ----> | continue () |-----------+ // +------+ +--------------------+ // | // | (false) // +-------+ | // | merge | <-------------+ // +-------+ // // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec. const spv::LoopControlMask loopControl = attrs.empty() ? spv::LoopControlMask::MaskNone : translateLoopAttribute(theDoStmt, *attrs.front()); // Create basic blocks auto *headerBB = spvBuilder.createBasicBlock("do_while.header"); auto *bodyBB = spvBuilder.createBasicBlock("do_while.body"); auto *continueBB = spvBuilder.createBasicBlock("do_while.continue"); auto *mergeBB = spvBuilder.createBasicBlock("do_while.merge"); // Make sure any continue statements branch to the continue block, and any // break statements branch to the merge block. continueStack.push(continueBB); breakStack.push(mergeBB); // Branch from the current insert point to the header block. spvBuilder.createBranch(headerBB, theDoStmt->getLocStart()); spvBuilder.addSuccessor(headerBB); // Process the
block // The header block must always branch to the body. spvBuilder.setInsertPoint(headerBB); const Stmt *body = theDoStmt->getBody(); spvBuilder.createBranch(bodyBB, body ? body->getLocStart() : theDoStmt->getLocStart(), mergeBB, continueBB, loopControl); spvBuilder.addSuccessor(bodyBB); // The current basic block has OpLoopMerge instruction. We need to set its // continue and merge target. spvBuilder.setContinueTarget(continueBB); spvBuilder.setMergeTarget(mergeBB); // Process the block spvBuilder.setInsertPoint(bodyBB); if (body) { doStmt(body); } if (!spvBuilder.isCurrentBasicBlockTerminated()) { spvBuilder.createBranch(continueBB, body ? body->getLocEnd() : theDoStmt->getLocStart()); } spvBuilder.addSuccessor(continueBB); // Process the block. The check for whether the loop should // continue lies in the continue block. // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in // a continue block of a loop, there should be no OpSelectionMerge. Only an // OpBranchConditional must be specified. spvBuilder.setInsertPoint(continueBB); SpirvInstruction *condition = nullptr; if (const Expr *check = theDoStmt->getCond()) { condition = doExpr(check); } else { condition = spvBuilder.getConstantBool(true); } spvBuilder.createConditionalBranch(condition, headerBB, mergeBB, theDoStmt->getLocEnd()); spvBuilder.addSuccessor(headerBB); spvBuilder.addSuccessor(mergeBB); // Set insertion point to the block for subsequent statements spvBuilder.setInsertPoint(mergeBB); // Done with the current scope's continue block and merge block. continueStack.pop(); breakStack.pop(); } void SpirvEmitter::doContinueStmt(const ContinueStmt *continueStmt) { assert(!spvBuilder.isCurrentBasicBlockTerminated()); auto *continueTargetBB = continueStack.top(); spvBuilder.createBranch(continueTargetBB, continueStmt->getLocStart()); spvBuilder.addSuccessor(continueTargetBB); // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. For example: StmtB and StmtC below are put inside a new // basic block which is unreachable. // // while (true) { // StmtA; // continue; // StmtB; // StmtC; // } auto *newBB = spvBuilder.createBasicBlock(); spvBuilder.setInsertPoint(newBB); } void SpirvEmitter::doWhileStmt(const WhileStmt *whileStmt, llvm::ArrayRef attrs) { // While loops are composed of: // while () { } // // SPIR-V requires loops to have a merge basic block as well as a continue // basic block. Even though while loops do not have an explicit continue // block as in for-loops, we still do need to create a continue block. // // Since SPIR-V requires structured control flow, we need two more basic // blocks,
and .
is the block before control flow // diverges, and is the block where control flow subsequently // converges. The block can take the responsibility of the
// block. The final CFG should normally be like the following. Exceptions // will occur with non-local exits like loop breaks or early returns. // // +----------+ // | header | <------------------+ // | (check) | | // +----------+ | // | | // +-------+-------+ | // | false | true | // | v | // | +------+ +------------------+ // | | body | --> | continue (no-op) | // v +------+ +------------------+ // +-------+ // | merge | // +-------+ // // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec. const spv::LoopControlMask loopControl = attrs.empty() ? spv::LoopControlMask::MaskNone : translateLoopAttribute(whileStmt, *attrs.front()); // Create basic blocks auto *checkBB = spvBuilder.createBasicBlock("while.check"); auto *bodyBB = spvBuilder.createBasicBlock("while.body"); auto *continueBB = spvBuilder.createBasicBlock("while.continue"); auto *mergeBB = spvBuilder.createBasicBlock("while.merge"); // Make sure any continue statements branch to the continue block, and any // break statements branch to the merge block. continueStack.push(continueBB); breakStack.push(mergeBB); // Process the block spvBuilder.createBranch(checkBB, whileStmt->getLocStart()); spvBuilder.addSuccessor(checkBB); spvBuilder.setInsertPoint(checkBB); // If we have: // while (int a = foo()) {...} // we should evaluate 'a' by calling 'foo()' every single time the check has // to occur. if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt()) doStmt(condVarDecl); SpirvInstruction *condition = nullptr; const Expr *check = whileStmt->getCond(); if (check) { condition = doExpr(check); } else { condition = spvBuilder.getConstantBool(true); } spvBuilder.createConditionalBranch( condition, bodyBB, /*false branch*/ mergeBB, whileStmt->getLocStart(), /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone, loopControl); spvBuilder.addSuccessor(bodyBB); spvBuilder.addSuccessor(mergeBB); // The current basic block has OpLoopMerge instruction. We need to set its // continue and merge target. spvBuilder.setContinueTarget(continueBB); spvBuilder.setMergeTarget(mergeBB); // Process the block spvBuilder.setInsertPoint(bodyBB); const Stmt *body = whileStmt->getBody(); if (body) { doStmt(body); } if (!spvBuilder.isCurrentBasicBlockTerminated()) spvBuilder.createBranch(continueBB, whileStmt->getLocEnd()); spvBuilder.addSuccessor(continueBB); // Process the block. While loops do not have an explicit // continue block. The continue block just branches to the block. spvBuilder.setInsertPoint(continueBB); spvBuilder.createBranch(checkBB, whileStmt->getLocEnd()); spvBuilder.addSuccessor(checkBB); // Set insertion point to the block for subsequent statements spvBuilder.setInsertPoint(mergeBB); // Done with the current scope's continue and merge blocks. continueStack.pop(); breakStack.pop(); } void SpirvEmitter::doForStmt(const ForStmt *forStmt, llvm::ArrayRef attrs) { // for loops are composed of: // for (; ; ) // // To translate a for loop, we'll need to emit all statements // in the current basic block, and then have separate basic blocks for // , , and . Besides, since SPIR-V requires // structured control flow, we need two more basic blocks,
// and .
is the block before control flow diverges, // while is the block where control flow subsequently converges. // The block can take the responsibility of the
block. // The final CFG should normally be like the following. Exceptions will // occur with non-local exits like loop breaks or early returns. // +--------+ // | init | // +--------+ // | // v // +----------+ // | header | <---------------+ // | (check) | | // +----------+ | // | | // +-------+-------+ | // | false | true | // | v | // | +------+ +----------+ // | | body | --> | continue | // v +------+ +----------+ // +-------+ // | merge | // +-------+ // // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec. const spv::LoopControlMask loopControl = attrs.empty() ? spv::LoopControlMask::MaskNone : translateLoopAttribute(forStmt, *attrs.front()); // Create basic blocks auto *checkBB = spvBuilder.createBasicBlock("for.check"); auto *bodyBB = spvBuilder.createBasicBlock("for.body"); auto *continueBB = spvBuilder.createBasicBlock("for.continue"); auto *mergeBB = spvBuilder.createBasicBlock("for.merge"); // Make sure any continue statements branch to the continue block, and any // break statements branch to the merge block. continueStack.push(continueBB); breakStack.push(mergeBB); // Process the block if (const Stmt *initStmt = forStmt->getInit()) { doStmt(initStmt); } const Expr *check = forStmt->getCond(); spvBuilder.createBranch(checkBB, check ? check->getLocStart() : forStmt->getLocStart()); spvBuilder.addSuccessor(checkBB); // Process the block spvBuilder.setInsertPoint(checkBB); SpirvInstruction *condition = nullptr; if (check) { condition = doExpr(check); } else { condition = spvBuilder.getConstantBool(true); } const Stmt *body = forStmt->getBody(); spvBuilder.createConditionalBranch( condition, bodyBB, /*false branch*/ mergeBB, check ? check->getLocEnd() : (body ? body->getLocStart() : forStmt->getLocStart()), /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone, loopControl); spvBuilder.addSuccessor(bodyBB); spvBuilder.addSuccessor(mergeBB); // The current basic block has OpLoopMerge instruction. We need to set its // continue and merge target. spvBuilder.setContinueTarget(continueBB); spvBuilder.setMergeTarget(mergeBB); // Process the block spvBuilder.setInsertPoint(bodyBB); if (body) { doStmt(body); } if (!spvBuilder.isCurrentBasicBlockTerminated()) spvBuilder.createBranch(continueBB, forStmt->getLocEnd()); spvBuilder.addSuccessor(continueBB); // Process the block spvBuilder.setInsertPoint(continueBB); if (const Expr *cont = forStmt->getInc()) { doExpr(cont); } // should jump back to header spvBuilder.createBranch(checkBB, forStmt->getLocEnd()); spvBuilder.addSuccessor(checkBB); // Set insertion point to the block for subsequent statements spvBuilder.setInsertPoint(mergeBB); // Done with the current scope's continue block and merge block. continueStack.pop(); breakStack.pop(); } void SpirvEmitter::doIfStmt(const IfStmt *ifStmt, llvm::ArrayRef attrs) { // if statements are composed of: // if () { } else { } // // To translate if statements, we'll need to emit the expressions // in the current basic block, and then create separate basic blocks for // and . Additionally, we'll need a block as per // SPIR-V's structured control flow requirements. Depending whether there // exists the else branch, the final CFG should normally be like the // following. Exceptions will occur with non-local exits like loop breaks // or early returns. // +-------+ +-------+ // | check | | check | // +-------+ +-------+ // | | // +-------+-------+ +-----+-----+ // | true | false | true | false // v v or v | // +------+ +------+ +------+ | // | then | | else | | then | | // +------+ +------+ +------+ | // | | | v // | +-------+ | | +-------+ // +-> | merge | <-+ +---> | merge | // +-------+ +-------+ { // Try to see if we can const-eval the condition bool condition = false; if (ifStmt->getCond()->EvaluateAsBooleanCondition(condition, astContext)) { if (condition) { doStmt(ifStmt->getThen()); } else if (ifStmt->getElse()) { doStmt(ifStmt->getElse()); } return; } } auto selectionControl = spv::SelectionControlMask::MaskNone; if (!attrs.empty()) { const Attr *attribute = attrs.front(); switch (attribute->getKind()) { case attr::HLSLBranch: selectionControl = spv::SelectionControlMask::DontFlatten; break; case attr::HLSLFlatten: selectionControl = spv::SelectionControlMask::Flatten; break; default: // warning emitted in hlsl::ProcessStmtAttributeForHLSL break; } } if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt()) doDeclStmt(declStmt); // First emit the instruction for evaluating the condition. auto *condition = doExpr(ifStmt->getCond()); // Then we need to emit the instruction for the conditional branch. // We'll need the for the then/else/merge block to do so. const bool hasElse = ifStmt->getElse() != nullptr; auto *thenBB = spvBuilder.createBasicBlock("if.true"); auto *mergeBB = spvBuilder.createBasicBlock("if.merge"); auto *elseBB = hasElse ? spvBuilder.createBasicBlock("if.false") : mergeBB; // Create the branch instruction. This will end the current basic block. const auto *then = ifStmt->getThen(); spvBuilder.createConditionalBranch(condition, thenBB, elseBB, then->getLocStart(), mergeBB, /*continue*/ 0, selectionControl); spvBuilder.addSuccessor(thenBB); spvBuilder.addSuccessor(elseBB); // The current basic block has the OpSelectionMerge instruction. We need // to record its merge target. spvBuilder.setMergeTarget(mergeBB); // Handle the then branch spvBuilder.setInsertPoint(thenBB); doStmt(then); if (!spvBuilder.isCurrentBasicBlockTerminated()) spvBuilder.createBranch(mergeBB, ifStmt->getLocEnd()); spvBuilder.addSuccessor(mergeBB); // Handle the else branch (if exists) if (hasElse) { spvBuilder.setInsertPoint(elseBB); const auto *elseStmt = ifStmt->getElse(); doStmt(elseStmt); if (!spvBuilder.isCurrentBasicBlockTerminated()) spvBuilder.createBranch(mergeBB, elseStmt->getLocEnd()); spvBuilder.addSuccessor(mergeBB); } // From now on, we'll emit instructions into the merge block. spvBuilder.setInsertPoint(mergeBB); } void SpirvEmitter::doReturnStmt(const ReturnStmt *stmt) { if (const auto *retVal = stmt->getRetValue()) { // Update counter variable associated with function returns tryToAssignCounterVar(curFunction, retVal); auto *retInfo = loadIfGLValue(retVal); if (!retInfo) return; auto retType = retVal->getType(); if (retInfo->getLayoutRule() != SpirvLayoutRule::Void && retType->isStructureType()) { // We are returning some value from a non-Function storage class. Need to // create a temporary variable to "convert" the value to Function storage // class and then return. auto *tempVar = spvBuilder.addFnVar(retType, retVal->getLocEnd(), "temp.var.ret"); storeValue(tempVar, retInfo, retType, retVal->getLocEnd()); spvBuilder.createReturnValue( spvBuilder.createLoad(retType, tempVar, retVal->getLocEnd()), stmt->getReturnLoc()); } else { spvBuilder.createReturnValue(retInfo, stmt->getReturnLoc()); } } else { spvBuilder.createReturn(stmt->getReturnLoc()); } // We are translating a ReturnStmt, we should be in some function's body. assert(curFunction->hasBody()); // If this return statement is the last statement in the function, then // whe have no more work to do. if (cast(curFunction->getBody())->body_back() == stmt) return; // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. In this case, the newly created basic block will contain // any statement that may come after an early return. auto *newBB = spvBuilder.createBasicBlock(); spvBuilder.setInsertPoint(newBB); } void SpirvEmitter::doBreakStmt(const BreakStmt *breakStmt) { assert(!spvBuilder.isCurrentBasicBlockTerminated()); auto *breakTargetBB = breakStack.top(); spvBuilder.addSuccessor(breakTargetBB); spvBuilder.createBranch(breakTargetBB, breakStmt->getLocStart()); // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. For example: StmtB and StmtC below are put inside a new // basic block which is unreachable. // // while (true) { // StmtA; // break; // StmtB; // StmtC; // } auto *newBB = spvBuilder.createBasicBlock(); spvBuilder.setInsertPoint(newBB); } void SpirvEmitter::doSwitchStmt(const SwitchStmt *switchStmt, llvm::ArrayRef attrs) { // Switch statements are composed of: // switch () { // // // // (optional) // } // // +-------+ // | check | // +-------+ // | // +-------+-------+----------------+---------------+ // | 1 | 2 | 3 | (others) // v v v v // +-------+ +-------------+ +-------+ +------------+ // | case1 | | case2 | | case3 | ... | default | // | | |(fallthrough)|---->| | | (optional) | // +-------+ |+------------+ +-------+ +------------+ // | | | // | | | // | +-------+ | | // | | | <--------------------+ | // +-> | merge | | // | | <-------------------------------------+ // +-------+ // If no attributes are given, or if "forcecase" attribute was provided, // we'll do our best to use OpSwitch if possible. // If any of the cases compares to a variable (rather than an integer // literal), we cannot use OpSwitch because OpSwitch expects literal // numbers as parameters. const bool isAttrForceCase = !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase; const bool canUseSpirvOpSwitch = (attrs.empty() || isAttrForceCase) && allSwitchCasesAreIntegerLiterals(switchStmt->getBody()); if (isAttrForceCase && !canUseSpirvOpSwitch && !spirvOptions.noWarnIgnoredFeatures) { emitWarning("ignored 'forcecase' attribute for the switch statement " "since one or more case values are not integer literals", switchStmt->getLocStart()); } if (canUseSpirvOpSwitch) processSwitchStmtUsingSpirvOpSwitch(switchStmt); else processSwitchStmtUsingIfStmts(switchStmt); } SpirvInstruction * SpirvEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) { llvm::SmallVector indices; const auto *base = collectArrayStructIndices( expr, /*rawIndex*/ false, /*rawIndices*/ nullptr, &indices); auto *info = loadIfAliasVarRef(base); if (!indices.empty()) { info = turnIntoElementPtr(base->getType(), info, expr->getType(), indices, base->getExprLoc()); } return info; } SpirvInstruction *SpirvEmitter::doBinaryOperator(const BinaryOperator *expr) { const auto opcode = expr->getOpcode(); // Handle assignment first since we need to evaluate rhs before lhs. // For other binary operations, we need to evaluate lhs before rhs. if (opcode == BO_Assign) { // Update counter variable associated with lhs of assignments tryToAssignCounterVar(expr->getLHS(), expr->getRHS()); return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()), /*isCompoundAssignment=*/false); } // Try to optimize floatMxN * float and floatN * float case if (opcode == BO_Mul) { if (auto *result = tryToGenFloatMatrixScale(expr)) return result; if (auto *result = tryToGenFloatVectorScale(expr)) return result; } return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode, expr->getLHS()->getType(), expr->getType(), expr->getSourceRange(), expr->getOperatorLoc()); } SpirvInstruction *SpirvEmitter::doCallExpr(const CallExpr *callExpr) { if (const auto *operatorCall = dyn_cast(callExpr)) return doCXXOperatorCallExpr(operatorCall); if (const auto *memberCall = dyn_cast(callExpr)) return doCXXMemberCallExpr(memberCall); // Intrinsic functions such as 'dot' or 'mul' if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) { return processIntrinsicCallExpr(callExpr); } // Normal standalone functions return processCall(callExpr); } SpirvInstruction *SpirvEmitter::getBaseOfMemberFunction(QualType objectType, SpirvInstruction * objInstr, const CXXMethodDecl* memberFn, SourceLocation loc) { // If objectType is different from the parent of memberFn, memberFn should be // defined in a base struct/class of objectType. We create OpAccessChain with // index 0 while iterating bases of objectType until we find the base with // the definition of memberFn. if (const auto *ptrType = objectType->getAs()) { if (const auto *recordType = ptrType->getPointeeType()->getAs()) { const auto *parentDeclOfMemberFn = memberFn->getParent(); if (recordType->getDecl() != parentDeclOfMemberFn) { const auto *cxxRecordDecl = dyn_cast(recordType->getDecl()); auto *zero = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); for (auto baseItr = cxxRecordDecl->bases_begin(), itrEnd = cxxRecordDecl->bases_end(); baseItr != itrEnd; baseItr++) { const auto *baseType = baseItr->getType()->getAs(); objectType = astContext.getPointerType(baseType->desugar()); objInstr = spvBuilder.createAccessChain(objectType, objInstr, {zero}, loc); if (baseType->getDecl() == parentDeclOfMemberFn) return objInstr; } } } } return nullptr; } SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) { const FunctionDecl *callee = getCalleeDefinition(callExpr); // Note that we always want the defintion because Stmts/Exprs in the // function body references the parameters in the definition. if (!callee) { emitError("found undefined function", callExpr->getExprLoc()); return nullptr; } const auto paramTypeMatchesArgType = [](QualType paramType, QualType argType) { if (argType == paramType) return true; if (const auto *refType = paramType->getAs()) paramType = refType->getPointeeType(); auto argUnqualifiedType = argType->getUnqualifiedDesugaredType(); auto paramUnqualifiedType = paramType->getUnqualifiedDesugaredType(); if (argUnqualifiedType == paramUnqualifiedType) return true; return false; }; const auto numParams = callee->getNumParams(); bool isNonStaticMemberCall = false; QualType objectType = {}; // Type of the object (if exists) SpirvInstruction *objInstr = nullptr; // EvalInfo for the object (if exists) llvm::SmallVector vars; // Variables for function call llvm::SmallVector isTempVar; // Temporary variable or not llvm::SmallVector args; // Evaluated arguments if (const auto *memberCall = dyn_cast(callExpr)) { const auto *memberFn = cast(memberCall->getCalleeDecl()); isNonStaticMemberCall = !memberFn->isStatic(); if (isNonStaticMemberCall) { // For non-static member calls, evaluate the object and pass it as the // first argument. const auto *object = memberCall->getImplicitObjectArgument(); object = object->IgnoreParenNoopCasts(astContext); // Update counter variable associated with the implicit object tryToAssignCounterVar(getOrCreateDeclForMethodObject(memberFn), object); objectType = object->getType(); objInstr = doExpr(object); if (auto *accessToBaseInstr = getBaseOfMemberFunction(objectType, objInstr, memberFn, memberCall->getExprLoc())) { objInstr = accessToBaseInstr; objectType = accessToBaseInstr->getAstResultType(); } // If not already a variable, we need to create a temporary variable and // pass the object pointer to the function. Example: // getObject().objectMethod(); // Also, any parameter passed to the member function must be of Function // storage class. if (objInstr->isRValue()) { args.push_back(createTemporaryVar( objectType, getAstTypeName(objectType), // May need to load to use as initializer loadIfGLValue(object, objInstr), object->getLocStart())); } else { // Based on SPIR-V spec, function parameter must always be in Function // scope. If we pass a non-function scope argument, we need // the legalization. if (objInstr->getStorageClass() != spv::StorageClass::Function || !isMemoryObjectDeclaration(objInstr)) beforeHlslLegalization = true; args.push_back(objInstr); } // We do not need to create a new temporary variable for the this // object. Use the evaluated argument. vars.push_back(args.back()); isTempVar.push_back(false); } } // Evaluate parameters for (uint32_t i = 0; i < numParams; ++i) { // We want the argument variable here so that we can write back to it // later. We will do the OpLoad of this argument manually. So ingore // the LValueToRValue implicit cast here. auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts(); const auto *param = callee->getParamDecl(i); const auto paramType = param->getType(); // Get the evaluation info if this argument is referencing some variable // *as a whole*, in which case we can avoid creating the temporary variable // for it if it can act as out parameter. SpirvInstruction *argInfo = nullptr; if (const auto *declRefExpr = dyn_cast(arg)) { argInfo = declIdMapper.getDeclEvalInfo(declRefExpr->getDecl(), arg->getLocStart()); } auto *argInst = doExpr(arg); bool isArgGlobalVarWithResourceType = argInfo && argInfo->getStorageClass() != spv::StorageClass::Function && isResourceType(paramType); // If argInfo is nullptr and argInst is a rvalue, we do not have a proper // pointer to pass to the function. we need a temporary variable in that // case. // // If we have an 'out/inout' resource as function argument, we need to // create a temporary variable for it because the function definition // expects are point-to-pointer argument for resources, which will be // resolved by legalization. if ((argInfo || (argInst && !argInst->isRValue())) && canActAsOutParmVar(param) && !isArgGlobalVarWithResourceType && paramTypeMatchesArgType(paramType, arg->getType())) { // Based on SPIR-V spec, function parameter must be always Function // scope. In addition, we must pass memory object declaration argument // to function. If we pass an argument that is not function scope // or not memory object declaration, we need the legalization. if (!argInfo || argInfo->getStorageClass() != spv::StorageClass::Function) beforeHlslLegalization = true; isTempVar.push_back(false); args.push_back(argInst); vars.push_back(argInfo ? argInfo : argInst); } else { // We need to create variables for holding the values to be used as // arguments. The variables themselves are of pointer types. const QualType varType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param); const std::string varName = "param.var." + param->getNameAsString(); // Temporary "param.var.*" variables are used for OpFunctionCall purposes. // 'precise' attribute on function parameters only affect computations // inside the function, not the variables at the call sites. Therefore, we // do not need to mark the "param.var.*" variables as precise. const bool isPrecise = false; auto *tempVar = spvBuilder.addFnVar(varType, arg->getLocStart(), varName, isPrecise); vars.push_back(tempVar); isTempVar.push_back(true); args.push_back(argInst); // Update counter variable associated with function parameters tryToAssignCounterVar(param, arg); // Manually load the argument here auto *rhsVal = loadIfGLValue(arg, args.back()); // The AST does not include cast nodes to and from the function parameter // type for 'out' and 'inout' cases. Example: // // void foo(out half3 param) {...} // void main() { float3 arg; foo(arg); } // // In such cases, we first do a manual cast before passing the argument to // the function. And we will cast back the results once the function call // has returned. if (canActAsOutParmVar(param) && !paramTypeMatchesArgType(paramType, arg->getType())) { if (const auto *refType = paramType->getAs()) rhsVal = castToType(rhsVal, arg->getType(), refType->getPointeeType(), arg->getLocStart()); } // Initialize the temporary variables using the contents of the arguments storeValue(tempVar, rhsVal, paramType, arg->getLocStart()); } } if (beforeHlslLegalization) needsLegalization = true; assert(vars.size() == isTempVar.size()); assert(vars.size() == args.size()); // Push the callee into the work queue if it is not there. addFunctionToWorkQueue(spvContext.getCurrentShaderModelKind(), callee, /*isEntryFunction*/ false); const QualType retType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(callee); // Get or forward declare the function SpirvFunction *func = declIdMapper.getOrRegisterFn(callee); auto *retVal = spvBuilder.createFunctionCall( retType, func, vars, callExpr->getCallee()->getExprLoc()); // Go through all parameters and write those marked as out/inout for (uint32_t i = 0; i < numParams; ++i) { const auto *param = callee->getParamDecl(i); const auto paramType = param->getType(); // If it calls a non-static member function, the object itself is argument // 0, and therefore all other argument positions are shifted by 1. const uint32_t index = i + isNonStaticMemberCall; // Using a resouce as a function parameter is never passed-by-copy. As a // result, even if the function parameter is marked as 'out' or 'inout', // there is no reason to copy back the results after the function call into // the resource. if (isTempVar[index] && canActAsOutParmVar(param) && !isResourceType(paramType)) { const auto *arg = callExpr->getArg(i); SpirvInstruction *value = spvBuilder.createLoad(paramType, vars[index], arg->getLocStart()); // Now we want to assign 'value' to arg. But first, in rare cases when // using 'out' or 'inout' where the parameter and argument have a type // mismatch, we need to first cast 'value' to the type of 'arg' because // the AST will not include a cast node. if (!paramTypeMatchesArgType(paramType, arg->getType())) { if (const auto *refType = paramType->getAs()) value = castToType(value, refType->getPointeeType(), arg->getType(), arg->getLocStart()); } processAssignment(arg, value, false, args[index]); } } return retVal; } SpirvInstruction *SpirvEmitter::doCastExpr(const CastExpr *expr) { const Expr *subExpr = expr->getSubExpr(); const QualType subExprType = subExpr->getType(); const QualType toType = expr->getType(); const auto srcLoc = expr->getExprLoc(); switch (expr->getCastKind()) { case CastKind::CK_LValueToRValue: return loadIfGLValue(subExpr); case CastKind::CK_NoOp: return doExpr(subExpr); case CastKind::CK_IntegralCast: case CastKind::CK_FloatingToIntegral: case CastKind::CK_HLSLCC_IntegralCast: case CastKind::CK_HLSLCC_FloatingToIntegral: { // Integer literals in the AST are represented using 64bit APInt // themselves and then implicitly casted into the expected bitwidth. // We need special treatment of integer literals here because generating // a 64bit constant and then explicit casting in SPIR-V requires Int64 // capability. We should avoid introducing unnecessary capabilities to // our best. if (auto *value = tryToEvaluateAsConst(expr)) { value->setRValue(); return value; } auto *value = castToInt(loadIfGLValue(subExpr), subExprType, toType, subExpr->getLocStart()); value->setRValue(); return value; } case CastKind::CK_FloatingCast: case CastKind::CK_IntegralToFloating: case CastKind::CK_HLSLCC_FloatingCast: case CastKind::CK_HLSLCC_IntegralToFloating: { // First try to see if we can do constant folding for floating point // numbers like what we are doing for integers in the above. if (auto *value = tryToEvaluateAsConst(expr)) { value->setRValue(); return value; } auto *value = castToFloat(loadIfGLValue(subExpr), subExprType, toType, subExpr->getLocStart()); value->setRValue(); return value; } case CastKind::CK_IntegralToBoolean: case CastKind::CK_FloatingToBoolean: case CastKind::CK_HLSLCC_IntegralToBoolean: case CastKind::CK_HLSLCC_FloatingToBoolean: { // First try to see if we can do constant folding. if (auto *value = tryToEvaluateAsConst(expr)) { value->setRValue(); return value; } auto *value = castToBool(loadIfGLValue(subExpr), subExprType, toType, subExpr->getLocStart()); value->setRValue(); return value; } case CastKind::CK_HLSLVectorSplat: { const size_t size = hlsl::GetHLSLVecSize(expr->getType()); return createVectorSplat(subExpr, size); } case CastKind::CK_HLSLVectorTruncationCast: { const QualType toVecType = toType; const QualType elemType = hlsl::GetHLSLVecElementType(toType); const auto toSize = hlsl::GetHLSLVecSize(toType); auto *composite = doExpr(subExpr); llvm::SmallVector elements; for (uint32_t i = 0; i < toSize; ++i) { elements.push_back(spvBuilder.createCompositeExtract( elemType, composite, {i}, expr->getExprLoc())); } auto *value = elements.front(); if (toSize > 1) { value = spvBuilder.createCompositeConstruct(toVecType, elements, expr->getExprLoc()); } value->setRValue(); return value; } case CastKind::CK_HLSLVectorToScalarCast: { // The underlying should already be a vector of size 1. assert(hlsl::GetHLSLVecSize(subExprType) == 1); return doExpr(subExpr); } case CastKind::CK_HLSLVectorToMatrixCast: { // If target type is already an 1xN matrix type, we just return the // underlying vector. if (is1xNMatrix(toType)) return doExpr(subExpr); // A vector can have no more than 4 elements. The only remaining case // is casting from size-4 vector to size-2-by-2 matrix. auto *vec = loadIfGLValue(subExpr); QualType elemType = {}; uint32_t rowCount = 0, colCount = 0; const bool isMat = isMxNMatrix(toType, &elemType, &rowCount, &colCount); assert(isMat && rowCount == 2 && colCount == 2); (void)isMat; QualType vec2Type = astContext.getExtVectorType(elemType, 2); auto *subVec1 = spvBuilder.createVectorShuffle(vec2Type, vec, vec, {0, 1}, expr->getLocStart()); auto *subVec2 = spvBuilder.createVectorShuffle(vec2Type, vec, vec, {2, 3}, expr->getLocStart()); auto *mat = spvBuilder.createCompositeConstruct(toType, {subVec1, subVec2}, expr->getLocStart()); mat->setRValue(); return mat; } case CastKind::CK_HLSLMatrixSplat: { // From scalar to matrix uint32_t rowCount = 0, colCount = 0; hlsl::GetHLSLMatRowColCount(toType, rowCount, colCount); // Handle degenerated cases first if (rowCount == 1 && colCount == 1) return doExpr(subExpr); if (colCount == 1) return createVectorSplat(subExpr, rowCount); const auto vecSplat = createVectorSplat(subExpr, colCount); if (rowCount == 1) return vecSplat; if (isa(vecSplat)) { llvm::SmallVector vectors( size_t(rowCount), cast(vecSplat)); auto *value = spvBuilder.getConstantComposite(toType, vectors); value->setRValue(); return value; } else { llvm::SmallVector vectors(size_t(rowCount), vecSplat); auto *value = spvBuilder.createCompositeConstruct(toType, vectors, expr->getLocEnd()); value->setRValue(); return value; } } case CastKind::CK_HLSLMatrixTruncationCast: { const QualType srcType = subExprType; auto *src = doExpr(subExpr); const QualType elemType = hlsl::GetHLSLMatElementType(srcType); llvm::SmallVector indexes; // It is possible that the source matrix is in fact a vector. // Example 1: Truncate float1x3 --> float1x2. // Example 2: Truncate float1x3 --> float1x1. // The front-end disallows float1x3 --> float2x1. { uint32_t srcVecSize = 0, dstVecSize = 0; if (isVectorType(srcType, nullptr, &srcVecSize) && isScalarType(toType)) { auto *val = spvBuilder.createCompositeExtract(toType, src, {0}, expr->getLocStart()); val->setRValue(); return val; } if (isVectorType(srcType, nullptr, &srcVecSize) && isVectorType(toType, nullptr, &dstVecSize)) { for (uint32_t i = 0; i < dstVecSize; ++i) indexes.push_back(i); auto *val = spvBuilder.createVectorShuffle(toType, src, src, indexes, expr->getLocStart()); val->setRValue(); return val; } } uint32_t srcRows = 0, srcCols = 0, dstRows = 0, dstCols = 0; hlsl::GetHLSLMatRowColCount(srcType, srcRows, srcCols); hlsl::GetHLSLMatRowColCount(toType, dstRows, dstCols); const QualType srcRowType = astContext.getExtVectorType(elemType, srcCols); const QualType dstRowType = astContext.getExtVectorType(elemType, dstCols); // Indexes to pass to OpVectorShuffle for (uint32_t i = 0; i < dstCols; ++i) indexes.push_back(i); llvm::SmallVector extractedVecs; for (uint32_t row = 0; row < dstRows; ++row) { // Extract a row SpirvInstruction *rowInstr = spvBuilder.createCompositeExtract( srcRowType, src, {row}, expr->getExprLoc()); // Extract the necessary columns from that row. // The front-end ensures dstCols <= srcCols. // If dstCols equals srcCols, we can use the whole row directly. if (dstCols == 1) { rowInstr = spvBuilder.createCompositeExtract(elemType, rowInstr, {0}, expr->getLocStart()); } else if (dstCols < srcCols) { rowInstr = spvBuilder.createVectorShuffle( dstRowType, rowInstr, rowInstr, indexes, expr->getLocStart()); } extractedVecs.push_back(rowInstr); } auto *val = extractedVecs.front(); if (extractedVecs.size() > 1) { val = spvBuilder.createCompositeConstruct(toType, extractedVecs, expr->getExprLoc()); } val->setRValue(); return val; } case CastKind::CK_HLSLMatrixToScalarCast: { // The underlying should already be a matrix of 1x1. assert(is1x1Matrix(subExprType)); return doExpr(subExpr); } case CastKind::CK_HLSLMatrixToVectorCast: { // If the underlying matrix is Mx1 or 1xM for M in {1, 2,3,4}, we can return // the underlying matrix because it'll be evaluated as a vector by default. if (is1x1Matrix(subExprType) || is1xNMatrix(subExprType) || isMx1Matrix(subExprType)) return doExpr(subExpr); // A vector can have no more than 4 elements. The only remaining case // is casting from a 2x2 matrix to a vector of size 4. auto *mat = loadIfGLValue(subExpr); QualType elemType = {}; uint32_t rowCount = 0, colCount = 0, elemCount = 0; const bool isMat = isMxNMatrix(subExprType, &elemType, &rowCount, &colCount); const bool isVec = isVectorType(toType, nullptr, &elemCount); assert(isMat && rowCount == 2 && colCount == 2); assert(isVec && elemCount == 4); (void)isMat; (void)isVec; QualType vec2Type = astContext.getExtVectorType(elemType, 2); auto *row0 = spvBuilder.createCompositeExtract(vec2Type, mat, {0}, srcLoc); auto *row1 = spvBuilder.createCompositeExtract(vec2Type, mat, {1}, srcLoc); auto *vec = spvBuilder.createVectorShuffle(toType, row0, row1, {0, 1, 2, 3}, srcLoc); vec->setRValue(); return vec; } case CastKind::CK_FunctionToPointerDecay: // Just need to return the function id return doExpr(subExpr); case CastKind::CK_FlatConversion: { SpirvInstruction *subExprInstr = nullptr; QualType evalType = subExprType; // Optimization: we can use OpConstantNull for cases where we want to // initialize an entire data structure to zeros. if (evaluatesToConstZero(subExpr, astContext)) { subExprInstr = spvBuilder.getConstantNull(toType); subExprInstr->setRValue(); return subExprInstr; } // Try to evaluate float literals as float rather than double. if (const auto *floatLiteral = dyn_cast(subExpr)) { subExprInstr = tryToEvaluateAsFloat32(floatLiteral->getValue()); if (subExprInstr) evalType = astContext.FloatTy; } // Evaluate 'literal float' initializer type as float rather than double. // TODO: This could result in rounding error if the initializer is a // non-literal expression that requires larger than 32 bits and has the // 'literal float' type. else if (subExprType->isSpecificBuiltinType(BuiltinType::LitFloat)) { evalType = astContext.FloatTy; } // Try to evaluate integer literals as 32-bit int rather than 64-bit int. else if (const auto *intLiteral = dyn_cast(subExpr)) { const bool isSigned = subExprType->isSignedIntegerType(); subExprInstr = tryToEvaluateAsInt32(intLiteral->getValue(), isSigned); if (subExprInstr) evalType = isSigned ? astContext.IntTy : astContext.UnsignedIntTy; } // For assigning one array instance to another one with the same array type // (regardless of constness and literalness), the rhs will be wrapped in a // FlatConversion. Similarly for assigning a struct to another struct with // identical members. // |- // `- ImplicitCastExpr // `- ImplicitCastExpr // `- else if (isSameType(astContext, toType, evalType) || // We can have casts changing the shape but without affecting // memory order, e.g., `float4 a[2]; float b[8] = (float[8])a;`. // This is also represented as FlatConversion. For such cases, we // can rely on the InitListHandler, which can decompse // vectors/matrices. subExprType->isArrayType()) { auto *valInstr = InitListHandler(astContext, *this).processCast(toType, subExpr); if (valInstr) valInstr->setRValue(); return valInstr; } // We can have casts changing the shape but without affecting memory order, // e.g., `float4 a[2]; float b[8] = (float[8])a;`. This is also represented // as FlatConversion. For such cases, we can rely on the InitListHandler, // which can decompse vectors/matrices. else if (subExprType->isArrayType()) { auto *valInstr = InitListHandler(astContext, *this) .processCast(expr->getType(), subExpr); if (valInstr) valInstr->setRValue(); return valInstr; } if (!subExprInstr) subExprInstr = doExpr(subExpr); auto *val = processFlatConversion(toType, evalType, subExprInstr, expr->getExprLoc()); val->setRValue(); return val; } case CastKind::CK_UncheckedDerivedToBase: case CastKind::CK_HLSLDerivedToBase: { // Find the index sequence of the base to which we are casting llvm::SmallVector baseIndices; getBaseClassIndices(expr, &baseIndices); // Turn them in to SPIR-V constants llvm::SmallVector baseIndexInstructions( baseIndices.size(), nullptr); for (uint32_t i = 0; i < baseIndices.size(); ++i) baseIndexInstructions[i] = spvBuilder.getConstantInt( astContext.UnsignedIntTy, llvm::APInt(32, baseIndices[i])); auto *derivedInfo = doExpr(subExpr); return turnIntoElementPtr(subExpr->getType(), derivedInfo, expr->getType(), baseIndexInstructions, subExpr->getExprLoc()); } case CastKind::CK_ArrayToPointerDecay: { // Literal string to const string conversion falls under this category. if (hlsl::IsStringLiteralType(subExprType) && hlsl::IsStringType(toType)) { return doExpr(subExpr); } else { emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc()) << expr->getCastKindName() << expr->getSourceRange(); expr->dump(); return 0; } } default: emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc()) << expr->getCastKindName() << expr->getSourceRange(); expr->dump(); return 0; } } SpirvInstruction *SpirvEmitter::processFlatConversion( const QualType type, const QualType initType, SpirvInstruction *initInstr, SourceLocation srcLoc) { // When translating ConstantBuffer or TextureBuffer types, we consider // the underlying type (T), and therefore we should bypass the FlatConversion // node when accessing these types: // `-MemberExpr // `-ImplicitCastExpr 'const T' lvalue // `-ArraySubscriptExpr 'ConstantBuffer':'ConstantBuffer' lvalue if (isConstantTextureBuffer(initType)) { return initInstr; } // Try to translate the canonical type first const auto canonicalType = type.getCanonicalType(); if (canonicalType != type) return processFlatConversion(canonicalType, initType, initInstr, srcLoc); // Primitive types { QualType ty = {}; if (isScalarType(type, &ty)) { if (const auto *builtinType = ty->getAs()) { switch (builtinType->getKind()) { case BuiltinType::Void: { emitError("cannot create a constant of void type", srcLoc); return 0; } case BuiltinType::Bool: return castToBool(initInstr, initType, ty, srcLoc); // Target type is an integer variant. case BuiltinType::Int: case BuiltinType::Short: case BuiltinType::Min12Int: case BuiltinType::Min16Int: case BuiltinType::Min16UInt: case BuiltinType::UShort: case BuiltinType::UInt: case BuiltinType::Long: case BuiltinType::LongLong: case BuiltinType::ULong: case BuiltinType::ULongLong: case BuiltinType::Int8_4Packed: case BuiltinType::UInt8_4Packed: return castToInt(initInstr, initType, ty, srcLoc); // Target type is a float variant. case BuiltinType::Double: case BuiltinType::Float: case BuiltinType::Half: case BuiltinType::HalfFloat: case BuiltinType::Min10Float: case BuiltinType::Min16Float: return castToFloat(initInstr, initType, ty, srcLoc); default: emitError("flat conversion of type %0 unimplemented", srcLoc) << builtinType->getTypeClassName(); return 0; } } } } // Vector types { QualType elemType = {}; uint32_t elemCount = {}; if (isVectorType(type, &elemType, &elemCount)) { auto *elem = processFlatConversion(elemType, initType, initInstr, srcLoc); llvm::SmallVector constituents(size_t(elemCount), elem); return spvBuilder.createCompositeConstruct(type, constituents, srcLoc); } } // Matrix types { QualType elemType = {}; uint32_t rowCount = 0, colCount = 0; if (isMxNMatrix(type, &elemType, &rowCount, &colCount)) { // By default HLSL matrices are row major, while SPIR-V matrices are // column major. We are mapping what HLSL semantically mean a row into a // column here. const QualType vecType = astContext.getExtVectorType(elemType, colCount); auto *elem = processFlatConversion(elemType, initType, initInstr, srcLoc); const llvm::SmallVector constituents( size_t(colCount), elem); auto *col = spvBuilder.createCompositeConstruct(vecType, constituents, srcLoc); const llvm::SmallVector rows(size_t(rowCount), col); return spvBuilder.createCompositeConstruct(type, rows, srcLoc); } } // Struct type if (const auto *structType = type->getAs()) { const auto *decl = structType->getDecl(); llvm::SmallVector fields; for (const auto *field : decl->fields()) { // There is a special case for FlatConversion. If T is a struct with only // one member, S, then (T) is allowed, which essentially // constructs a new T instance using the instance of S as its only member. // Check whether we are handling that case here first. if (field->getType().getCanonicalType() == initType.getCanonicalType()) { fields.push_back(initInstr); } else { fields.push_back(processFlatConversion(field->getType(), initType, initInstr, srcLoc)); } } return spvBuilder.createCompositeConstruct(type, fields, srcLoc); } // Array type if (const auto *arrayType = astContext.getAsConstantArrayType(type)) { const auto size = static_cast(arrayType->getSize().getZExtValue()); auto *elem = processFlatConversion(arrayType->getElementType(), initType, initInstr, srcLoc); llvm::SmallVector constituents(size_t(size), elem); return spvBuilder.createCompositeConstruct(type, constituents, srcLoc); } emitError("flat conversion of type %0 unimplemented", {}) << type->getTypeClassName(); type->dump(); return 0; } SpirvInstruction * SpirvEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) { const auto opcode = expr->getOpcode(); // Try to optimize floatMxN *= float and floatN *= float case if (opcode == BO_MulAssign) { if (auto *result = tryToGenFloatMatrixScale(expr)) return result; if (auto *result = tryToGenFloatVectorScale(expr)) return result; } const auto *rhs = expr->getRHS(); const auto *lhs = expr->getLHS(); SpirvInstruction *lhsPtr = nullptr; auto *result = processBinaryOp( lhs, rhs, opcode, expr->getComputationLHSType(), expr->getType(), expr->getSourceRange(), expr->getOperatorLoc(), &lhsPtr); return processAssignment(lhs, result, true, lhsPtr); } SpirvInstruction * SpirvEmitter::doConditionalOperator(const ConditionalOperator *expr) { const auto type = expr->getType(); const SourceLocation loc = expr->getExprLoc(); const Expr *cond = expr->getCond(); const Expr *falseExpr = expr->getFalseExpr(); const Expr *trueExpr = expr->getTrueExpr(); // According to HLSL doc, all sides of the ?: expression are always evaluated. // Corner-case: In HLSL, the condition of the ternary operator can be a // matrix of booleans which results in selecting between components of two // matrices. However, a matrix of booleans is not a valid type in SPIR-V. // If the AST has inserted a splat of a scalar/vector to a matrix, we can just // use that scalar/vector as an if-clause condition. if (auto *cast = dyn_cast(cond)) if (cast->getCastKind() == CK_HLSLMatrixSplat) cond = cast->getSubExpr(); // If we are selecting between two SampleState objects, none of the three // operands has a LValueToRValue implicit cast. auto *condition = loadIfGLValue(cond); auto *trueBranch = loadIfGLValue(trueExpr); auto *falseBranch = loadIfGLValue(falseExpr); // Corner-case: In HLSL, the condition of the ternary operator can be a // matrix of booleans which results in selecting between components of two // matrices. However, a matrix of booleans is not a valid type in SPIR-V. // Therefore, we need to perform OpSelect for each row of the matrix. { QualType condElemType = {}, elemType = {}; uint32_t rowCount = 0, colCount = 0; if (isMxNMatrix(type, &elemType, &rowCount, &colCount) && isMxNMatrix(cond->getType(), &condElemType) && condElemType->isBooleanType()) { const auto rowType = astContext.getExtVectorType(elemType, colCount); const auto condRowType = astContext.getExtVectorType(condElemType, colCount); llvm::SmallVector rows; for (uint32_t i = 0; i < rowCount; ++i) { auto *condRow = spvBuilder.createCompositeExtract(condRowType, condition, {i}, loc); auto *trueRow = spvBuilder.createCompositeExtract(rowType, trueBranch, {i}, loc); auto *falseRow = spvBuilder.createCompositeExtract(rowType, falseBranch, {i}, loc); rows.push_back( spvBuilder.createSelect(rowType, condRow, trueRow, falseRow, loc)); } auto *result = spvBuilder.createCompositeConstruct(type, rows, loc); result->setRValue(); return result; } } // For cases where the return type is a scalar or a vector, we can use // OpSelect to choose between the two. OpSelect's return type must be either // scalar or vector. if (isScalarType(type) || isVectorType(type)) { // The SPIR-V OpSelect instruction must have a selection argument that is // the same size as the return type. If the return type is a vector, the // selection must be a vector of booleans (one per output component). uint32_t count = 0; if (isVectorType(expr->getType(), nullptr, &count) && !isVectorType(expr->getCond()->getType())) { const llvm::SmallVector components(size_t(count), condition); condition = spvBuilder.createCompositeConstruct( astContext.getExtVectorType(astContext.BoolTy, count), components, expr->getCond()->getLocEnd()); } auto *value = spvBuilder.createSelect(type, condition, trueBranch, falseBranch, loc); value->setRValue(); return value; } // If we can't use OpSelect, we need to create if-else control flow. auto *tempVar = spvBuilder.addFnVar(type, loc, "temp.var.ternary"); auto *thenBB = spvBuilder.createBasicBlock("if.true"); auto *mergeBB = spvBuilder.createBasicBlock("if.merge"); auto *elseBB = spvBuilder.createBasicBlock("if.false"); // Create the branch instruction. This will end the current basic block. spvBuilder.createConditionalBranch(condition, thenBB, elseBB, expr->getCond()->getLocEnd(), mergeBB); spvBuilder.addSuccessor(thenBB); spvBuilder.addSuccessor(elseBB); spvBuilder.setMergeTarget(mergeBB); // Handle the then branch spvBuilder.setInsertPoint(thenBB); spvBuilder.createStore(tempVar, trueBranch, expr->getTrueExpr()->getLocStart()); spvBuilder.createBranch(mergeBB, expr->getTrueExpr()->getLocEnd()); spvBuilder.addSuccessor(mergeBB); // Handle the else branch spvBuilder.setInsertPoint(elseBB); spvBuilder.createStore(tempVar, falseBranch, expr->getFalseExpr()->getLocStart()); spvBuilder.createBranch(mergeBB, expr->getFalseExpr()->getLocEnd()); spvBuilder.addSuccessor(mergeBB); // From now on, emit instructions into the merge block. spvBuilder.setInsertPoint(mergeBB); auto *result = spvBuilder.createLoad(type, tempVar, expr->getLocEnd()); result->setRValue(); return result; } SpirvInstruction * SpirvEmitter::processByteAddressBufferStructuredBufferGetDimensions( const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument(); auto *objectInstr = loadIfAliasVarRef(object); const auto type = object->getType(); const bool isBABuf = isByteAddressBuffer(type) || isRWByteAddressBuffer(type); const bool isStructuredBuf = isStructuredBuffer(type) || isAppendStructuredBuffer(type) || isConsumeStructuredBuffer(type); assert(isBABuf || isStructuredBuf); // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure // with only one member that is a runtime array. We need to perform // OpArrayLength on member 0. SpirvInstruction *length = spvBuilder.createArrayLength( astContext.UnsignedIntTy, expr->getExprLoc(), objectInstr, 0); // For (RW)ByteAddressBuffers, GetDimensions() must return the array length // in bytes, but OpArrayLength returns the number of uints in the runtime // array. Therefore we must multiply the results by 4. if (isBABuf) { length = spvBuilder.createBinaryOp( spv::Op::OpIMul, astContext.UnsignedIntTy, length, // TODO(jaebaek): What line info we should emit for constants? spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 4u)), expr->getExprLoc()); } spvBuilder.createStore(doExpr(expr->getArg(0)), length, expr->getArg(0)->getLocStart()); if (isStructuredBuf) { // For (RW)StructuredBuffer, the stride of the runtime array (which is the // size of the struct) must also be written to the second argument. AlignmentSizeCalculator alignmentCalc(astContext, spirvOptions); uint32_t size = 0, stride = 0; std::tie(std::ignore, size) = alignmentCalc.getAlignmentAndSize(type, spirvOptions.sBufferLayoutRule, /*isRowMajor*/ llvm::None, &stride); auto *sizeInstr = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, size)); spvBuilder.createStore(doExpr(expr->getArg(1)), sizeInstr, expr->getArg(1)->getLocStart()); } return nullptr; } SpirvInstruction *SpirvEmitter::processRWByteAddressBufferAtomicMethods( hlsl::IntrinsicOp opcode, const CXXMemberCallExpr *expr) { // The signature of RWByteAddressBuffer atomic methods are largely: // void Interlocked*(in UINT dest, in UINT value); // void Interlocked*(in UINT dest, in UINT value, out UINT original_value); const auto *object = expr->getImplicitObjectArgument(); auto *objectInfo = loadIfAliasVarRef(object); auto *zero = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); auto *offset = doExpr(expr->getArg(0)); // Right shift by 2 to convert the byte offset to uint32_t offset auto *address = spvBuilder.createBinaryOp( spv::Op::OpShiftRightLogical, astContext.UnsignedIntTy, offset, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2)), expr->getExprLoc()); auto *ptr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, objectInfo, {zero, address}, object->getLocStart()); const bool isCompareExchange = opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange; const bool isCompareStore = opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareStore; if (isCompareExchange || isCompareStore) { auto *comparator = doExpr(expr->getArg(1)); auto *originalVal = spvBuilder.createAtomicCompareExchange( astContext.UnsignedIntTy, ptr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone, spv::MemorySemanticsMask::MaskNone, doExpr(expr->getArg(2)), comparator, expr->getCallee()->getExprLoc()); if (isCompareExchange) spvBuilder.createStore(doExpr(expr->getArg(3)), originalVal, expr->getArg(3)->getLocStart()); } else { auto *value = doExpr(expr->getArg(1)); SpirvInstruction *originalVal = spvBuilder.createAtomicOp( translateAtomicHlslOpcodeToSpirvOpcode(opcode), astContext.UnsignedIntTy, ptr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone, value, expr->getCallee()->getExprLoc()); if (expr->getNumArgs() > 2) { originalVal = castToType(originalVal, astContext.UnsignedIntTy, expr->getArg(2)->getType(), expr->getArg(2)->getLocStart()); spvBuilder.createStore(doExpr(expr->getArg(2)), originalVal, expr->getArg(2)->getLocStart()); } } return nullptr; } SpirvInstruction * SpirvEmitter::processGetSamplePosition(const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument()->IgnoreParens(); auto *sampleCount = spvBuilder.createImageQuery( spv::Op::OpImageQuerySamples, astContext.UnsignedIntTy, expr->getExprLoc(), loadIfGLValue(object)); if (!spirvOptions.noWarnEmulatedFeatures) emitWarning("GetSamplePosition is emulated using many SPIR-V instructions " "due to lack of direct SPIR-V equivalent, so it only supports " "standard sample settings with 1, 2, 4, 8, or 16 samples and " "will return float2(0, 0) for other cases", expr->getCallee()->getExprLoc()); return emitGetSamplePosition(sampleCount, doExpr(expr->getArg(0)), expr->getCallee()->getExprLoc()); } SpirvInstruction * SpirvEmitter::processSubpassLoad(const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument()->IgnoreParens(); SpirvInstruction *sample = expr->getNumArgs() == 1 ? doExpr(expr->getArg(0)) : nullptr; auto *zero = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)); auto *location = spvBuilder.getConstantComposite( astContext.getExtVectorType(astContext.IntTy, 2), {zero, zero}); return processBufferTextureLoad(object, location, /*constOffset*/ 0, /*varOffset*/ 0, /*lod*/ sample, /*residencyCode*/ 0, expr->getExprLoc()); } SpirvInstruction * SpirvEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument(); auto *objectInstr = loadIfGLValue(object); const auto type = object->getType(); const auto *recType = type->getAs(); assert(recType); const auto typeName = recType->getDecl()->getName(); const auto numArgs = expr->getNumArgs(); const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr; assert(isTexture(type) || isRWTexture(type) || isBuffer(type) || isRWBuffer(type)); // For Texture1D, arguments are either: // a) width // b) MipLevel, width, NumLevels // For Texture1DArray, arguments are either: // a) width, elements // b) MipLevel, width, elements, NumLevels // For Texture2D, arguments are either: // a) width, height // b) MipLevel, width, height, NumLevels // For Texture2DArray, arguments are either: // a) width, height, elements // b) MipLevel, width, height, elements, NumLevels // For Texture3D, arguments are either: // a) width, height, depth // b) MipLevel, width, height, depth, NumLevels // For Texture2DMS, arguments are: width, height, NumSamples // For Texture2DMSArray, arguments are: width, height, elements, NumSamples // For TextureCube, arguments are either: // a) width, height // b) MipLevel, width, height, NumLevels // For TextureCubeArray, arguments are either: // a) width, height, elements // b) MipLevel, width, height, elements, NumLevels // Note: SPIR-V Spec requires return type of OpImageQuerySize(Lod) to be a // scalar/vector of integers. SPIR-V Spec also requires return type of // OpImageQueryLevels and OpImageQuerySamples to be scalar integers. // The HLSL methods, however, have overloaded functions which have float // output arguments. Since the AST naturally won't have casting AST nodes for // such cases, we'll have to perform the cast ourselves. const auto storeToOutputArg = [this](const Expr *outputArg, SpirvInstruction *id, QualType type) { id = castToType(id, type, outputArg->getType(), outputArg->getExprLoc()); spvBuilder.createStore(doExpr(outputArg), id, outputArg->getLocStart()); }; if ((typeName == "Texture1D" && numArgs > 1) || (typeName == "Texture2D" && numArgs > 2) || (typeName == "TextureCube" && numArgs > 2) || (typeName == "Texture3D" && numArgs > 3) || (typeName == "Texture1DArray" && numArgs > 2) || (typeName == "TextureCubeArray" && numArgs > 3) || (typeName == "Texture2DArray" && numArgs > 3)) { mipLevel = expr->getArg(0); numLevels = expr->getArg(numArgs - 1); } if (isTextureMS(type)) { numSamples = expr->getArg(numArgs - 1); } uint32_t querySize = numArgs; // If numLevels arg is present, mipLevel must also be present. These are not // queried via ImageQuerySizeLod. if (numLevels) querySize -= 2; // If numLevels arg is present, mipLevel must also be present. else if (numSamples) querySize -= 1; const QualType resultQualType = querySize == 1 ? astContext.UnsignedIntTy : astContext.getExtVectorType(astContext.UnsignedIntTy, querySize); // Only Texture types use ImageQuerySizeLod. // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize. SpirvInstruction *lod = nullptr; if (isTexture(type) && !numSamples) { if (mipLevel) { // For Texture types when mipLevel argument is present. lod = doExpr(mipLevel); } else { // For Texture types when mipLevel argument is omitted. lod = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)); } } SpirvInstruction *query = lod ? cast(spvBuilder.createImageQuery( spv::Op::OpImageQuerySizeLod, resultQualType, expr->getCallee()->getExprLoc(), objectInstr, lod)) : cast(spvBuilder.createImageQuery( spv::Op::OpImageQuerySize, resultQualType, expr->getCallee()->getExprLoc(), objectInstr)); if (querySize == 1) { const uint32_t argIndex = mipLevel ? 1 : 0; storeToOutputArg(expr->getArg(argIndex), query, resultQualType); } else { for (uint32_t i = 0; i < querySize; ++i) { const uint32_t argIndex = mipLevel ? i + 1 : i; auto *component = spvBuilder.createCompositeExtract( astContext.UnsignedIntTy, query, {i}, expr->getCallee()->getExprLoc()); // If the first arg is the mipmap level, we must write the results // starting from Arg(i+1), not Arg(i). storeToOutputArg(expr->getArg(argIndex), component, astContext.UnsignedIntTy); } } if (numLevels || numSamples) { const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples; const spv::Op opcode = numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples; auto *numLevelsSamplesQuery = spvBuilder.createImageQuery( opcode, astContext.UnsignedIntTy, expr->getCallee()->getExprLoc(), objectInstr); storeToOutputArg(numLevelsSamplesArg, numLevelsSamplesQuery, astContext.UnsignedIntTy); } return nullptr; } SpirvInstruction * SpirvEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr, bool unclamped) { // Possible signatures are as follows: // Texture1D(Array).CalculateLevelOfDetail(SamplerState S, float x); // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy); // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz); // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz); // Return type is always a single float (LOD). assert(expr->getNumArgs() == 2u); const auto *object = expr->getImplicitObjectArgument(); auto *objectInfo = loadIfGLValue(object); auto *samplerState = doExpr(expr->getArg(0)); auto *coordinate = doExpr(expr->getArg(1)); auto *sampledImage = spvBuilder.createSampledImage( object->getType(), objectInfo, samplerState, expr->getExprLoc()); // The result type of OpImageQueryLod must be a float2. const QualType queryResultType = astContext.getExtVectorType(astContext.FloatTy, 2u); auto *query = spvBuilder.createImageQuery(spv::Op::OpImageQueryLod, queryResultType, expr->getExprLoc(), sampledImage, coordinate); // The first component of the float2 contains the mipmap array layer. // The second component of the float2 represents the unclamped lod. return spvBuilder.createCompositeExtract(astContext.FloatTy, query, unclamped ? 1 : 0, expr->getCallee()->getExprLoc()); } SpirvInstruction *SpirvEmitter::processTextureGatherRGBACmpRGBA( const CXXMemberCallExpr *expr, const bool isCmp, const uint32_t component) { // Parameters for .Gather{Red|Green|Blue|Alpha}() are one of the following // two sets: // * SamplerState s, float2 location, int2 offset // * SamplerState s, float2 location, int2 offset0, int2 offset1, // int offset2, int2 offset3 // // An additional 'out uint status' parameter can appear in both of the above. // // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() are one of the following // two sets: // * SamplerState s, float2 location, float compare_value, int2 offset // * SamplerState s, float2 location, float compare_value, int2 offset1, // int2 offset2, int2 offset3, int2 offset4 // // An additional 'out uint status' parameter can appear in both of the above. // // TextureCube's signature is somewhat different from the rest. // Parameters for .Gather{Red|Green|Blue|Alpha}() for TextureCube are: // * SamplerState s, float2 location, out uint status // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() for TextureCube are: // * SamplerState s, float2 location, float compare_value, out uint status // // Return type is always a 4-component vector. const FunctionDecl *callee = expr->getDirectCallee(); const auto numArgs = expr->getNumArgs(); const auto *imageExpr = expr->getImplicitObjectArgument(); const auto loc = expr->getCallee()->getExprLoc(); const QualType imageType = imageExpr->getType(); const QualType retType = callee->getReturnType(); // If the last arg is an unsigned integer, it must be the status. const bool hasStatusArg = expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType(); // Subtract 1 for status arg (if it exists), subtract 1 for compare_value (if // it exists), and subtract 2 for SamplerState and location. const auto numOffsetArgs = numArgs - hasStatusArg - isCmp - 2; // No offset args for TextureCube, 1 or 4 offset args for the rest. assert(numOffsetArgs == 0 || numOffsetArgs == 1 || numOffsetArgs == 4); auto *image = loadIfGLValue(imageExpr); auto *sampler = doExpr(expr->getArg(0)); auto *coordinate = doExpr(expr->getArg(1)); auto *compareVal = isCmp ? doExpr(expr->getArg(2)) : nullptr; // Handle offsets (if any). bool needsEmulation = false; SpirvInstruction *constOffset = nullptr, *varOffset = nullptr, *constOffsets = nullptr; if (numOffsetArgs == 1) { // The offset arg is not optional. handleOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset); } else if (numOffsetArgs == 4) { auto *offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp)); auto *offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp)); auto *offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp)); auto *offset3 = tryToEvaluateAsConst(expr->getArg(5 + isCmp)); // If any of the offsets is not constant, we then need to emulate the call // using 4 OpImageGather instructions. Otherwise, we can leverage the // ConstOffsets image operand. if (offset0 && offset1 && offset2 && offset3) { const QualType v2i32 = astContext.getExtVectorType(astContext.IntTy, 2); const auto offsetType = astContext.getConstantArrayType( v2i32, llvm::APInt(32, 4), clang::ArrayType::Normal, 0); constOffsets = spvBuilder.getConstantComposite( offsetType, {offset0, offset1, offset2, offset3}); } else { needsEmulation = true; } } auto *status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr; if (needsEmulation) { const auto elemType = hlsl::GetHLSLVecElementType(callee->getReturnType()); SpirvInstruction *texels[4]; for (uint32_t i = 0; i < 4; ++i) { varOffset = doExpr(expr->getArg(2 + isCmp + i)); auto *gatherRet = spvBuilder.createImageGather( retType, imageType, image, sampler, coordinate, spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, component, true)), compareVal, /*constOffset*/ nullptr, varOffset, /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr, status, loc); texels[i] = spvBuilder.createCompositeExtract(elemType, gatherRet, {i}, loc); } return spvBuilder.createCompositeConstruct( retType, {texels[0], texels[1], texels[2], texels[3]}, loc); } return spvBuilder.createImageGather( retType, imageType, image, sampler, coordinate, spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, component, true)), compareVal, constOffset, varOffset, constOffsets, /*sampleNumber*/ nullptr, status, loc); } SpirvInstruction * SpirvEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) { // Signature for Texture2D/Texture2DArray: // // float4 GatherCmp( // in SamplerComparisonState s, // in float2 location, // in float compare_value // [,in int2 offset] // [,out uint Status] // ); // // Signature for TextureCube/TextureCubeArray: // // float4 GatherCmp( // in SamplerComparisonState s, // in float2 location, // in float compare_value, // out uint Status // ); // // Other Texture types do not have the GatherCmp method. const FunctionDecl *callee = expr->getDirectCallee(); const auto numArgs = expr->getNumArgs(); const auto loc = expr->getExprLoc(); const bool hasStatusArg = expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType(); const bool hasOffsetArg = (numArgs == 5) || (numArgs == 4 && !hasStatusArg); const auto *imageExpr = expr->getImplicitObjectArgument(); auto *image = loadIfGLValue(imageExpr); auto *sampler = doExpr(expr->getArg(0)); auto *coordinate = doExpr(expr->getArg(1)); auto *comparator = doExpr(expr->getArg(2)); SpirvInstruction *constOffset = nullptr, *varOffset = nullptr; if (hasOffsetArg) handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset); const auto retType = callee->getReturnType(); const auto imageType = imageExpr->getType(); const auto status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : nullptr; return spvBuilder.createImageGather( retType, imageType, image, sampler, coordinate, /*component*/ nullptr, comparator, constOffset, varOffset, /*constOffsets*/ nullptr, /*sampleNumber*/ nullptr, status, loc); } SpirvInstruction *SpirvEmitter::processBufferTextureLoad( const Expr *object, SpirvInstruction *location, SpirvInstruction *constOffset, SpirvInstruction *varOffset, SpirvInstruction *lod, SpirvInstruction *residencyCode, SourceLocation loc) { // Loading for Buffer and RWBuffer translates to an OpImageFetch. // The result type of an OpImageFetch must be a vec4 of float or int. const auto type = object->getType(); assert(isBuffer(type) || isRWBuffer(type) || isTexture(type) || isRWTexture(type) || isSubpassInput(type) || isSubpassInputMS(type)); const bool doFetch = isBuffer(type) || isTexture(type); auto *objectInfo = loadIfGLValue(object); // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod. SpirvInstruction *sampleNumber = nullptr; if (isTextureMS(type) || isSubpassInputMS(type)) { sampleNumber = lod; lod = nullptr; } const auto sampledType = hlsl::GetHLSLResourceResultType(type); QualType elemType = sampledType; uint32_t elemCount = 1; bool isTemplateOverStruct = false; // Check whether the template type is a vector type or struct type. if (!isVectorType(sampledType, &elemType, &elemCount)) { if (sampledType->getAsStructureType()) { isTemplateOverStruct = true; // For struct type, we need to make sure it can fit into a 4-component // vector. Detailed failing reasons will be emitted by the function so // we don't need to emit errors here. if (!canFitIntoOneRegister(astContext, sampledType, &elemType, &elemCount)) return nullptr; } } { // Treat a vector of size 1 the same as a scalar. if (hlsl::IsHLSLVecType(elemType) && hlsl::GetHLSLVecSize(elemType) == 1) elemType = hlsl::GetHLSLVecElementType(elemType); if (!elemType->isFloatingType() && !elemType->isIntegerType()) { emitError("loading %0 value unsupported", object->getExprLoc()) << type; return nullptr; } } // If residencyCode is nullptr, we are dealing with a Load method with 2 // arguments which does not return the operation status. if (residencyCode && residencyCode->isRValue()) { emitError( "an lvalue argument should be used for returning the operation status", loc); return nullptr; } // OpImageFetch and OpImageRead can only fetch a vector of 4 elements. const QualType texelType = astContext.getExtVectorType(elemType, 4u); auto *texel = spvBuilder.createImageFetchOrRead( doFetch, texelType, type, objectInfo, location, lod, constOffset, varOffset, /*constOffsets*/ nullptr, sampleNumber, residencyCode, loc); // If the result type is a vec1, vec2, or vec3, some extra processing // (extraction) is required. auto *retVal = extractVecFromVec4(texel, elemCount, elemType, loc); if (isTemplateOverStruct) { // Convert to the struct so that we are consistent with types in the AST. retVal = convertVectorToStruct(sampledType, elemType, retVal, loc); } retVal->setRValue(); return retVal; } SpirvInstruction *SpirvEmitter::processByteAddressBufferLoadStore( const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) { SpirvInstruction *result = nullptr; const auto object = expr->getImplicitObjectArgument(); auto *objectInfo = loadIfAliasVarRef(object); assert(numWords >= 1 && numWords <= 4); if (doStore) { assert(isRWByteAddressBuffer(object->getType())); assert(expr->getNumArgs() == 2); } else { assert(isRWByteAddressBuffer(object->getType()) || isByteAddressBuffer(object->getType())); if (expr->getNumArgs() == 2) { emitError( "(RW)ByteAddressBuffer::Load(in address, out status) not supported", expr->getExprLoc()); return 0; } } const Expr *addressExpr = expr->getArg(0); auto *byteAddress = doExpr(addressExpr); const QualType addressType = addressExpr->getType(); // The front-end prevents usage of templated Load2, Load3, Load4, Store2, // Store3, Store4 intrinsic functions. const bool isTemplatedLoadOrStore = (numWords == 1) && (doStore ? !expr->getArg(1)->getType()->isSpecificBuiltinType( BuiltinType::UInt) : !expr->getType()->isSpecificBuiltinType(BuiltinType::UInt)); // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory // access). The AST always casts the address to unsinged integer, so shift // by unsinged integer 2. auto *constUint2 = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2)); SpirvInstruction *address = spvBuilder.createBinaryOp(spv::Op::OpShiftRightLogical, addressType, byteAddress, constUint2, expr->getExprLoc()); if (isTemplatedLoadOrStore) { // Templated load. Need to (potentially) perform more // loads/casts/composite-constructs. uint32_t bitOffset = 0; if (doStore) { auto *values = doExpr(expr->getArg(1)); RawBufferHandler(*this).processTemplatedStoreToBuffer( values, objectInfo, address, expr->getArg(1)->getType(), bitOffset); return nullptr; } else { RawBufferHandler rawBufferHandler(*this); return rawBufferHandler.processTemplatedLoadFromBuffer( objectInfo, address, expr->getType(), bitOffset); } } // Perform access chain into the RWByteAddressBuffer. // First index must be zero (member 0 of the struct is a // runtimeArray). The second index passed to OpAccessChain should be // the address. auto *constUint0 = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); if (doStore) { auto *values = doExpr(expr->getArg(1)); auto *curStoreAddress = address; for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) { // Extract a 32-bit word from the input. auto *curValue = numWords == 1 ? values : spvBuilder.createCompositeExtract( astContext.UnsignedIntTy, values, {wordCounter}, expr->getArg(1)->getExprLoc()); // Update the output address if necessary. if (wordCounter > 0) { auto *offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, wordCounter)); curStoreAddress = spvBuilder.createBinaryOp(spv::Op::OpIAdd, addressType, address, offset, expr->getCallee()->getExprLoc()); } // Store the word to the right address at the output. auto *storePtr = spvBuilder.createAccessChain( astContext.UnsignedIntTy, objectInfo, {constUint0, curStoreAddress}, object->getLocStart()); spvBuilder.createStore(storePtr, curValue, expr->getCallee()->getExprLoc()); } } else { auto *loadPtr = spvBuilder.createAccessChain( astContext.UnsignedIntTy, objectInfo, {constUint0, address}, object->getLocStart()); result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, expr->getCallee()->getExprLoc()); if (numWords > 1) { // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to // return a vector result. llvm::SmallVector values; values.push_back(result); for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) { auto *offset = spvBuilder.getConstantInt( astContext.UnsignedIntTy, llvm::APInt(32, wordCounter - 1)); auto *newAddress = spvBuilder.createBinaryOp(spv::Op::OpIAdd, addressType, address, offset, expr->getCallee()->getExprLoc()); loadPtr = spvBuilder.createAccessChain( astContext.UnsignedIntTy, objectInfo, {constUint0, newAddress}, object->getLocStart()); values.push_back( spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, expr->getCallee()->getExprLoc())); } const QualType resultType = astContext.getExtVectorType(addressType, numWords); result = spvBuilder.createCompositeConstruct(resultType, values, expr->getLocStart()); result->setRValue(); } } return result; } SpirvInstruction * SpirvEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) { if (expr->getNumArgs() == 2) { emitError( "(RW)StructuredBuffer::Load(in location, out status) not supported", expr->getExprLoc()); return 0; } const auto *buffer = expr->getImplicitObjectArgument(); auto *info = loadIfAliasVarRef(buffer); const QualType structType = hlsl::GetHLSLResourceResultType(buffer->getType()); auto *zero = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0)); auto *index = doExpr(expr->getArg(0)); return turnIntoElementPtr(buffer->getType(), info, structType, {zero, index}, buffer->getExprLoc()); } SpirvInstruction * SpirvEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr, bool isInc, bool loadObject) { auto *zero = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); auto *sOne = spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 1, true)); const auto srcLoc = expr->getCallee()->getExprLoc(); const auto *object = expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); if (loadObject) { // We don't need the object's here since counter variable is a // separate variable. But we still need the side effects of evaluating the // object, e.g., if the source code is foo(...).IncrementCounter(), we still // want to emit the code for foo(...). (void)doExpr(object); } const auto *counterPair = getFinalACSBufferCounter(object); if (!counterPair) { emitFatalError("cannot find the associated counter variable", object->getExprLoc()); return nullptr; } auto *counterPtr = spvBuilder.createAccessChain( astContext.IntTy, counterPair->get(spvBuilder, spvContext), {zero}, srcLoc); SpirvInstruction *index = nullptr; if (isInc) { index = spvBuilder.createAtomicOp( spv::Op::OpAtomicIAdd, astContext.IntTy, counterPtr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone, sOne, srcLoc); } else { // Note that OpAtomicISub returns the value before the subtraction; // so we need to do substraction again with OpAtomicISub's return value. auto *prev = spvBuilder.createAtomicOp( spv::Op::OpAtomicISub, astContext.IntTy, counterPtr, spv::Scope::Device, spv::MemorySemanticsMask::MaskNone, sOne, srcLoc); index = spvBuilder.createBinaryOp(spv::Op::OpISub, astContext.IntTy, prev, sOne, srcLoc); } return index; } bool SpirvEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl, const Expr *srcExpr) { // We are handling associated counters here. Casts should not alter which // associated counter to manipulate. srcExpr = srcExpr->IgnoreParenCasts(); // For parameters of forward-declared functions. We must make sure the // associated counter variable is created. But for forward-declared functions, // the translation of the real definition may not be started yet. if (const auto *param = dyn_cast(dstDecl)) declIdMapper.createFnParamCounterVar(param); // For implicit objects of methods. Similar to the above. else if (const auto *thisObject = dyn_cast(dstDecl)) declIdMapper.createFnParamCounterVar(thisObject); // Handle AssocCounter#1 (see CounterVarFields comment) if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) { const auto *srcPair = getFinalACSBufferCounter(srcExpr); if (!srcPair) { emitFatalError("cannot find the associated counter variable", srcExpr->getExprLoc()); return false; } dstPair->assign(*srcPair, spvBuilder, spvContext); return true; } // Handle AssocCounter#3 llvm::SmallVector srcIndices; const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl); const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices); if (dstFields && srcFields) { // The destination is a struct whose fields are directly alias resources. // But that's not necessarily true for the source, which can be deep // nested structs. That means they will have different index "prefixes" // for all their fields; while the "prefix" for destination is effectively // an empty list (since it is not nested in other structs). We need to // strip the index prefix from the source. return dstFields->assign(*srcFields, /*dstIndices=*/{}, srcIndices, spvBuilder, spvContext); } // AssocCounter#2 and AssocCounter#4 for the lhs cannot happen since the lhs // is a stand-alone decl in this method. return false; } bool SpirvEmitter::tryToAssignCounterVar(const Expr *dstExpr, const Expr *srcExpr) { dstExpr = dstExpr->IgnoreParenCasts(); srcExpr = srcExpr->IgnoreParenCasts(); const auto *dstPair = getFinalACSBufferCounter(dstExpr); const auto *srcPair = getFinalACSBufferCounter(srcExpr); if ((dstPair == nullptr) != (srcPair == nullptr)) { emitFatalError("cannot handle associated counter variable assignment", srcExpr->getExprLoc()); return false; } // Handle AssocCounter#1 & AssocCounter#2 if (dstPair && srcPair) { dstPair->assign(*srcPair, spvBuilder, spvContext); return true; } // Handle AssocCounter#3 & AssocCounter#4 llvm::SmallVector dstIndices; llvm::SmallVector srcIndices; const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices); const auto *dstFields = getIntermediateACSBufferCounter(dstExpr, &dstIndices); if (dstFields && srcFields) { return dstFields->assign(*srcFields, dstIndices, srcIndices, spvBuilder, spvContext); } return false; } const CounterIdAliasPair * SpirvEmitter::getFinalACSBufferCounter(const Expr *expr) { // AssocCounter#1: referencing some stand-alone variable if (const auto *decl = getReferencedDef(expr)) return declIdMapper.getCounterIdAliasPair(decl); // AssocCounter#2: referencing some non-struct field llvm::SmallVector rawIndices; const auto *base = collectArrayStructIndices( expr, /*rawIndex=*/true, &rawIndices, /*indices*/ nullptr); const auto *decl = (base && isa(base)) ? getOrCreateDeclForMethodObject(cast(curFunction)) : getReferencedDef(base); return declIdMapper.getCounterIdAliasPair(decl, &rawIndices); } const CounterVarFields *SpirvEmitter::getIntermediateACSBufferCounter( const Expr *expr, llvm::SmallVector *rawIndices) { const auto *base = collectArrayStructIndices(expr, /*rawIndex=*/true, rawIndices, /*indices*/ nullptr); const auto *decl = (base && isa(base)) // Use the decl we created to represent the implicit object ? getOrCreateDeclForMethodObject(cast(curFunction)) // Find the referenced decl from the original source code : getReferencedDef(base); return declIdMapper.getCounterVarFields(decl); } const ImplicitParamDecl * SpirvEmitter::getOrCreateDeclForMethodObject(const CXXMethodDecl *method) { const auto found = thisDecls.find(method); if (found != thisDecls.end()) return found->second; const std::string name = method->getName().str() + ".this"; // Create a new identifier to convey the name auto &identifier = astContext.Idents.get(name); return thisDecls[method] = ImplicitParamDecl::Create( astContext, /*DC=*/nullptr, SourceLocation(), &identifier, method->getThisType(astContext)->getPointeeType()); } SpirvInstruction * SpirvEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) { const bool isAppend = expr->getNumArgs() == 1; auto *zero = spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 0)); const auto *object = expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); auto *bufferInfo = loadIfAliasVarRef(object); auto *index = incDecRWACSBufferCounter( expr, isAppend, // We have already translated the object in the above. Avoid duplication. /*loadObject=*/false); auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType()); // If this is a variable to communicate with host e.g., ACSBuffer // and its type is bool or vector of bool, its effective type used // for SPIRV must be uint not bool. We must convert it to uint here. bool needCast = false; if (bufferInfo->getLayoutRule() != SpirvLayoutRule::Void && isBoolOrVecOfBoolType(bufferElemTy)) { uint32_t vecSize = 1; const bool isVec = isVectorType(bufferElemTy, nullptr, &vecSize); bufferElemTy = isVec ? astContext.getExtVectorType(astContext.UnsignedIntTy, vecSize) : astContext.UnsignedIntTy; needCast = true; } bufferInfo = turnIntoElementPtr(object->getType(), bufferInfo, bufferElemTy, {zero, index}, object->getExprLoc()); if (isAppend) { // Write out the value auto *arg0 = doExpr(expr->getArg(0)); if (!arg0) return nullptr; if (!arg0->isRValue()) { arg0 = spvBuilder.createLoad(bufferElemTy, arg0, expr->getArg(0)->getExprLoc()); } if (needCast && !isSameType(astContext, bufferElemTy, arg0->getAstResultType())) { arg0 = castToType(arg0, arg0->getAstResultType(), bufferElemTy, expr->getArg(0)->getExprLoc()); } storeValue(bufferInfo, arg0, bufferElemTy, expr->getCallee()->getExprLoc()); return 0; } else { // Note that we are returning a pointer (lvalue) here inorder to further // acess the fields in this element, e.g., buffer.Consume().a.b. So we // cannot forcefully set all normal function calls as returning rvalue. return bufferInfo; } } SpirvInstruction * SpirvEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) { // TODO: handle multiple stream-output objects const auto *object = expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); const auto *stream = cast(object)->getDecl(); auto *value = doExpr(expr->getArg(0)); declIdMapper.writeBackOutputStream(stream, stream->getType(), value); spvBuilder.createEmitVertex(expr->getExprLoc()); return nullptr; } SpirvInstruction * SpirvEmitter::processStreamOutputRestart(const CXXMemberCallExpr *expr) { // TODO: handle multiple stream-output objects spvBuilder.createEndPrimitive(expr->getExprLoc()); return 0; } SpirvInstruction * SpirvEmitter::emitGetSamplePosition(SpirvInstruction *sampleCount, SpirvInstruction *sampleIndex, SourceLocation loc) { struct Float2 { float x; float y; }; static const Float2 pos2[] = { {4.0 / 16.0, 4.0 / 16.0}, {-4.0 / 16.0, -4.0 / 16.0}, }; static const Float2 pos4[] = { {-2.0 / 16.0, -6.0 / 16.0}, {6.0 / 16.0, -2.0 / 16.0}, {-6.0 / 16.0, 2.0 / 16.0}, {2.0 / 16.0, 6.0 / 16.0}, }; static const Float2 pos8[] = { {1.0 / 16.0, -3.0 / 16.0}, {-1.0 / 16.0, 3.0 / 16.0}, {5.0 / 16.0, 1.0 / 16.0}, {-3.0 / 16.0, -5.0 / 16.0}, {-5.0 / 16.0, 5.0 / 16.0}, {-7.0 / 16.0, -1.0 / 16.0}, {3.0 / 16.0, 7.0 / 16.0}, {7.0 / 16.0, -7.0 / 16.0}, }; static const Float2 pos16[] = { {1.0 / 16.0, 1.0 / 16.0}, {-1.0 / 16.0, -3.0 / 16.0}, {-3.0 / 16.0, 2.0 / 16.0}, {4.0 / 16.0, -1.0 / 16.0}, {-5.0 / 16.0, -2.0 / 16.0}, {2.0 / 16.0, 5.0 / 16.0}, {5.0 / 16.0, 3.0 / 16.0}, {3.0 / 16.0, -5.0 / 16.0}, {-2.0 / 16.0, 6.0 / 16.0}, {0.0 / 16.0, -7.0 / 16.0}, {-4.0 / 16.0, -6.0 / 16.0}, {-6.0 / 16.0, 4.0 / 16.0}, {-8.0 / 16.0, 0.0 / 16.0}, {7.0 / 16.0, -4.0 / 16.0}, {6.0 / 16.0, 7.0 / 16.0}, {-7.0 / 16.0, -8.0 / 16.0}, }; // We are emitting the SPIR-V for the following HLSL source code: // // float2 position; // // if (count == 2) { // position = pos2[index]; // } // else if (count == 4) { // position = pos4[index]; // } // else if (count == 8) { // position = pos8[index]; // } // else if (count == 16) { // position = pos16[index]; // } // else { // position = float2(0.0f, 0.0f); // } const auto v2f32Type = astContext.getExtVectorType(astContext.FloatTy, 2); // Creates a SPIR-V function scope variable of type float2[len]. const auto createArray = [this, v2f32Type, loc](const Float2 *ptr, uint32_t len) { llvm::SmallVector components; for (uint32_t i = 0; i < len; ++i) { auto *x = spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(ptr[i].x)); auto *y = spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(ptr[i].y)); components.push_back(spvBuilder.getConstantComposite(v2f32Type, {x, y})); } const auto arrType = astContext.getConstantArrayType( v2f32Type, llvm::APInt(32, len), clang::ArrayType::Normal, 0); auto *val = spvBuilder.getConstantComposite(arrType, components); const std::string varName = "var.GetSamplePosition.data." + std::to_string(len); auto *var = spvBuilder.addFnVar(arrType, loc, varName); spvBuilder.createStore(var, val, loc); return var; }; auto *pos2Arr = createArray(pos2, 2); auto *pos4Arr = createArray(pos4, 4); auto *pos8Arr = createArray(pos8, 8); auto *pos16Arr = createArray(pos16, 16); auto *resultVar = spvBuilder.addFnVar(v2f32Type, loc, "var.GetSamplePosition.result"); auto *then2BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then2"); auto *then4BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then4"); auto *then8BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then8"); auto *then16BB = spvBuilder.createBasicBlock("if.GetSamplePosition.then16"); auto *else2BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else2"); auto *else4BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else4"); auto *else8BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else8"); auto *else16BB = spvBuilder.createBasicBlock("if.GetSamplePosition.else16"); auto *merge2BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge2"); auto *merge4BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge4"); auto *merge8BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge8"); auto *merge16BB = spvBuilder.createBasicBlock("if.GetSamplePosition.merge16"); // if (count == 2) { const auto check2 = spvBuilder.createBinaryOp( spv::Op::OpIEqual, astContext.BoolTy, sampleCount, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 2)), loc); spvBuilder.createConditionalBranch(check2, then2BB, else2BB, loc, merge2BB); spvBuilder.addSuccessor(then2BB); spvBuilder.addSuccessor(else2BB); spvBuilder.setMergeTarget(merge2BB); // position = pos2[index]; // } spvBuilder.setInsertPoint(then2BB); auto *ac = spvBuilder.createAccessChain(v2f32Type, pos2Arr, {sampleIndex}, loc); spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc), loc); spvBuilder.createBranch(merge2BB, loc); spvBuilder.addSuccessor(merge2BB); // else if (count == 4) { spvBuilder.setInsertPoint(else2BB); const auto check4 = spvBuilder.createBinaryOp( spv::Op::OpIEqual, astContext.BoolTy, sampleCount, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 4)), loc); spvBuilder.createConditionalBranch(check4, then4BB, else4BB, loc, merge4BB); spvBuilder.addSuccessor(then4BB); spvBuilder.addSuccessor(else4BB); spvBuilder.setMergeTarget(merge4BB); // position = pos4[index]; // } spvBuilder.setInsertPoint(then4BB); ac = spvBuilder.createAccessChain(v2f32Type, pos4Arr, {sampleIndex}, loc); spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc), loc); spvBuilder.createBranch(merge4BB, loc); spvBuilder.addSuccessor(merge4BB); // else if (count == 8) { spvBuilder.setInsertPoint(else4BB); const auto check8 = spvBuilder.createBinaryOp( spv::Op::OpIEqual, astContext.BoolTy, sampleCount, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 8)), loc); spvBuilder.createConditionalBranch(check8, then8BB, else8BB, loc, merge8BB); spvBuilder.addSuccessor(then8BB); spvBuilder.addSuccessor(else8BB); spvBuilder.setMergeTarget(merge8BB); // position = pos8[index]; // } spvBuilder.setInsertPoint(then8BB); ac = spvBuilder.createAccessChain(v2f32Type, pos8Arr, {sampleIndex}, loc); spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc), loc); spvBuilder.createBranch(merge8BB, loc); spvBuilder.addSuccessor(merge8BB); // else if (count == 16) { spvBuilder.setInsertPoint(else8BB); const auto check16 = spvBuilder.createBinaryOp( spv::Op::OpIEqual, astContext.BoolTy, sampleCount, spvBuilder.getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, 16)), loc); spvBuilder.createConditionalBranch(check16, then16BB, else16BB, loc, merge16BB); spvBuilder.addSuccessor(then16BB); spvBuilder.addSuccessor(else16BB); spvBuilder.setMergeTarget(merge16BB); // position = pos16[index]; // } spvBuilder.setInsertPoint(then16BB); ac = spvBuilder.createAccessChain(v2f32Type, pos16Arr, {sampleIndex}, loc); spvBuilder.createStore(resultVar, spvBuilder.createLoad(v2f32Type, ac, loc), loc); spvBuilder.createBranch(merge16BB, loc); spvBuilder.addSuccessor(merge16BB); // else { // position = float2(0.0f, 0.0f); // } spvBuilder.setInsertPoint(else16BB); auto *zero = spvBuilder.getConstantFloat(astContext.FloatTy, llvm::APFloat(0.0f)); auto *v2f32Zero = spvBuilder.getConstantComposite(v2f32Type, {zero, zero}); spvBuilder.createStore(resultVar, v2f32Zero, loc); spvBuilder.createBranch(merge16BB, loc); spvBuilder.addSuccessor(merge16BB); spvBuilder.setInsertPoint(merge16BB); spvBuilder.createBranch(merge8BB, loc); spvBuilder.addSuccessor(merge8BB); spvBuilder.setInsertPoint(merge8BB); spvBuilder.createBranch(merge4BB, loc); spvBuilder.addSuccessor(merge4BB); spvBuilder.setInsertPoint(merge4BB); spvBuilder.createBranch(merge2BB, loc); spvBuilder.addSuccessor(merge2BB); spvBuilder.setInsertPoint(merge2BB); return spvBuilder.createLoad(v2f32Type, resultVar, loc); } SpirvInstruction * SpirvEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) { const FunctionDecl *callee = expr->getDirectCallee(); llvm::StringRef group; uint32_t opcode = static_cast(hlsl::IntrinsicOp::Num_Intrinsics); if (hlsl::GetIntrinsicOp(callee, opcode, group)) { return processIntrinsicMemberCall(expr, static_cast(opcode)); } return processCall(expr); } void SpirvEmitter::handleOffsetInMethodCall(const CXXMemberCallExpr *expr, uint32_t index, SpirvInstruction **constOffset, SpirvInstruction **varOffset) { assert(constOffset && varOffset); // Ensure the given arg index is not out-of-range. assert(index < expr->getNumArgs()); *constOffset = *varOffset = nullptr; // Initialize both first if ((*constOffset = tryToEvaluateAsConst(expr->getArg(index)))) return; // Constant offset else *varOffset = doExpr(expr->getArg(index)); } SpirvInstruction * SpirvEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr, hlsl::IntrinsicOp opcode) { using namespace hlsl; SpirvInstruction *retVal = nullptr; switch (opcode) { case IntrinsicOp::MOP_Sample: retVal = processTextureSampleGather(expr, /*isSample=*/true); break; case IntrinsicOp::MOP_Gather: retVal = processTextureSampleGather(expr, /*isSample=*/false); break; case IntrinsicOp::MOP_SampleBias: retVal = processTextureSampleBiasLevel(expr, /*isBias=*/true); break; case IntrinsicOp::MOP_SampleLevel: retVal = processTextureSampleBiasLevel(expr, /*isBias=*/false); break; case IntrinsicOp::MOP_SampleGrad: retVal = processTextureSampleGrad(expr); break; case IntrinsicOp::MOP_SampleCmp: retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/true); break; case IntrinsicOp::MOP_SampleCmpLevelZero: retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/false); break; case IntrinsicOp::MOP_GatherRed: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 0); break; case IntrinsicOp::MOP_GatherGreen: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 1); break; case IntrinsicOp::MOP_GatherBlue: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 2); break; case IntrinsicOp::MOP_GatherAlpha: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 3); break; case IntrinsicOp::MOP_GatherCmp: retVal = processTextureGatherCmp(expr); break; case IntrinsicOp::MOP_GatherCmpRed: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/true, 0); break; case IntrinsicOp::MOP_Load: return processBufferTextureLoad(expr); case IntrinsicOp::MOP_Load2: return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false); case IntrinsicOp::MOP_Load3: return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false); case IntrinsicOp::MOP_Load4: return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false); case IntrinsicOp::MOP_Store: return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true); case IntrinsicOp::MOP_Store2: return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true); case IntrinsicOp::MOP_Store3: return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true); case IntrinsicOp::MOP_Store4: return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true); case IntrinsicOp::MOP_GetDimensions: retVal = processGetDimensions(expr); break; case IntrinsicOp::MOP_CalculateLevelOfDetail: retVal = processTextureLevelOfDetail(expr, /* unclamped */ false); case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped: retVal = processTextureLevelOfDetail(expr, /* unclamped */ true); break; case IntrinsicOp::MOP_IncrementCounter: retVal = spvBuilder.createUnaryOp(spv::Op::OpBitcast, astContext.UnsignedIntTy, incDecRWACSBufferCounter(expr, /*isInc*/ true), expr->getCallee()->getExprLoc()); break; case IntrinsicOp::MOP_DecrementCounter: retVal = spvBuilder.createUnaryOp( spv::Op::OpBitcast, astContext.UnsignedIntTy, incDecRWACSBufferCounter(expr, /*isInc*/ false), expr->getCallee()->getExprLoc()); break; case IntrinsicOp::MOP_Append: if (hlsl::IsHLSLStreamOutputType( expr->getImplicitObjectArgument()->getType())) return processStreamOutputAppend(expr); else return processACSBufferAppendConsume(expr); case IntrinsicOp::MOP_Consume: return processACSBufferAppendConsume(expr); case IntrinsicOp::MOP_RestartStrip: retVal = processStreamOutputRestart(expr); break; case IntrinsicOp::MOP_InterlockedAdd: case IntrinsicOp::MOP_InterlockedAnd: case IntrinsicOp::MOP_InterlockedOr: case IntrinsicOp::MOP_InterlockedXor: case IntrinsicOp::MOP_InterlockedUMax: case IntrinsicOp::MOP_InterlockedUMin: case IntrinsicOp::MOP_InterlockedMax: case IntrinsicOp::MOP_InterlockedMin: case IntrinsicOp::MOP_InterlockedExchange: case IntrinsicOp::MOP_InterlockedCompareExchange: case IntrinsicOp::MOP_InterlockedCompareStore: retVal = processRWByteAddressBufferAtomicMethods(opcode, expr); break; case IntrinsicOp::MOP_GetSamplePosition: retVal = processGetSamplePosition(expr); break; case IntrinsicOp::MOP_SubpassLoad: retVal = processSubpassLoad(expr); break; case IntrinsicOp::MOP_GatherCmpGreen: case IntrinsicOp::MOP_GatherCmpBlue: case IntrinsicOp::MOP_GatherCmpAlpha: emitError("no equivalent for %0 intrinsic method in Vulkan", expr->getCallee()->getExprLoc()) << expr->getMethodDecl()->getName(); return nullptr; case IntrinsicOp::MOP_TraceRayInline: return processTraceRayInline(expr); case IntrinsicOp::MOP_Abort: case IntrinsicOp::MOP_CandidateGeometryIndex: case IntrinsicOp::MOP_CandidateInstanceContributionToHitGroupIndex: case IntrinsicOp::MOP_CandidateInstanceID: case IntrinsicOp::MOP_CandidateInstanceIndex: case IntrinsicOp::MOP_CandidateObjectRayDirection: case IntrinsicOp::MOP_CandidateObjectRayOrigin: case IntrinsicOp::MOP_CandidateObjectToWorld3x4: case IntrinsicOp::MOP_CandidateObjectToWorld4x3: case IntrinsicOp::MOP_CandidatePrimitiveIndex: case IntrinsicOp::MOP_CandidateProceduralPrimitiveNonOpaque: case IntrinsicOp::MOP_CandidateTriangleBarycentrics: case IntrinsicOp::MOP_CandidateTriangleFrontFace: case IntrinsicOp::MOP_CandidateTriangleRayT: case IntrinsicOp::MOP_CandidateType: case IntrinsicOp::MOP_CandidateWorldToObject3x4: case IntrinsicOp::MOP_CandidateWorldToObject4x3: case IntrinsicOp::MOP_CommitNonOpaqueTriangleHit: case IntrinsicOp::MOP_CommitProceduralPrimitiveHit: case IntrinsicOp::MOP_CommittedGeometryIndex: case IntrinsicOp::MOP_CommittedInstanceContributionToHitGroupIndex: case IntrinsicOp::MOP_CommittedInstanceID: case IntrinsicOp::MOP_CommittedInstanceIndex: case IntrinsicOp::MOP_CommittedObjectRayDirection: case IntrinsicOp::MOP_CommittedObjectRayOrigin: case IntrinsicOp::MOP_CommittedObjectToWorld3x4: case IntrinsicOp::MOP_CommittedObjectToWorld4x3: case IntrinsicOp::MOP_CommittedPrimitiveIndex: case IntrinsicOp::MOP_CommittedRayT: case IntrinsicOp::MOP_CommittedStatus: case IntrinsicOp::MOP_CommittedTriangleBarycentrics: case IntrinsicOp::MOP_CommittedTriangleFrontFace: case IntrinsicOp::MOP_CommittedWorldToObject3x4: case IntrinsicOp::MOP_CommittedWorldToObject4x3: case IntrinsicOp::MOP_Proceed: case IntrinsicOp::MOP_RayFlags: case IntrinsicOp::MOP_RayTMin: case IntrinsicOp::MOP_WorldRayDirection: case IntrinsicOp::MOP_WorldRayOrigin: return processRayQueryIntrinsics(expr, opcode); default: emitError("intrinsic '%0' method unimplemented", expr->getCallee()->getExprLoc()) << expr->getDirectCallee()->getName(); return nullptr; } if (retVal) retVal->setRValue(); return retVal; } SpirvInstruction *SpirvEmitter::createImageSample( QualType retType, QualType imageType, SpirvInstruction *image, SpirvInstruction *sampler, SpirvInstruction *coordinate, SpirvInstruction *compareVal, SpirvInstruction *bias, SpirvInstruction *lod, std::pair grad, SpirvInstruction *constOffset, SpirvInstruction *varOffset, SpirvInstruction *constOffsets, SpirvInstruction *sample, SpirvInstruction *minLod, SpirvInstruction *residencyCodeId, SourceLocation loc) { // SampleDref* instructions in SPIR-V always return a scalar. // They also have the correct type in HLSL. if (compareVal) { return spvBuilder.createImageSample(retType, imageType, image, sampler, coordinate, compareVal, bias, lod, grad, constOffset, varOffset, constOffsets, sample, minLod, residencyCodeId, loc); } // Non-Dref Sample instructions in SPIR-V must always return a vec4. auto texelType = retType; QualType elemType = {}; uint32_t retVecSize = 0; if (isVectorType(retType, &elemType, &retVecSize) && retVecSize != 4) { texelType = astContext.getExtVectorType(elemType, 4); } else if (isScalarType(retType)) { retVecSize = 1; elemType = retType; texelType = astContext.getExtVectorType(retType, 4); } // The Lod and Grad image operands requires explicit-lod instructions. // Otherwise we use implicit-lod instructions. const bool isExplicit = lod || (grad.first && grad.second); // Implicit-lod instructions are only allowed in pixel shader. if (!spvContext.isPS() && !isExplicit) emitError("sampling with implicit lod is only allowed in fragment shaders", loc); auto *retVal = spvBuilder.createImageSample( texelType, imageType, image, sampler, coordinate, compareVal, bias, lod, grad, constOffset, varOffset, constOffsets, sample, minLod, residencyCodeId, loc); // Extract smaller vector from the vec4 result if necessary. if (texelType != retType) { retVal = extractVecFromVec4(retVal, retVecSize, elemType, loc); } return retVal; } SpirvInstruction * SpirvEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr, const bool isSample) { // Signatures: // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, Texture3D: // DXGI_FORMAT Object.Sample(sampler_state S, // float Location // [, int Offset] // [, float Clamp] // [, out uint Status]); // // For TextureCube and TextureCubeArray: // DXGI_FORMAT Object.Sample(sampler_state S, // float Location // [, float Clamp] // [, out uint Status]); // // For Texture2D/Texture2DArray: //