//===------- SPIRVEmitter.h - SPIR-V Binary Code Emitter --------*- C++ -*-===// // // The LLVM Compiler Infrastructure // // This file is distributed under the University of Illinois Open Source // License. See LICENSE.TXT for details. //===----------------------------------------------------------------------===// // // This file implements a SPIR-V emitter class that takes in HLSL AST and emits // SPIR-V binary words. // //===----------------------------------------------------------------------===// #include "SPIRVEmitter.h" #include "dxc/HlslIntrinsicOp.h" #include "spirv-tools/optimizer.hpp" #include "llvm/ADT/StringExtras.h" #include "InitListHandler.h" namespace clang { namespace spirv { namespace { /// Returns the type of the given decl. If the given decl is a FunctionDecl, /// returns its result type. inline QualType getTypeOrFnRetType(const ValueDecl *decl) { if (const auto *funcDecl = dyn_cast(decl)) { return funcDecl->getReturnType(); } return decl->getType(); } // Returns true if the given decl has the given semantic. bool hasSemantic(const DeclaratorDecl *decl, hlsl::DXIL::SemanticKind semanticKind) { using namespace hlsl; for (auto *annotation : decl->getUnusualAnnotations()) { if (auto *semanticDecl = dyn_cast(annotation)) { llvm::StringRef semanticName; uint32_t semanticIndex = 0; Semantic::DecomposeNameAndIndex(semanticDecl->SemanticName, &semanticName, &semanticIndex); const auto *semantic = Semantic::GetByName(semanticName); if (semantic->GetKind() == semanticKind) return true; } } return false; } bool patchConstFuncTakesHullOutputPatch(FunctionDecl *pcf) { for (const auto *param : pcf->parameters()) if (hlsl::IsHLSLOutputPatchType(param->getType())) return true; return false; } // TODO: Maybe we should move these type probing functions to TypeTranslator. /// Returns true if the given type is a bool or vector of bool type. bool isBoolOrVecOfBoolType(QualType type) { QualType elemType = {}; return (TypeTranslator::isScalarType(type, &elemType) || TypeTranslator::isVectorType(type, &elemType)) && elemType->isBooleanType(); } /// Returns true if the given type is a signed integer or vector of signed /// integer type. bool isSintOrVecOfSintType(QualType type) { QualType elemType = {}; return (TypeTranslator::isScalarType(type, &elemType) || TypeTranslator::isVectorType(type, &elemType)) && elemType->isSignedIntegerType(); } /// Returns true if the given type is an unsigned integer or vector of unsigned /// integer type. bool isUintOrVecOfUintType(QualType type) { QualType elemType = {}; return (TypeTranslator::isScalarType(type, &elemType) || TypeTranslator::isVectorType(type, &elemType)) && elemType->isUnsignedIntegerType(); } /// Returns true if the given type is a float or vector of float type. bool isFloatOrVecOfFloatType(QualType type) { QualType elemType = {}; return (TypeTranslator::isScalarType(type, &elemType) || TypeTranslator::isVectorType(type, &elemType)) && elemType->isFloatingType(); } /// Returns true if the given type is a bool or vector/matrix of bool type. bool isBoolOrVecMatOfBoolType(QualType type) { return isBoolOrVecOfBoolType(type) || (hlsl::IsHLSLMatType(type) && hlsl::GetHLSLMatElementType(type)->isBooleanType()); } /// Returns true if the given type is a signed integer or vector/matrix of /// signed integer type. bool isSintOrVecMatOfSintType(QualType type) { return isSintOrVecOfSintType(type) || (hlsl::IsHLSLMatType(type) && hlsl::GetHLSLMatElementType(type)->isSignedIntegerType()); } /// Returns true if the given type is an unsigned integer or vector/matrix of /// unsigned integer type. bool isUintOrVecMatOfUintType(QualType type) { return isUintOrVecOfUintType(type) || (hlsl::IsHLSLMatType(type) && hlsl::GetHLSLMatElementType(type)->isUnsignedIntegerType()); } /// Returns true if the given type is a float or vector/matrix of float type. bool isFloatOrVecMatOfFloatType(QualType type) { return isFloatOrVecOfFloatType(type) || (hlsl::IsHLSLMatType(type) && hlsl::GetHLSLMatElementType(type)->isFloatingType()); } inline bool isSpirvMatrixOp(spv::Op opcode) { return opcode == spv::Op::OpMatrixTimesMatrix || opcode == spv::Op::OpMatrixTimesVector || opcode == spv::Op::OpMatrixTimesScalar; } /// If expr is a (RW)StructuredBuffer.Load(), returns the object and writes /// index. Otherwiser, returns false. // TODO: The following doesn't handle Load(int, int) yet. And it is basically a // duplicate of doCXXMemberCallExpr. const Expr *isStructuredBufferLoad(const Expr *expr, const Expr **index) { using namespace hlsl; if (const auto *indexing = dyn_cast(expr)) { const auto *callee = indexing->getDirectCallee(); uint32_t opcode = static_cast(IntrinsicOp::Num_Intrinsics); llvm::StringRef group; if (GetIntrinsicOp(callee, opcode, group)) { if (static_cast(opcode) == IntrinsicOp::MOP_Load) { const auto *object = indexing->getImplicitObjectArgument(); if (TypeTranslator::isStructuredBuffer(object->getType())) { *index = indexing->getArg(0); return indexing->getImplicitObjectArgument(); } } } } return nullptr; } /// Returns true if the given VarDecl will be translated into a SPIR-V variable /// not in the Private or Function storage class. inline bool isExternalVar(const VarDecl *var) { // Class static variables should be put in the Private storage class. // groupshared variables are allowed to be declared as "static". But we still // need to put them in the Workgroup storage class. That is, when seeing // "static groupshared", ignore "static". return var->hasExternalFormalLinkage() ? !var->isStaticDataMember() : var->getAttr(); } /// Returns the referenced variable's DeclContext if the given expr is /// a DeclRefExpr referencing a ConstantBuffer/TextureBuffer. Otherwise, /// returns nullptr. const DeclContext *isConstantTextureBufferDeclRef(const Expr *expr) { if (const auto *declRefExpr = dyn_cast(expr->IgnoreParenCasts())) if (const auto *varDecl = dyn_cast(declRefExpr->getFoundDecl())) if (TypeTranslator::isConstantTextureBuffer(varDecl)) return varDecl->getType()->getAs()->getDecl(); return nullptr; } /// Returns true if /// * the given expr is an DeclRefExpr referencing a kind of structured or byte /// buffer and it is non-alias one, or /// * the given expr is an CallExpr returning a kind of structured or byte /// buffer. /// /// Note: legalization specific code bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) { expr = expr->IgnoreParenCasts(); if (const auto *declRefExpr = dyn_cast(expr)) { if (const auto *varDecl = dyn_cast(declRefExpr->getFoundDecl())) if (TypeTranslator::isAKindOfStructuredOrByteBuffer(varDecl->getType())) return isExternalVar(varDecl); } else if (const auto *callExpr = dyn_cast(expr)) { if (TypeTranslator::isAKindOfStructuredOrByteBuffer(callExpr->getType())) return true; } return false; } bool spirvToolsLegalize(spv_target_env env, std::vector *module, std::string *messages) { spvtools::Optimizer optimizer(env); optimizer.SetMessageConsumer( [messages](spv_message_level_t /*level*/, const char * /*source*/, const spv_position_t & /*position*/, const char *message) { *messages += message; }); optimizer.RegisterLegalizationPasses(); optimizer.RegisterPass(spvtools::CreateReplaceInvalidOpcodePass()); optimizer.RegisterPass(spvtools::CreateCompactIdsPass()); return optimizer.Run(module->data(), module->size(), module); } bool spirvToolsOptimize(spv_target_env env, std::vector *module, std::string *messages) { spvtools::Optimizer optimizer(env); optimizer.SetMessageConsumer( [messages](spv_message_level_t /*level*/, const char * /*source*/, const spv_position_t & /*position*/, const char *message) { *messages += message; }); optimizer.RegisterPerformancePasses(); optimizer.RegisterPass(spvtools::CreateCompactIdsPass()); return optimizer.Run(module->data(), module->size(), module); } bool spirvToolsValidate(spv_target_env env, std::vector *module, std::string *messages, bool relaxLogicalPointer) { spvtools::SpirvTools tools(env); tools.SetMessageConsumer( [messages](spv_message_level_t /*level*/, const char * /*source*/, const spv_position_t & /*position*/, const char *message) { *messages += message; }); spvtools::ValidatorOptions options; options.SetRelaxLogicalPointer(relaxLogicalPointer); return tools.Validate(module->data(), module->size(), options); } /// Translates atomic HLSL opcodes into the equivalent SPIR-V opcode. spv::Op translateAtomicHlslOpcodeToSpirvOpcode(hlsl::IntrinsicOp opcode) { using namespace hlsl; using namespace spv; switch (opcode) { case IntrinsicOp::IOP_InterlockedAdd: case IntrinsicOp::MOP_InterlockedAdd: return Op::OpAtomicIAdd; case IntrinsicOp::IOP_InterlockedAnd: case IntrinsicOp::MOP_InterlockedAnd: return Op::OpAtomicAnd; case IntrinsicOp::IOP_InterlockedOr: case IntrinsicOp::MOP_InterlockedOr: return Op::OpAtomicOr; case IntrinsicOp::IOP_InterlockedXor: case IntrinsicOp::MOP_InterlockedXor: return Op::OpAtomicXor; case IntrinsicOp::IOP_InterlockedUMax: case IntrinsicOp::MOP_InterlockedUMax: return Op::OpAtomicUMax; case IntrinsicOp::IOP_InterlockedUMin: case IntrinsicOp::MOP_InterlockedUMin: return Op::OpAtomicUMin; case IntrinsicOp::IOP_InterlockedMax: case IntrinsicOp::MOP_InterlockedMax: return Op::OpAtomicSMax; case IntrinsicOp::IOP_InterlockedMin: case IntrinsicOp::MOP_InterlockedMin: return Op::OpAtomicSMin; case IntrinsicOp::IOP_InterlockedExchange: case IntrinsicOp::MOP_InterlockedExchange: return Op::OpAtomicExchange; } assert(false && "unimplemented hlsl intrinsic opcode"); return Op::Max; } // Returns true if the given opcode is an accepted binary opcode in // OpSpecConstantOp. bool isAcceptedSpecConstantBinaryOp(spv::Op op) { switch (op) { case spv::Op::OpIAdd: case spv::Op::OpISub: case spv::Op::OpIMul: case spv::Op::OpUDiv: case spv::Op::OpSDiv: case spv::Op::OpUMod: case spv::Op::OpSRem: case spv::Op::OpSMod: case spv::Op::OpShiftRightLogical: case spv::Op::OpShiftRightArithmetic: case spv::Op::OpShiftLeftLogical: case spv::Op::OpBitwiseOr: case spv::Op::OpBitwiseXor: case spv::Op::OpBitwiseAnd: case spv::Op::OpVectorShuffle: case spv::Op::OpCompositeExtract: case spv::Op::OpCompositeInsert: case spv::Op::OpLogicalOr: case spv::Op::OpLogicalAnd: case spv::Op::OpLogicalNot: case spv::Op::OpLogicalEqual: case spv::Op::OpLogicalNotEqual: case spv::Op::OpIEqual: case spv::Op::OpINotEqual: case spv::Op::OpULessThan: case spv::Op::OpSLessThan: case spv::Op::OpUGreaterThan: case spv::Op::OpSGreaterThan: case spv::Op::OpULessThanEqual: case spv::Op::OpSLessThanEqual: case spv::Op::OpUGreaterThanEqual: case spv::Op::OpSGreaterThanEqual: return true; } return false; } /// Returns true if the given expression is an accepted initializer for a spec /// constant. bool isAcceptedSpecConstantInit(const Expr *init) { // Allow numeric casts init = init->IgnoreParenCasts(); if (isa(init) || isa(init) || isa(init)) return true; // Allow the minus operator which is used to specify negative values if (const auto *unaryOp = dyn_cast(init)) return unaryOp->getOpcode() == UO_Minus && isAcceptedSpecConstantInit(unaryOp->getSubExpr()); return false; } /// Returns true if the given function parameter can act as shader stage /// input parameter. inline bool canActAsInParmVar(const ParmVarDecl *param) { // If the parameter has no in/out/inout attribute, it is defaulted to // an in parameter. return !param->hasAttr() && // GS output streams are marked as inout, but it should not be // used as in parameter. !hlsl::IsHLSLStreamOutputType(param->getType()); } /// Returns true if the given function parameter can act as shader stage /// output parameter. inline bool canActAsOutParmVar(const ParmVarDecl *param) { return param->hasAttr() || param->hasAttr(); } /// Returns true if the given expression is of builtin type and can be evaluated /// to a constant zero. Returns false otherwise. inline bool evaluatesToConstZero(const Expr *expr, ASTContext &astContext) { const auto type = expr->getType(); if (!type->isBuiltinType()) return false; Expr::EvalResult evalResult; if (expr->EvaluateAsRValue(evalResult, astContext) && !evalResult.HasSideEffects) { const auto &val = evalResult.Val; return ((type->isBooleanType() && !val.getInt().getBoolValue()) || (type->isIntegerType() && !val.getInt().getBoolValue()) || (type->isFloatingType() && val.getFloat().isZero())); } return false; } /// Returns the HLSLBufferDecl if the given VarDecl is inside a cbuffer/tbuffer. /// Returns nullptr otherwise, including varDecl is a ConstantBuffer or /// TextureBuffer itself. inline const HLSLBufferDecl *getCTBufferContext(const VarDecl *varDecl) { if (const auto *bufferDecl = dyn_cast(varDecl->getDeclContext())) // Filter ConstantBuffer/TextureBuffer if (!bufferDecl->isConstantBufferView()) return bufferDecl; return nullptr; } /// Returns the real definition of the callee of the given CallExpr. /// /// If we are calling a forward-declared function, callee will be the /// FunctionDecl for the foward-declared function, not the actual /// definition. The foward-delcaration and defintion are two completely /// different AST nodes. inline const FunctionDecl *getCalleeDefinition(const CallExpr *expr) { const auto *callee = expr->getDirectCallee(); if (callee->isThisDeclarationADefinition()) return callee; // We need to update callee to the actual definition here if (!callee->isDefined(callee)) return nullptr; return callee; } /// Returns the referenced definition. The given expr is expected to be a /// DeclRefExpr or CallExpr after ignoring casts. Returns nullptr otherwise. const DeclaratorDecl *getReferencedDef(const Expr *expr) { if (!expr) return nullptr; expr = expr->IgnoreParenCasts(); if (const auto *declRefExpr = dyn_cast(expr)) { return dyn_cast_or_null(declRefExpr->getDecl()); } if (const auto *callExpr = dyn_cast(expr)) { return getCalleeDefinition(callExpr); } return nullptr; } /// Returns the number of base classes if this type is a derived class/struct. /// Returns zero otherwise. inline uint32_t getNumBaseClasses(QualType type) { if (const auto *cxxDecl = type->getAsCXXRecordDecl()) return cxxDecl->getNumBases(); return 0; } /// Gets the index sequence of casting a derived object to a base object by /// following the cast chain. void getBaseClassIndices(const CastExpr *expr, llvm::SmallVectorImpl *indices) { assert(expr->getCastKind() == CK_UncheckedDerivedToBase || expr->getCastKind() == CK_HLSLDerivedToBase); indices->clear(); QualType derivedType = expr->getSubExpr()->getType(); const auto *derivedDecl = derivedType->getAsCXXRecordDecl(); // Go through the base cast chain: for each of the derived to base cast, find // the index of the base in question in the derived's bases. for (auto pathIt = expr->path_begin(), pathIe = expr->path_end(); pathIt != pathIe; ++pathIt) { // The type of the base in question const auto baseType = (*pathIt)->getType(); uint32_t index = 0; for (auto baseIt = derivedDecl->bases_begin(), baseIe = derivedDecl->bases_end(); baseIt != baseIe; ++baseIt, ++index) if (baseIt->getType() == baseType) { indices->push_back(index); break; } assert(index < derivedDecl->getNumBases()); // Continue to proceed the next base in the chain derivedType = baseType; derivedDecl = derivedType->getAsCXXRecordDecl(); } } spv::Capability getCapabilityForGroupNonUniform(spv::Op opcode) { switch (opcode) { case spv::Op::OpGroupNonUniformElect: return spv::Capability::GroupNonUniform; case spv::Op::OpGroupNonUniformAny: case spv::Op::OpGroupNonUniformAll: case spv::Op::OpGroupNonUniformAllEqual: return spv::Capability::GroupNonUniformVote; case spv::Op::OpGroupNonUniformBallot: case spv::Op::OpGroupNonUniformBallotBitCount: case spv::Op::OpGroupNonUniformBroadcast: case spv::Op::OpGroupNonUniformBroadcastFirst: return spv::Capability::GroupNonUniformBallot; case spv::Op::OpGroupNonUniformIAdd: case spv::Op::OpGroupNonUniformFAdd: case spv::Op::OpGroupNonUniformIMul: case spv::Op::OpGroupNonUniformFMul: case spv::Op::OpGroupNonUniformSMax: case spv::Op::OpGroupNonUniformUMax: case spv::Op::OpGroupNonUniformFMax: case spv::Op::OpGroupNonUniformSMin: case spv::Op::OpGroupNonUniformUMin: case spv::Op::OpGroupNonUniformFMin: case spv::Op::OpGroupNonUniformBitwiseAnd: case spv::Op::OpGroupNonUniformBitwiseOr: case spv::Op::OpGroupNonUniformBitwiseXor: return spv::Capability::GroupNonUniformArithmetic; case spv::Op::OpGroupNonUniformQuadBroadcast: case spv::Op::OpGroupNonUniformQuadSwap: return spv::Capability::GroupNonUniformQuad; } assert(false && "unhandled opcode"); return spv::Capability::Max; } std::string getNamespacePrefix(const Decl *decl) { std::string nsPrefix = ""; const DeclContext *dc = decl->getDeclContext(); while (dc && !dc->isTranslationUnit()) { if (const NamespaceDecl *ns = dyn_cast(dc)) { if (!ns->isAnonymousNamespace()) { nsPrefix = ns->getName().str() + "::" + nsPrefix; } } dc = dc->getParent(); } return nsPrefix; } std::string getFnName(const FunctionDecl *fn) { // Prefix the function name with the struct name if necessary std::string classOrStructName = ""; if (const auto *memberFn = dyn_cast(fn)) if (const auto *st = dyn_cast(memberFn->getDeclContext())) classOrStructName = st->getName().str() + "."; return getNamespacePrefix(fn) + classOrStructName + fn->getName().str(); } } // namespace SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci, EmitSPIRVOptions &options) : theCompilerInstance(ci), astContext(ci.getASTContext()), diags(ci.getDiagnostics()), spirvOptions(options), entryFunctionName(ci.getCodeGenOpts().HLSLEntryFunction), shaderModel(*hlsl::ShaderModel::GetByName( ci.getCodeGenOpts().HLSLProfile.c_str())), theContext(), featureManager(diags, options), theBuilder(&theContext, &featureManager, options.enableReflect), typeTranslator(astContext, theBuilder, diags, options), declIdMapper(shaderModel, astContext, theBuilder, typeTranslator, featureManager, options), entryFunctionId(0), curFunction(nullptr), curThis(0), seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false) { if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid) emitError("unknown shader module: %0", {}) << shaderModel.GetName(); if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() && !shaderModel.IsGS()) emitError("-fvk-invert-y can only be used in VS/DS/GS", {}); if (options.useGlLayout && options.useDxLayout) emitError("cannot specify both -fvk-use-dx-layout and -fvk-use-gl-layout", {}); options.Initialize(); // Set shader module version theBuilder.setShaderModelVersion(shaderModel.GetMajor(), shaderModel.GetMinor()); // Set debug info const auto &inputFiles = ci.getFrontendOpts().Inputs; if (options.enableDebugInfo && !inputFiles.empty()) theBuilder.setSourceFileName(theContext.takeNextId(), inputFiles.front().getFile().str()); } void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) { // Stop translating if there are errors in previous compilation stages. if (context.getDiagnostics().hasErrorOccurred()) return; TranslationUnitDecl *tu = context.getTranslationUnitDecl(); // The entry function is the seed of the queue. for (auto *decl : tu->decls()) { if (auto *funcDecl = dyn_cast(decl)) { if (funcDecl->getName() == entryFunctionName) { workQueue.insert(funcDecl); } } else { // If ignoring unused resources, defer Decl handling inside // TranslationUnit to the time of first referencing. if (!spirvOptions.ignoreUnusedResources) { doDecl(decl); } } } // Translate all functions reachable from the entry function. // The queue can grow in the meanwhile; so need to keep evaluating // workQueue.size(). for (uint32_t i = 0; i < workQueue.size(); ++i) { doDecl(workQueue[i]); } if (context.getDiagnostics().hasErrorOccurred()) return; const spv_target_env targetEnv = featureManager.getTargetEnv(); AddRequiredCapabilitiesForShaderModel(); // Addressing and memory model are required in a valid SPIR-V module. theBuilder.setAddressingModel(spv::AddressingModel::Logical); theBuilder.setMemoryModel(spv::MemoryModel::GLSL450); theBuilder.addEntryPoint(getSpirvShaderStage(shaderModel), entryFunctionId, entryFunctionName, declIdMapper.collectStageVars()); // Add Location decorations to stage input/output variables. if (!declIdMapper.decorateStageIOLocations()) return; // Add descriptor set and binding decorations to resource variables. if (!declIdMapper.decorateResourceBindings()) return; // Output the constructed module. std::vector m = theBuilder.takeModule(); if (!spirvOptions.codeGenHighLevel) { // Run legalization passes if (needsLegalization || declIdMapper.requiresLegalization()) { std::string messages; if (!spirvToolsLegalize(targetEnv, &m, &messages)) { emitFatalError("failed to legalize SPIR-V: %0", {}) << messages; emitNote("please file a bug report on " "https://github.com/Microsoft/DirectXShaderCompiler/issues " "with source code if possible", {}); return; } else if (!messages.empty()) { emitWarning("SPIR-V legalization: %0", {}) << messages; } } // Run optimization passes if (theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) { std::string messages; if (!spirvToolsOptimize(targetEnv, &m, &messages)) { emitFatalError("failed to optimize SPIR-V: %0", {}) << messages; emitNote("please file a bug report on " "https://github.com/Microsoft/DirectXShaderCompiler/issues " "with source code if possible", {}); return; } } } // Validate the generated SPIR-V code if (!spirvOptions.disableValidation) { std::string messages; if (!spirvToolsValidate(targetEnv, &m, &messages, declIdMapper.requiresLegalization())) { emitFatalError("generated SPIR-V is invalid: %0", {}) << messages; emitNote("please file a bug report on " "https://github.com/Microsoft/DirectXShaderCompiler/issues " "with source code if possible", {}); return; } } theCompilerInstance.getOutStream()->write( reinterpret_cast(m.data()), m.size() * 4); } void SPIRVEmitter::doDecl(const Decl *decl) { if (decl->isImplicit() || isa(decl) || isa(decl)) return; if (const auto *varDecl = dyn_cast(decl)) { // We can have VarDecls inside cbuffer/tbuffer. For those VarDecls, we need // to emit their cbuffer/tbuffer as a whole and access each individual one // using access chains. if (const auto *bufferDecl = getCTBufferContext(varDecl)) { doHLSLBufferDecl(bufferDecl); } else { doVarDecl(varDecl); } } else if (const auto *namespaceDecl = dyn_cast(decl)) { for (auto *subDecl : namespaceDecl->decls()) // Note: We only emit functions as they are discovered through the call // graph starting from the entry-point. We should not emit unused // functions inside namespaces. if (!isa(subDecl)) doDecl(subDecl); } else if (const auto *funcDecl = dyn_cast(decl)) { doFunctionDecl(funcDecl); } else if (const auto *bufferDecl = dyn_cast(decl)) { doHLSLBufferDecl(bufferDecl); } else if (const auto *recordDecl = dyn_cast(decl)) { doRecordDecl(recordDecl); } else { emitError("decl type %0 unimplemented", decl->getLocation()) << decl->getDeclKindName(); } } void SPIRVEmitter::doStmt(const Stmt *stmt, llvm::ArrayRef attrs) { if (const auto *compoundStmt = dyn_cast(stmt)) { for (auto *st : compoundStmt->body()) doStmt(st); } else if (const auto *retStmt = dyn_cast(stmt)) { doReturnStmt(retStmt); } else if (const auto *declStmt = dyn_cast(stmt)) { doDeclStmt(declStmt); } else if (const auto *ifStmt = dyn_cast(stmt)) { doIfStmt(ifStmt, attrs); } else if (const auto *switchStmt = dyn_cast(stmt)) { doSwitchStmt(switchStmt, attrs); } else if (const auto *caseStmt = dyn_cast(stmt)) { processCaseStmtOrDefaultStmt(stmt); } else if (const auto *defaultStmt = dyn_cast(stmt)) { processCaseStmtOrDefaultStmt(stmt); } else if (const auto *breakStmt = dyn_cast(stmt)) { doBreakStmt(breakStmt); } else if (const auto *theDoStmt = dyn_cast(stmt)) { doDoStmt(theDoStmt, attrs); } else if (const auto *discardStmt = dyn_cast(stmt)) { doDiscardStmt(discardStmt); } else if (const auto *continueStmt = dyn_cast(stmt)) { doContinueStmt(continueStmt); } else if (const auto *whileStmt = dyn_cast(stmt)) { doWhileStmt(whileStmt, attrs); } else if (const auto *forStmt = dyn_cast(stmt)) { doForStmt(forStmt, attrs); } else if (const auto *nullStmt = dyn_cast(stmt)) { // For the null statement ";". We don't need to do anything. } else if (const auto *expr = dyn_cast(stmt)) { // All cases for expressions used as statements doExpr(expr); } else if (const auto *attrStmt = dyn_cast(stmt)) { doStmt(attrStmt->getSubStmt(), attrStmt->getAttrs()); } else { emitError("statement class '%0' unimplemented", stmt->getLocStart()) << stmt->getStmtClassName() << stmt->getSourceRange(); } } SpirvEvalInfo SPIRVEmitter::doDeclRefExpr(const DeclRefExpr *expr) { const auto *decl = expr->getDecl(); auto id = declIdMapper.getDeclEvalInfo(decl, false); if (spirvOptions.ignoreUnusedResources && !id) { // First time referencing a Decl inside TranslationUnit. Register // into DeclResultIdMapper and emit SPIR-V for it and then query // again. doDecl(decl); id = declIdMapper.getDeclEvalInfo(decl); } return id; } SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) { SpirvEvalInfo result(/*id*/ 0); // Provide a hint to the typeTranslator that if a literal is discovered, its // intended usage is as this expression type. TypeTranslator::LiteralTypeHint hint(typeTranslator, expr->getType()); expr = expr->IgnoreParens(); if (const auto *declRefExpr = dyn_cast(expr)) { result = doDeclRefExpr(declRefExpr); } else if (const auto *memberExpr = dyn_cast(expr)) { result = doMemberExpr(memberExpr); } else if (const auto *castExpr = dyn_cast(expr)) { result = doCastExpr(castExpr); } else if (const auto *initListExpr = dyn_cast(expr)) { result = doInitListExpr(initListExpr); } else if (const auto *boolLiteral = dyn_cast(expr)) { const auto value = theBuilder.getConstantBool(boolLiteral->getValue(), isSpecConstantMode); result = SpirvEvalInfo(value).setConstant().setRValue(); } else if (const auto *intLiteral = dyn_cast(expr)) { const auto value = translateAPInt(intLiteral->getValue(), expr->getType()); result = SpirvEvalInfo(value).setConstant().setRValue(); } else if (const auto *floatLiteral = dyn_cast(expr)) { const auto value = translateAPFloat(floatLiteral->getValue(), expr->getType()); result = SpirvEvalInfo(value).setConstant().setRValue(); } else if (const auto *compoundAssignOp = dyn_cast(expr)) { // CompoundAssignOperator is a subclass of BinaryOperator. It should be // checked before BinaryOperator. result = doCompoundAssignOperator(compoundAssignOp); } else if (const auto *binOp = dyn_cast(expr)) { result = doBinaryOperator(binOp); } else if (const auto *unaryOp = dyn_cast(expr)) { result = doUnaryOperator(unaryOp); } else if (const auto *vecElemExpr = dyn_cast(expr)) { result = doHLSLVectorElementExpr(vecElemExpr); } else if (const auto *matElemExpr = dyn_cast(expr)) { result = doExtMatrixElementExpr(matElemExpr); } else if (const auto *funcCall = dyn_cast(expr)) { result = doCallExpr(funcCall); } else if (const auto *subscriptExpr = dyn_cast(expr)) { result = doArraySubscriptExpr(subscriptExpr); } else if (const auto *condExpr = dyn_cast(expr)) { result = doConditionalOperator(condExpr); } else if (const auto *defaultArgExpr = dyn_cast(expr)) { result = doExpr(defaultArgExpr->getParam()->getDefaultArg()); } else if (isa(expr)) { assert(curThis); result = curThis; } else { emitError("expression class '%0' unimplemented", expr->getExprLoc()) << expr->getStmtClassName() << expr->getSourceRange(); } return result; } SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr) { // We are trying to load the value here, which is what an LValueToRValue // implicit cast is intended to do. We can ignore the cast if exists. expr = expr->IgnoreParenLValueCasts(); return loadIfGLValue(expr, doExpr(expr)); } SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr, SpirvEvalInfo info) { // Do nothing if this is already rvalue if (info.isRValue()) return info; // Check whether we are trying to load an array of opaque objects as a whole. // If true, we are likely to copy it as a whole. To assist per-element // copying, avoid the load here and return the pointer directly. // TODO: consider moving this hack into SPIRV-Tools as a transformation. if (TypeTranslator::isOpaqueArrayType(expr->getType())) return info; // Check whether we are trying to load an externally visible structured/byte // buffer as a whole. If true, it means we are creating alias for it. Avoid // the load and write the pointer directly to the alias variable then. // // Also for the case of alias function returns. If we are trying to load an // alias function return as a whole, it means we are assigning it to another // alias variable. Avoid the load and write the pointer directly. // // Note: legalization specific code if (isReferencingNonAliasStructuredOrByteBuffer(expr)) { return info.setRValue(); } if (loadIfAliasVarRef(expr, info)) { // We are loading an alias variable as a whole here. This is likely for // wholesale assignments or function returns. Need to load the pointer. // // Note: legalization specific code return info; } uint32_t valType = 0; // TODO: Ouch. Very hacky. We need special path to get the value type if // we are loading a whole ConstantBuffer/TextureBuffer since the normal // type translation path won't work. if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) { valType = declIdMapper.getCTBufferPushConstantTypeId(declContext); } else { valType = typeTranslator.translateType(expr->getType(), info.getLayoutRule()); } uint32_t loadedId = theBuilder.createLoad(valType, info); // Special-case: According to the SPIR-V Spec: There is no physical size or // bit pattern defined for boolean type. Therefore an unsigned integer is used // to represent booleans when layout is required. In such cases, after loading // the uint, we should perform a comparison. { uint32_t vecSize = 1, numRows = 0, numCols = 0; if (info.getLayoutRule() != LayoutRule::Void && isBoolOrVecMatOfBoolType(expr->getType())) { const auto exprType = expr->getType(); QualType uintType = astContext.UnsignedIntTy; QualType boolType = astContext.BoolTy; if (TypeTranslator::isScalarType(exprType) || TypeTranslator::isVectorType(exprType, nullptr, &vecSize)) { const auto fromType = vecSize == 1 ? uintType : astContext.getExtVectorType(uintType, vecSize); const auto toType = vecSize == 1 ? boolType : astContext.getExtVectorType(boolType, vecSize); loadedId = castToBool(loadedId, fromType, toType); } else { const bool isMat = TypeTranslator::isMxNMatrix(exprType, nullptr, &numRows, &numCols); assert(isMat); const auto uintRowQualType = astContext.getExtVectorType(uintType, numCols); const auto uintRowQualTypeId = typeTranslator.translateType(uintRowQualType); const auto boolRowQualType = astContext.getExtVectorType(boolType, numCols); const auto boolRowQualTypeId = typeTranslator.translateType(boolRowQualType); const uint32_t resultTypeId = theBuilder.getMatType(boolType, boolRowQualTypeId, numRows); llvm::SmallVector rows; for (uint32_t i = 0; i < numRows; ++i) { const auto row = theBuilder.createCompositeExtract(uintRowQualTypeId, loadedId, {i}); rows.push_back(castToBool(row, uintRowQualType, boolRowQualType)); } loadedId = theBuilder.createCompositeConstruct(resultTypeId, rows); } // Now that it is converted to Bool, it has no layout rule. // This result-id should be evaluated as bool from here on out. info.setLayoutRule(LayoutRule::Void); } } return info.setResultId(loadedId).setRValue(); } SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) { auto info = doExpr(expr); loadIfAliasVarRef(expr, info); return info; } bool SPIRVEmitter::loadIfAliasVarRef(const Expr *varExpr, SpirvEvalInfo &info) { if (info.containsAliasComponent() && TypeTranslator::isAKindOfStructuredOrByteBuffer(varExpr->getType())) { // Aliased-to variables are all in the Uniform storage class with GLSL // std430 layout rules. const auto ptrType = typeTranslator.translateType(varExpr->getType()); // Load the pointer of the aliased-to-variable if the expression has a // pointer to pointer type. That is, the expression itself is a lvalue. // (Note that we translate alias function return values as pointer types, // not pointer to pointer types.) if (varExpr->isGLValue()) info.setResultId(theBuilder.createLoad(ptrType, info)); info.setStorageClass(spv::StorageClass::Uniform) .setLayoutRule(spirvOptions.sBufferLayoutRule) // Now it is a pointer to the global resource, which is lvalue. .setRValue(false) // Set to false to indicate that we've performed dereference over the // pointer-to-pointer and now should fallback to the normal path .setContainsAliasComponent(false); return true; } return false; } uint32_t SPIRVEmitter::castToType(uint32_t value, QualType fromType, QualType toType, SourceLocation srcLoc) { if (isFloatOrVecOfFloatType(toType)) return castToFloat(value, fromType, toType, srcLoc); // Order matters here. Bool (vector) values will also be considered as uint // (vector) values. So given a bool (vector) argument, isUintOrVecOfUintType() // will also return true. We need to check bool before uint. The opposite is // not true. if (isBoolOrVecOfBoolType(toType)) return castToBool(value, fromType, toType); if (isSintOrVecOfSintType(toType) || isUintOrVecOfUintType(toType)) return castToInt(value, fromType, toType, srcLoc); emitError("casting to type %0 unimplemented", {}) << toType; return 0; } void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) { assert(decl->isThisDeclarationADefinition()); // A RAII class for maintaining the current function under traversal. class FnEnvRAII { public: // Creates a new instance which sets fnEnv to the newFn on creation, // and resets fnEnv to its original value on destruction. FnEnvRAII(const FunctionDecl **fnEnv, const FunctionDecl *newFn) : oldFn(*fnEnv), fnSlot(fnEnv) { *fnEnv = newFn; } ~FnEnvRAII() { *fnSlot = oldFn; } private: const FunctionDecl *oldFn; const FunctionDecl **fnSlot; }; FnEnvRAII fnEnvRAII(&curFunction, decl); // We are about to start translation for a new function. Clear the break stack // and the continue stack. breakStack = std::stack(); continueStack = std::stack(); // This will allow the entry-point name to be something like // myNamespace::myEntrypointFunc. std::string funcName = getFnName(decl); uint32_t funcId = 0; if (funcName == entryFunctionName) { // The entry function surely does not have pre-assigned for // it like other functions that got added to the work queue following // function calls. funcId = theContext.takeNextId(); funcName = "src." + funcName; // Create wrapper for the entry function if (!emitEntryFunctionWrapper(decl, funcId)) return; } else { // Non-entry functions are added to the work queue following function // calls. We have already assigned s for it when translating // its call site. Query it here. funcId = declIdMapper.getDeclEvalInfo(decl); } const uint32_t retType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl); // Construct the function signature. llvm::SmallVector paramTypes; bool isNonStaticMemberFn = false; if (const auto *memberFn = dyn_cast(decl)) { isNonStaticMemberFn = !memberFn->isStatic(); if (isNonStaticMemberFn) { // For non-static member function, the first parameter should be the // object on which we are invoking this method. const uint32_t valueType = typeTranslator.translateType( memberFn->getThisType(astContext)->getPointeeType()); const uint32_t ptrType = theBuilder.getPointerType(valueType, spv::StorageClass::Function); paramTypes.push_back(ptrType); } } for (const auto *param : decl->params()) { const uint32_t valueType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param); const uint32_t ptrType = theBuilder.getPointerType(valueType, spv::StorageClass::Function); paramTypes.push_back(ptrType); } const uint32_t funcType = theBuilder.getFunctionType(retType, paramTypes); theBuilder.beginFunction(funcType, retType, funcName, funcId); if (isNonStaticMemberFn) { // Remember the parameter for the this object so later we can handle // CXXThisExpr correctly. curThis = theBuilder.addFnParam(paramTypes[0], "param.this"); } // Create all parameters. for (uint32_t i = 0; i < decl->getNumParams(); ++i) { const ParmVarDecl *paramDecl = decl->getParamDecl(i); (void)declIdMapper.createFnParam(paramDecl); } if (decl->hasBody()) { // The entry basic block. const uint32_t entryLabel = theBuilder.createBasicBlock("bb.entry"); theBuilder.setInsertPoint(entryLabel); // Process all statments in the body. doStmt(decl->getBody()); // We have processed all Stmts in this function and now in the last // basic block. Make sure we have a termination instruction. if (!theBuilder.isCurrentBasicBlockTerminated()) { const auto retType = decl->getReturnType(); if (retType->isVoidType()) { theBuilder.createReturn(); } else { // If the source code does not provide a proper return value for some // control flow path, it's undefined behavior. We just return null // value here. theBuilder.createReturnValue( theBuilder.getConstantNull(typeTranslator.translateType(retType))); } } } theBuilder.endFunction(); } bool SPIRVEmitter::validateVKAttributes(const NamedDecl *decl) { bool success = true; if (const auto *varDecl = dyn_cast(decl)) { const auto varType = varDecl->getType(); if ((TypeTranslator::isSubpassInput(varType) || TypeTranslator::isSubpassInputMS(varType)) && !varDecl->hasAttr()) { emitError("missing vk::input_attachment_index attribute", varDecl->getLocation()); success = false; } } if (const auto *iaiAttr = decl->getAttr()) { if (!shaderModel.IsPS()) { emitError("SubpassInput(MS) only allowed in pixel shader", decl->getLocation()); success = false; } if (!decl->isExternallyVisible()) { emitError("SubpassInput(MS) must be externally visible", decl->getLocation()); success = false; } // We only allow VKInputAttachmentIndexAttr to be attached to global // variables. So it should be fine to cast here. const auto elementType = hlsl::GetHLSLResourceResultType(cast(decl)->getType()); if (!TypeTranslator::isScalarType(elementType) && !TypeTranslator::isVectorType(elementType)) { emitError( "only scalar/vector types allowed as SubpassInput(MS) parameter type", decl->getLocation()); // Return directly to avoid further type processing, which will hit // asserts in TypeTranslator. return false; } } // The frontend will make sure that // * vk::push_constant applies to global variables of struct type // * vk::binding applies to global variables or cbuffers/tbuffers // * vk::counter_binding applies to global variables of RW/Append/Consume // StructuredBuffer // * vk::location applies to function parameters/returns and struct fields // So the only case we need to check co-existence is vk::push_constant and // vk::binding. if (const auto *pcAttr = decl->getAttr()) { const auto loc = pcAttr->getLocation(); if (seenPushConstantAt.isInvalid()) { seenPushConstantAt = loc; } else { // TODO: Actually this is slightly incorrect. The Vulkan spec says: // There must be no more than one push constant block statically used // per shader entry point. // But we are checking whether there are more than one push constant // blocks defined. Tracking usage requires more work. emitError("cannot have more than one push constant block", loc); emitNote("push constant block previously defined here", seenPushConstantAt); success = false; } if (decl->hasAttr()) { emitError("vk::push_constant attribute cannot be used together with " "vk::binding attribute", loc); success = false; } } return success; } void SPIRVEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) { // This is a cbuffer/tbuffer decl. // Check and emit warnings for member intializers which are not // supported in Vulkan for (const auto *member : bufferDecl->decls()) { if (const auto *varMember = dyn_cast(member)) { if (const auto *init = varMember->getInit()) emitWarning("%select{tbuffer|cbuffer}0 member initializer " "ignored since no equivalent in Vulkan", init->getExprLoc()) << bufferDecl->isCBuffer() << init->getSourceRange(); // We cannot handle external initialization of column-major matrices now. if (typeTranslator.isOrContainsNonFpColMajorMatrix(varMember->getType(), varMember)) { emitError("externally initialized non-floating-point column-major " "matrices not supported yet", varMember->getLocation()); } } } if (!validateVKAttributes(bufferDecl)) return; (void)declIdMapper.createCTBuffer(bufferDecl); } void SPIRVEmitter::doRecordDecl(const RecordDecl *recordDecl) { // Ignore implict records // Somehow we'll have implicit records with: // static const int Length = count; // that can mess up with the normal CodeGen. if (recordDecl->isImplicit()) return; // Handle each static member with inline initializer. // Each static member has a corresponding VarDecl inside the // RecordDecl. For those defined in the translation unit, // their VarDecls do not have initializer. for (auto *subDecl : recordDecl->decls()) if (auto *varDecl = dyn_cast(subDecl)) if (varDecl->isStaticDataMember() && varDecl->hasInit()) doVarDecl(varDecl); } void SPIRVEmitter::doVarDecl(const VarDecl *decl) { if (!validateVKAttributes(decl)) return; // We cannot handle external initialization of column-major matrices now. if (isExternalVar(decl) && typeTranslator.isOrContainsNonFpColMajorMatrix(decl->getType(), decl)) { emitError("externally initialized non-floating-point column-major " "matrices not supported yet", decl->getLocation()); } if (const auto *arrayType = astContext.getAsConstantArrayType(decl->getType())) { if (TypeTranslator::isAKindOfStructuredOrByteBuffer( arrayType->getElementType())) { emitError("arrays of structured/byte buffers unsupported", decl->getLocation()); return; } } if (decl->hasAttr()) { // This is a VarDecl for specialization constant. createSpecConstant(decl); return; } if (decl->hasAttr()) { // This is a VarDecl for PushConstant block. (void)declIdMapper.createPushConstant(decl); return; } if (isa(decl->getDeclContext())) { // This is a VarDecl of a ConstantBuffer/TextureBuffer type. (void)declIdMapper.createCTBuffer(decl); return; } SpirvEvalInfo varId(0); // The contents in externally visible variables can be updated via the // pipeline. They should be handled differently from file and function scope // variables. // File scope variables (static "global" and "local" variables) belongs to // the Private storage class, while function scope variables (normal "local" // variables) belongs to the Function storage class. if (isExternalVar(decl)) { varId = declIdMapper.createExternVar(decl); } else { // We already know the variable is not externally visible here. If it does // not have local storage, it should be file scope variable. const bool isFileScopeVar = !decl->hasLocalStorage(); if (isFileScopeVar) varId = declIdMapper.createFileVar(decl, llvm::None); else varId = declIdMapper.createFnVar(decl, llvm::None); // Emit OpStore to initialize the variable // TODO: revert back to use OpVariable initializer // We should only evaluate the initializer once for a static variable. if (isFileScopeVar) { if (decl->isStaticLocal()) { initOnce(decl->getType(), decl->getName(), varId, decl->getInit()); } else { // Defer to initialize these global variables at the beginning of the // entry function. toInitGloalVars.push_back(decl); } } // Function local variables. Just emit OpStore at the current insert point. else if (const Expr *init = decl->getInit()) { if (const auto constId = tryToEvaluateAsConst(init)) theBuilder.createStore(varId, constId); else storeValue(varId, loadIfGLValue(init), decl->getType()); // Update counter variable associated with local variables tryToAssignCounterVar(decl, init); } // Variables that are not externally visible and of opaque types should // request legalization. if (!needsLegalization && TypeTranslator::isOpaqueType(decl->getType())) needsLegalization = true; } if (TypeTranslator::isRelaxedPrecisionType(decl->getType(), spirvOptions)) { theBuilder.decorate(varId, spv::Decoration::RelaxedPrecision); } // All variables that are of opaque struct types should request legalization. if (!needsLegalization && TypeTranslator::isOpaqueStructType(decl->getType())) needsLegalization = true; } spv::LoopControlMask SPIRVEmitter::translateLoopAttribute(const Stmt *stmt, const Attr &attr) { switch (attr.getKind()) { case attr::HLSLLoop: case attr::HLSLFastOpt: return spv::LoopControlMask::DontUnroll; case attr::HLSLUnroll: return spv::LoopControlMask::Unroll; case attr::HLSLAllowUAVCondition: emitWarning("unsupported allow_uav_condition attribute ignored", stmt->getLocStart()); break; default: llvm_unreachable("found unknown loop attribute"); } return spv::LoopControlMask::MaskNone; } void SPIRVEmitter::doDiscardStmt(const DiscardStmt *discardStmt) { assert(!theBuilder.isCurrentBasicBlockTerminated()); theBuilder.createKill(); // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. const uint32_t newBB = theBuilder.createBasicBlock(); theBuilder.setInsertPoint(newBB); } void SPIRVEmitter::doDoStmt(const DoStmt *theDoStmt, llvm::ArrayRef attrs) { // do-while loops are composed of: // // do { // // } while(); // // SPIR-V requires loops to have a merge basic block as well as a continue // basic block. Even though do-while loops do not have an explicit continue // block as in for-loops, we still do need to create a continue block. // // Since SPIR-V requires structured control flow, we need two more basic // blocks,
and .
is the block before control flow // diverges, and is the block where control flow subsequently // converges. The can be performed in the basic block. // The final CFG should normally be like the following. Exceptions // will occur with non-local exits like loop breaks or early returns. // // +----------+ // | header | <-----------------------------------+ // +----------+ | // | | (true) // v | // +------+ +--------------------+ | // | body | ----> | continue () |-----------+ // +------+ +--------------------+ // | // | (false) // +-------+ | // | merge | <-------------+ // +-------+ // // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec. const spv::LoopControlMask loopControl = attrs.empty() ? spv::LoopControlMask::MaskNone : translateLoopAttribute(theDoStmt, *attrs.front()); // Create basic blocks const uint32_t headerBB = theBuilder.createBasicBlock("do_while.header"); const uint32_t bodyBB = theBuilder.createBasicBlock("do_while.body"); const uint32_t continueBB = theBuilder.createBasicBlock("do_while.continue"); const uint32_t mergeBB = theBuilder.createBasicBlock("do_while.merge"); // Make sure any continue statements branch to the continue block, and any // break statements branch to the merge block. continueStack.push(continueBB); breakStack.push(mergeBB); // Branch from the current insert point to the header block. theBuilder.createBranch(headerBB); theBuilder.addSuccessor(headerBB); // Process the
block // The header block must always branch to the body. theBuilder.setInsertPoint(headerBB); theBuilder.createBranch(bodyBB, mergeBB, continueBB, loopControl); theBuilder.addSuccessor(bodyBB); // The current basic block has OpLoopMerge instruction. We need to set its // continue and merge target. theBuilder.setContinueTarget(continueBB); theBuilder.setMergeTarget(mergeBB); // Process the block theBuilder.setInsertPoint(bodyBB); if (const Stmt *body = theDoStmt->getBody()) { doStmt(body); } if (!theBuilder.isCurrentBasicBlockTerminated()) theBuilder.createBranch(continueBB); theBuilder.addSuccessor(continueBB); // Process the block. The check for whether the loop should // continue lies in the continue block. // *NOTE*: There's a SPIR-V rule that when a conditional branch is to occur in // a continue block of a loop, there should be no OpSelectionMerge. Only an // OpBranchConditional must be specified. theBuilder.setInsertPoint(continueBB); uint32_t condition = 0; if (const Expr *check = theDoStmt->getCond()) { condition = doExpr(check); } else { condition = theBuilder.getConstantBool(true); } theBuilder.createConditionalBranch(condition, headerBB, mergeBB); theBuilder.addSuccessor(headerBB); theBuilder.addSuccessor(mergeBB); // Set insertion point to the block for subsequent statements theBuilder.setInsertPoint(mergeBB); // Done with the current scope's continue block and merge block. continueStack.pop(); breakStack.pop(); } void SPIRVEmitter::doContinueStmt(const ContinueStmt *continueStmt) { assert(!theBuilder.isCurrentBasicBlockTerminated()); const uint32_t continueTargetBB = continueStack.top(); theBuilder.createBranch(continueTargetBB); theBuilder.addSuccessor(continueTargetBB); // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. For example: StmtB and StmtC below are put inside a new // basic block which is unreachable. // // while (true) { // StmtA; // continue; // StmtB; // StmtC; // } const uint32_t newBB = theBuilder.createBasicBlock(); theBuilder.setInsertPoint(newBB); } void SPIRVEmitter::doWhileStmt(const WhileStmt *whileStmt, llvm::ArrayRef attrs) { // While loops are composed of: // while () { } // // SPIR-V requires loops to have a merge basic block as well as a continue // basic block. Even though while loops do not have an explicit continue // block as in for-loops, we still do need to create a continue block. // // Since SPIR-V requires structured control flow, we need two more basic // blocks,
and .
is the block before control flow // diverges, and is the block where control flow subsequently // converges. The block can take the responsibility of the
// block. The final CFG should normally be like the following. Exceptions // will occur with non-local exits like loop breaks or early returns. // // +----------+ // | header | <------------------+ // | (check) | | // +----------+ | // | | // +-------+-------+ | // | false | true | // | v | // | +------+ +------------------+ // | | body | --> | continue (no-op) | // v +------+ +------------------+ // +-------+ // | merge | // +-------+ // // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec. const spv::LoopControlMask loopControl = attrs.empty() ? spv::LoopControlMask::MaskNone : translateLoopAttribute(whileStmt, *attrs.front()); // Create basic blocks const uint32_t checkBB = theBuilder.createBasicBlock("while.check"); const uint32_t bodyBB = theBuilder.createBasicBlock("while.body"); const uint32_t continueBB = theBuilder.createBasicBlock("while.continue"); const uint32_t mergeBB = theBuilder.createBasicBlock("while.merge"); // Make sure any continue statements branch to the continue block, and any // break statements branch to the merge block. continueStack.push(continueBB); breakStack.push(mergeBB); // Process the block theBuilder.createBranch(checkBB); theBuilder.addSuccessor(checkBB); theBuilder.setInsertPoint(checkBB); // If we have: // while (int a = foo()) {...} // we should evaluate 'a' by calling 'foo()' every single time the check has // to occur. if (const auto *condVarDecl = whileStmt->getConditionVariableDeclStmt()) doStmt(condVarDecl); uint32_t condition = 0; if (const Expr *check = whileStmt->getCond()) { condition = doExpr(check); } else { condition = theBuilder.getConstantBool(true); } theBuilder.createConditionalBranch(condition, bodyBB, /*false branch*/ mergeBB, /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone, loopControl); theBuilder.addSuccessor(bodyBB); theBuilder.addSuccessor(mergeBB); // The current basic block has OpLoopMerge instruction. We need to set its // continue and merge target. theBuilder.setContinueTarget(continueBB); theBuilder.setMergeTarget(mergeBB); // Process the block theBuilder.setInsertPoint(bodyBB); if (const Stmt *body = whileStmt->getBody()) { doStmt(body); } if (!theBuilder.isCurrentBasicBlockTerminated()) theBuilder.createBranch(continueBB); theBuilder.addSuccessor(continueBB); // Process the block. While loops do not have an explicit // continue block. The continue block just branches to the block. theBuilder.setInsertPoint(continueBB); theBuilder.createBranch(checkBB); theBuilder.addSuccessor(checkBB); // Set insertion point to the block for subsequent statements theBuilder.setInsertPoint(mergeBB); // Done with the current scope's continue and merge blocks. continueStack.pop(); breakStack.pop(); } void SPIRVEmitter::doForStmt(const ForStmt *forStmt, llvm::ArrayRef attrs) { // for loops are composed of: // for (; ; ) // // To translate a for loop, we'll need to emit all statements // in the current basic block, and then have separate basic blocks for // , , and . Besides, since SPIR-V requires // structured control flow, we need two more basic blocks,
// and .
is the block before control flow diverges, // while is the block where control flow subsequently converges. // The block can take the responsibility of the
block. // The final CFG should normally be like the following. Exceptions will // occur with non-local exits like loop breaks or early returns. // +--------+ // | init | // +--------+ // | // v // +----------+ // | header | <---------------+ // | (check) | | // +----------+ | // | | // +-------+-------+ | // | false | true | // | v | // | +------+ +----------+ // | | body | --> | continue | // v +------+ +----------+ // +-------+ // | merge | // +-------+ // // For more details, see "2.11. Structured Control Flow" in the SPIR-V spec. const spv::LoopControlMask loopControl = attrs.empty() ? spv::LoopControlMask::MaskNone : translateLoopAttribute(forStmt, *attrs.front()); // Create basic blocks const uint32_t checkBB = theBuilder.createBasicBlock("for.check"); const uint32_t bodyBB = theBuilder.createBasicBlock("for.body"); const uint32_t continueBB = theBuilder.createBasicBlock("for.continue"); const uint32_t mergeBB = theBuilder.createBasicBlock("for.merge"); // Make sure any continue statements branch to the continue block, and any // break statements branch to the merge block. continueStack.push(continueBB); breakStack.push(mergeBB); // Process the block if (const Stmt *initStmt = forStmt->getInit()) { doStmt(initStmt); } theBuilder.createBranch(checkBB); theBuilder.addSuccessor(checkBB); // Process the block theBuilder.setInsertPoint(checkBB); uint32_t condition; if (const Expr *check = forStmt->getCond()) { condition = doExpr(check); } else { condition = theBuilder.getConstantBool(true); } theBuilder.createConditionalBranch(condition, bodyBB, /*false branch*/ mergeBB, /*merge*/ mergeBB, continueBB, spv::SelectionControlMask::MaskNone, loopControl); theBuilder.addSuccessor(bodyBB); theBuilder.addSuccessor(mergeBB); // The current basic block has OpLoopMerge instruction. We need to set its // continue and merge target. theBuilder.setContinueTarget(continueBB); theBuilder.setMergeTarget(mergeBB); // Process the block theBuilder.setInsertPoint(bodyBB); if (const Stmt *body = forStmt->getBody()) { doStmt(body); } if (!theBuilder.isCurrentBasicBlockTerminated()) theBuilder.createBranch(continueBB); theBuilder.addSuccessor(continueBB); // Process the block theBuilder.setInsertPoint(continueBB); if (const Expr *cont = forStmt->getInc()) { doExpr(cont); } theBuilder.createBranch(checkBB); // should jump back to header theBuilder.addSuccessor(checkBB); // Set insertion point to the block for subsequent statements theBuilder.setInsertPoint(mergeBB); // Done with the current scope's continue block and merge block. continueStack.pop(); breakStack.pop(); } void SPIRVEmitter::doIfStmt(const IfStmt *ifStmt, llvm::ArrayRef attrs) { // if statements are composed of: // if () { } else { } // // To translate if statements, we'll need to emit the expressions // in the current basic block, and then create separate basic blocks for // and . Additionally, we'll need a block as per // SPIR-V's structured control flow requirements. Depending whether there // exists the else branch, the final CFG should normally be like the // following. Exceptions will occur with non-local exits like loop breaks // or early returns. // +-------+ +-------+ // | check | | check | // +-------+ +-------+ // | | // +-------+-------+ +-----+-----+ // | true | false | true | false // v v or v | // +------+ +------+ +------+ | // | then | | else | | then | | // +------+ +------+ +------+ | // | | | v // | +-------+ | | +-------+ // +-> | merge | <-+ +---> | merge | // +-------+ +-------+ { // Try to see if we can const-eval the condition bool condition = false; if (ifStmt->getCond()->EvaluateAsBooleanCondition(condition, astContext)) { if (condition) { doStmt(ifStmt->getThen()); } else if (ifStmt->getElse()) { doStmt(ifStmt->getElse()); } return; } } auto selectionControl = spv::SelectionControlMask::MaskNone; if (!attrs.empty()) { const Attr *attribute = attrs.front(); switch (attribute->getKind()) { case attr::HLSLBranch: selectionControl = spv::SelectionControlMask::DontFlatten; break; case attr::HLSLFlatten: selectionControl = spv::SelectionControlMask::Flatten; break; default: emitWarning("unknown if statement attribute '%0' ignored", attribute->getLocation()) << attribute->getSpelling(); break; } } if (const auto *declStmt = ifStmt->getConditionVariableDeclStmt()) doDeclStmt(declStmt); // First emit the instruction for evaluating the condition. const uint32_t condition = doExpr(ifStmt->getCond()); // Then we need to emit the instruction for the conditional branch. // We'll need the for the then/else/merge block to do so. const bool hasElse = ifStmt->getElse() != nullptr; const uint32_t thenBB = theBuilder.createBasicBlock("if.true"); const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge"); const uint32_t elseBB = hasElse ? theBuilder.createBasicBlock("if.false") : mergeBB; // Create the branch instruction. This will end the current basic block. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB, /*continue*/ 0, selectionControl); theBuilder.addSuccessor(thenBB); theBuilder.addSuccessor(elseBB); // The current basic block has the OpSelectionMerge instruction. We need // to record its merge target. theBuilder.setMergeTarget(mergeBB); // Handle the then branch theBuilder.setInsertPoint(thenBB); doStmt(ifStmt->getThen()); if (!theBuilder.isCurrentBasicBlockTerminated()) theBuilder.createBranch(mergeBB); theBuilder.addSuccessor(mergeBB); // Handle the else branch (if exists) if (hasElse) { theBuilder.setInsertPoint(elseBB); doStmt(ifStmt->getElse()); if (!theBuilder.isCurrentBasicBlockTerminated()) theBuilder.createBranch(mergeBB); theBuilder.addSuccessor(mergeBB); } // From now on, we'll emit instructions into the merge block. theBuilder.setInsertPoint(mergeBB); } void SPIRVEmitter::doReturnStmt(const ReturnStmt *stmt) { if (const auto *retVal = stmt->getRetValue()) { // Update counter variable associated with function returns tryToAssignCounterVar(curFunction, retVal); const auto retInfo = loadIfGLValue(retVal); const auto retType = retVal->getType(); if (retInfo.getStorageClass() != spv::StorageClass::Function && retType->isStructureType()) { // We are returning some value from a non-Function storage class. Need to // create a temporary variable to "convert" the value to Function storage // class and then return. const uint32_t valType = typeTranslator.translateType(retType); const uint32_t tempVar = theBuilder.addFnVar(valType, "temp.var.ret"); storeValue(tempVar, retInfo, retType); theBuilder.createReturnValue(theBuilder.createLoad(valType, tempVar)); } else { theBuilder.createReturnValue(retInfo); } } else { theBuilder.createReturn(); } // We are translating a ReturnStmt, we should be in some function's body. assert(curFunction->hasBody()); // If this return statement is the last statement in the function, then // whe have no more work to do. if (cast(curFunction->getBody())->body_back() == stmt) return; // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. In this case, the newly created basic block will contain // any statement that may come after an early return. const uint32_t newBB = theBuilder.createBasicBlock(); theBuilder.setInsertPoint(newBB); } void SPIRVEmitter::doBreakStmt(const BreakStmt *breakStmt) { assert(!theBuilder.isCurrentBasicBlockTerminated()); uint32_t breakTargetBB = breakStack.top(); theBuilder.addSuccessor(breakTargetBB); theBuilder.createBranch(breakTargetBB); // Some statements that alter the control flow (break, continue, return, and // discard), require creation of a new basic block to hold any statement that // may follow them. For example: StmtB and StmtC below are put inside a new // basic block which is unreachable. // // while (true) { // StmtA; // break; // StmtB; // StmtC; // } const uint32_t newBB = theBuilder.createBasicBlock(); theBuilder.setInsertPoint(newBB); } void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt, llvm::ArrayRef attrs) { // Switch statements are composed of: // switch () { // // // // (optional) // } // // +-------+ // | check | // +-------+ // | // +-------+-------+----------------+---------------+ // | 1 | 2 | 3 | (others) // v v v v // +-------+ +-------------+ +-------+ +------------+ // | case1 | | case2 | | case3 | ... | default | // | | |(fallthrough)|---->| | | (optional) | // +-------+ |+------------+ +-------+ +------------+ // | | | // | | | // | +-------+ | | // | | | <--------------------+ | // +-> | merge | | // | | <-------------------------------------+ // +-------+ // If no attributes are given, or if "forcecase" attribute was provided, // we'll do our best to use OpSwitch if possible. // If any of the cases compares to a variable (rather than an integer // literal), we cannot use OpSwitch because OpSwitch expects literal // numbers as parameters. const bool isAttrForceCase = !attrs.empty() && attrs.front()->getKind() == attr::HLSLForceCase; const bool canUseSpirvOpSwitch = (attrs.empty() || isAttrForceCase) && allSwitchCasesAreIntegerLiterals(switchStmt->getBody()); if (isAttrForceCase && !canUseSpirvOpSwitch) emitWarning("ignored 'forcecase' attribute for the switch statement " "since one or more case values are not integer literals", switchStmt->getLocStart()); if (canUseSpirvOpSwitch) processSwitchStmtUsingSpirvOpSwitch(switchStmt); else processSwitchStmtUsingIfStmts(switchStmt); } SpirvEvalInfo SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) { llvm::SmallVector indices; auto info = loadIfAliasVarRef(collectArrayStructIndices(expr, &indices)); if (!indices.empty()) { (void)turnIntoElementPtr(info, expr->getType(), indices); } return info; } SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) { const auto opcode = expr->getOpcode(); // Handle assignment first since we need to evaluate rhs before lhs. // For other binary operations, we need to evaluate lhs before rhs. if (opcode == BO_Assign) { // Update counter variable associated with lhs of assignments tryToAssignCounterVar(expr->getLHS(), expr->getRHS()); return processAssignment(expr->getLHS(), loadIfGLValue(expr->getRHS()), /*isCompoundAssignment=*/false); } // Try to optimize floatMxN * float and floatN * float case if (opcode == BO_Mul) { if (SpirvEvalInfo result = tryToGenFloatMatrixScale(expr)) return result; if (SpirvEvalInfo result = tryToGenFloatVectorScale(expr)) return result; } return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode, expr->getLHS()->getType(), expr->getType(), expr->getSourceRange()); } SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) { if (const auto *operatorCall = dyn_cast(callExpr)) return doCXXOperatorCallExpr(operatorCall); if (const auto *memberCall = dyn_cast(callExpr)) return doCXXMemberCallExpr(memberCall); // Intrinsic functions such as 'dot' or 'mul' if (hlsl::IsIntrinsicOp(callExpr->getDirectCallee())) { return processIntrinsicCallExpr(callExpr); } // Normal standalone functions return processCall(callExpr); } SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) { const FunctionDecl *callee = getCalleeDefinition(callExpr); // Note that we always want the defintion because Stmts/Exprs in the // function body references the parameters in the definition. if (!callee) { emitError("found undefined function", callExpr->getExprLoc()); return 0; } const auto numParams = callee->getNumParams(); bool isNonStaticMemberCall = false; QualType objectType = {}; // Type of the object (if exists) SpirvEvalInfo objectEvalInfo = 0; // EvalInfo for the object (if exists) bool needsTempVar = false; // Whether we need temporary variable. llvm::SmallVector params; // Temporary variables llvm::SmallVector args; // Evaluated arguments if (const auto *memberCall = dyn_cast(callExpr)) { const auto *memberFn = cast(memberCall->getCalleeDecl()); isNonStaticMemberCall = !memberFn->isStatic(); if (isNonStaticMemberCall) { // For non-static member calls, evaluate the object and pass it as the // first argument. const auto *object = memberCall->getImplicitObjectArgument(); object = object->IgnoreParenNoopCasts(astContext); // Update counter variable associated with the implicit object tryToAssignCounterVar(getOrCreateDeclForMethodObject(memberFn), object); objectType = object->getType(); objectEvalInfo = doExpr(object); uint32_t objectId = objectEvalInfo; // If not already a variable, we need to create a temporary variable and // pass the object pointer to the function. Example: // getObject().objectMethod(); // Also, any parameter passed to the member function must be of Function // storage class. needsTempVar = objectEvalInfo.isRValue() || objectEvalInfo.getStorageClass() != spv::StorageClass::Function; if (needsTempVar) { objectId = createTemporaryVar(objectType, TypeTranslator::getName(objectType), // May need to load to use as initializer loadIfGLValue(object, objectEvalInfo)); } args.push_back(objectId); // We do not need to create a new temporary variable for the this // object. Use the evaluated argument. params.push_back(args.back()); } } // Evaluate parameters for (uint32_t i = 0; i < numParams; ++i) { // We want the argument variable here so that we can write back to it // later. We will do the OpLoad of this argument manually. So ingore // the LValueToRValue implicit cast here. auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts(); const auto *param = callee->getParamDecl(i); // We need to create variables for holding the values to be used as // arguments. The variables themselves are of pointer types. const uint32_t varType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param); const std::string varName = "param.var." + param->getNameAsString(); const uint32_t tempVarId = theBuilder.addFnVar(varType, varName); params.push_back(tempVarId); args.push_back(doExpr(arg)); // Update counter variable associated with function parameters tryToAssignCounterVar(param, arg); // Manually load the argument here const auto rhsVal = loadIfGLValue(arg, args.back()); // Initialize the temporary variables using the contents of the arguments storeValue(tempVarId, rhsVal, param->getType()); } // Push the callee into the work queue if it is not there. if (!workQueue.count(callee)) { workQueue.insert(callee); } const uint32_t retType = declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(callee); // Get or forward declare the function const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee); const uint32_t retVal = theBuilder.createFunctionCall(retType, funcId, params); // If we created a temporary variable for the lvalue object this method is // invoked upon, we need to copy the contents in the temporary variable back // to the original object's variable in case there are side effects. if (needsTempVar && !objectEvalInfo.isRValue()) { const uint32_t typeId = typeTranslator.translateType(objectType); const uint32_t value = theBuilder.createLoad(typeId, params.front()); storeValue(objectEvalInfo, value, objectType); } // Go through all parameters and write those marked as out/inout for (uint32_t i = 0; i < numParams; ++i) { const auto *param = callee->getParamDecl(i); if (canActAsOutParmVar(param)) { const auto *arg = callExpr->getArg(i); const uint32_t index = i + isNonStaticMemberCall; const uint32_t typeId = typeTranslator.translateType(param->getType()); const uint32_t value = theBuilder.createLoad(typeId, params[index]); processAssignment(arg, value, false, args[index]); } } // Inherit the SpirvEvalInfo from the function definition return declIdMapper.getDeclEvalInfo(callee).setResultId(retVal); } SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) { const Expr *subExpr = expr->getSubExpr(); const QualType subExprType = subExpr->getType(); const QualType toType = expr->getType(); // Unfortunately the front-end fails to deduce some types in certain cases. // Provide a hint about literal type usage if possible. TypeTranslator::LiteralTypeHint hint(typeTranslator); // 'literal int' to 'float' conversion. If a literal integer is to be used as // a 32-bit float, the hint is a 32-bit integer. if (toType->isFloatingType() && subExprType->isSpecificBuiltinType(BuiltinType::LitInt) && llvm::APFloat::getSizeInBits(astContext.getFloatTypeSemantics(toType)) == 32) hint.setHint(astContext.IntTy); // 'literal float' to 'float' conversion where intended type is float32. if (toType->isFloatingType() && subExprType->isSpecificBuiltinType(BuiltinType::LitFloat) && llvm::APFloat::getSizeInBits(astContext.getFloatTypeSemantics(toType)) == 32) hint.setHint(astContext.FloatTy); // TODO: We could provide other useful hints. For instance: // For the case of toType being a boolean, if the fromType is a literal float, // we could provide a FloatTy hint and if the fromType is a literal integer, // we could provide an IntTy hint. The front-end, however, seems to deduce the // correct type in these cases; therefore we currently don't provide any // additional hints. switch (expr->getCastKind()) { case CastKind::CK_LValueToRValue: return loadIfGLValue(subExpr); case CastKind::CK_NoOp: return doExpr(subExpr); case CastKind::CK_IntegralCast: case CastKind::CK_FloatingToIntegral: case CastKind::CK_HLSLCC_IntegralCast: case CastKind::CK_HLSLCC_FloatingToIntegral: { // Integer literals in the AST are represented using 64bit APInt // themselves and then implicitly casted into the expected bitwidth. // We need special treatment of integer literals here because generating // a 64bit constant and then explicit casting in SPIR-V requires Int64 // capability. We should avoid introducing unnecessary capabilities to // our best. if (const uint32_t valueId = tryToEvaluateAsConst(expr)) return SpirvEvalInfo(valueId).setConstant().setRValue(); const auto valueId = castToInt(doExpr(subExpr), subExprType, toType, subExpr->getExprLoc()); return SpirvEvalInfo(valueId).setRValue(); } case CastKind::CK_FloatingCast: case CastKind::CK_IntegralToFloating: case CastKind::CK_HLSLCC_FloatingCast: case CastKind::CK_HLSLCC_IntegralToFloating: { // First try to see if we can do constant folding for floating point // numbers like what we are doing for integers in the above. if (const uint32_t valueId = tryToEvaluateAsConst(expr)) return SpirvEvalInfo(valueId).setConstant().setRValue(); const auto valueId = castToFloat(doExpr(subExpr), subExprType, toType, subExpr->getExprLoc()); return SpirvEvalInfo(valueId).setRValue(); } case CastKind::CK_IntegralToBoolean: case CastKind::CK_FloatingToBoolean: case CastKind::CK_HLSLCC_IntegralToBoolean: case CastKind::CK_HLSLCC_FloatingToBoolean: { // First try to see if we can do constant folding. if (const uint32_t valueId = tryToEvaluateAsConst(expr)) return SpirvEvalInfo(valueId).setConstant().setRValue(); const auto valueId = castToBool(doExpr(subExpr), subExprType, toType); return SpirvEvalInfo(valueId).setRValue(); } case CastKind::CK_HLSLVectorSplat: { const size_t size = hlsl::GetHLSLVecSize(expr->getType()); return createVectorSplat(subExpr, size); } case CastKind::CK_HLSLVectorTruncationCast: { const uint32_t toVecTypeId = typeTranslator.translateType(toType); const uint32_t elemTypeId = typeTranslator.translateType(hlsl::GetHLSLVecElementType(toType)); const auto toSize = hlsl::GetHLSLVecSize(toType); const uint32_t composite = doExpr(subExpr); llvm::SmallVector elements; for (uint32_t i = 0; i < toSize; ++i) { elements.push_back( theBuilder.createCompositeExtract(elemTypeId, composite, {i})); } auto valueId = elements.front(); if (toSize > 1) valueId = theBuilder.createCompositeConstruct(toVecTypeId, elements); return SpirvEvalInfo(valueId).setRValue(); } case CastKind::CK_HLSLVectorToScalarCast: { // The underlying should already be a vector of size 1. assert(hlsl::GetHLSLVecSize(subExprType) == 1); return doExpr(subExpr); } case CastKind::CK_HLSLVectorToMatrixCast: { // If target type is already an 1xN matrix type, we just return the // underlying vector. if (TypeTranslator::is1xNMatrix(toType)) return doExpr(subExpr); // A vector can have no more than 4 elements. The only remaining case // is casting from size-4 vector to size-2-by-2 matrix. const auto vec = loadIfGLValue(subExpr); QualType elemType = {}; uint32_t rowCount = 0, colCount = 0; const bool isMat = TypeTranslator::isMxNMatrix(toType, &elemType, &rowCount, &colCount); assert(isMat && rowCount == 2 && colCount == 2); uint32_t vec2Type = theBuilder.getVecType(typeTranslator.translateType(elemType), 2); const auto subVec1 = theBuilder.createVectorShuffle(vec2Type, vec, vec, {0, 1}); const auto subVec2 = theBuilder.createVectorShuffle(vec2Type, vec, vec, {2, 3}); const auto mat = theBuilder.createCompositeConstruct( theBuilder.getMatType(elemType, vec2Type, 2), {subVec1, subVec2}); return SpirvEvalInfo(mat).setRValue(); } case CastKind::CK_HLSLMatrixSplat: { // From scalar to matrix uint32_t rowCount = 0, colCount = 0; hlsl::GetHLSLMatRowColCount(toType, rowCount, colCount); // Handle degenerated cases first if (rowCount == 1 && colCount == 1) return doExpr(subExpr); if (colCount == 1) return createVectorSplat(subExpr, rowCount); const auto vecSplat = createVectorSplat(subExpr, colCount); if (rowCount == 1) return vecSplat; const uint32_t matType = typeTranslator.translateType(toType); llvm::SmallVector vectors(size_t(rowCount), vecSplat); if (vecSplat.isConstant()) { const auto valueId = theBuilder.getConstantComposite(matType, vectors); return SpirvEvalInfo(valueId).setConstant().setRValue(); } else { const auto valueId = theBuilder.createCompositeConstruct(matType, vectors); return SpirvEvalInfo(valueId).setRValue(); } } case CastKind::CK_HLSLMatrixTruncationCast: { const QualType srcType = subExprType; const uint32_t srcId = doExpr(subExpr); const QualType elemType = hlsl::GetHLSLMatElementType(srcType); const uint32_t dstTypeId = typeTranslator.translateType(toType); llvm::SmallVector indexes; // It is possible that the source matrix is in fact a vector. // For example: Truncate float1x3 --> float1x2. // The front-end disallows float1x3 --> float2x1. { uint32_t srcVecSize = 0, dstVecSize = 0; if (TypeTranslator::isVectorType(srcType, nullptr, &srcVecSize) && TypeTranslator::isVectorType(toType, nullptr, &dstVecSize)) { for (uint32_t i = 0; i < dstVecSize; ++i) indexes.push_back(i); const auto valId = theBuilder.createVectorShuffle(dstTypeId, srcId, srcId, indexes); return SpirvEvalInfo(valId).setRValue(); } } uint32_t srcRows = 0, srcCols = 0, dstRows = 0, dstCols = 0; hlsl::GetHLSLMatRowColCount(srcType, srcRows, srcCols); hlsl::GetHLSLMatRowColCount(toType, dstRows, dstCols); const uint32_t elemTypeId = typeTranslator.translateType(elemType); const uint32_t srcRowType = theBuilder.getVecType(elemTypeId, srcCols); // Indexes to pass to OpVectorShuffle for (uint32_t i = 0; i < dstCols; ++i) indexes.push_back(i); llvm::SmallVector extractedVecs; for (uint32_t row = 0; row < dstRows; ++row) { // Extract a row uint32_t rowId = theBuilder.createCompositeExtract(srcRowType, srcId, {row}); // Extract the necessary columns from that row. // The front-end ensures dstCols <= srcCols. // If dstCols equals srcCols, we can use the whole row directly. if (dstCols == 1) { rowId = theBuilder.createCompositeExtract(elemTypeId, rowId, {0}); } else if (dstCols < srcCols) { rowId = theBuilder.createVectorShuffle( theBuilder.getVecType(elemTypeId, dstCols), rowId, rowId, indexes); } extractedVecs.push_back(rowId); } uint32_t valId = extractedVecs.front(); if (extractedVecs.size() > 1) { valId = theBuilder.createCompositeConstruct( typeTranslator.translateType(toType), extractedVecs); } return SpirvEvalInfo(valId).setRValue(); } case CastKind::CK_HLSLMatrixToScalarCast: { // The underlying should already be a matrix of 1x1. assert(TypeTranslator::is1x1Matrix(subExprType)); return doExpr(subExpr); } case CastKind::CK_HLSLMatrixToVectorCast: { // The underlying should already be a matrix of 1xN. assert(TypeTranslator::is1xNMatrix(subExprType) || TypeTranslator::isMx1Matrix(subExprType)); return doExpr(subExpr); } case CastKind::CK_FunctionToPointerDecay: // Just need to return the function id return doExpr(subExpr); case CastKind::CK_FlatConversion: { uint32_t subExprId = 0; QualType evalType = subExprType; // Optimization: we can use OpConstantNull for cases where we want to // initialize an entire data structure to zeros. if (evaluatesToConstZero(subExpr, astContext)) { subExprId = theBuilder.getConstantNull(typeTranslator.translateType(toType)); return SpirvEvalInfo(subExprId).setRValue().setConstant(); } TypeTranslator::LiteralTypeHint hint(typeTranslator); // Try to evaluate float literals as float rather than double. if (const auto *floatLiteral = dyn_cast(subExpr)) { subExprId = tryToEvaluateAsFloat32(floatLiteral->getValue()); if (subExprId) evalType = astContext.FloatTy; } // Evaluate 'literal float' initializer type as float rather than double. // TODO: This could result in rounding error if the initializer is a // non-literal expression that requires larger than 32 bits and has the // 'literal float' type. else if (subExprType->isSpecificBuiltinType(BuiltinType::LitFloat)) { evalType = astContext.FloatTy; hint.setHint(astContext.FloatTy); } // Try to evaluate integer literals as 32-bit int rather than 64-bit int. else if (const auto *intLiteral = dyn_cast(subExpr)) { const bool isSigned = subExprType->isSignedIntegerType(); subExprId = tryToEvaluateAsInt32(intLiteral->getValue(), isSigned); if (subExprId) evalType = isSigned ? astContext.IntTy : astContext.UnsignedIntTy; } // For assigning one array instance to another one with the same array type // (regardless of constness and literalness), the rhs will be wrapped in a // FlatConversion: // |- // `- ImplicitCastExpr // `- ImplicitCastExpr // `- // This FlatConversion does not affect CodeGen, so that we can ignore it. else if (subExprType->isArrayType() && typeTranslator.isSameType(expr->getType(), subExprType)) { return doExpr(subExpr); } if (!subExprId) subExprId = doExpr(subExpr); const auto valId = processFlatConversion(toType, evalType, subExprId, expr->getExprLoc()); return SpirvEvalInfo(valId).setRValue(); } case CastKind::CK_UncheckedDerivedToBase: case CastKind::CK_HLSLDerivedToBase: { // Find the index sequence of the base to which we are casting llvm::SmallVector baseIndices; getBaseClassIndices(expr, &baseIndices); // Turn them in to SPIR-V constants for (uint32_t i = 0; i < baseIndices.size(); ++i) baseIndices[i] = theBuilder.getConstantUint32(baseIndices[i]); auto derivedInfo = doExpr(subExpr); return turnIntoElementPtr(derivedInfo, expr->getType(), baseIndices); } default: emitError("implicit cast kind '%0' unimplemented", expr->getExprLoc()) << expr->getCastKindName() << expr->getSourceRange(); expr->dump(); return 0; } } uint32_t SPIRVEmitter::processFlatConversion(const QualType type, const QualType initType, const uint32_t initId, SourceLocation srcLoc) { // Try to translate the canonical type first const auto canonicalType = type.getCanonicalType(); if (canonicalType != type) return processFlatConversion(canonicalType, initType, initId, srcLoc); // Primitive types { QualType ty = {}; if (TypeTranslator::isScalarType(type, &ty)) { if (const auto *builtinType = ty->getAs()) { switch (builtinType->getKind()) { case BuiltinType::Void: { emitError("cannot create a constant of void type", srcLoc); return 0; } case BuiltinType::Bool: return castToBool(initId, initType, ty); // Target type is an integer variant. case BuiltinType::Int: case BuiltinType::Short: case BuiltinType::Min12Int: case BuiltinType::UShort: case BuiltinType::UInt: case BuiltinType::Long: case BuiltinType::LongLong: case BuiltinType::ULong: case BuiltinType::ULongLong: return castToInt(initId, initType, ty, srcLoc); // Target type is a float variant. case BuiltinType::Double: case BuiltinType::Float: case BuiltinType::Half: case BuiltinType::Min10Float: return castToFloat(initId, initType, ty, srcLoc); default: emitError("flat conversion of type %0 unimplemented", srcLoc) << builtinType->getTypeClassName(); return 0; } } } } // Vector types { QualType elemType = {}; uint32_t elemCount = {}; if (TypeTranslator::isVectorType(type, &elemType, &elemCount)) { const uint32_t elemId = processFlatConversion(elemType, initType, initId, srcLoc); llvm::SmallVector constituents(size_t(elemCount), elemId); return theBuilder.createCompositeConstruct( typeTranslator.translateType(type), constituents); } } // Matrix types { QualType elemType = {}; uint32_t rowCount = 0, colCount = 0; if (TypeTranslator::isMxNMatrix(type, &elemType, &rowCount, &colCount)) { // By default HLSL matrices are row major, while SPIR-V matrices are // column major. We are mapping what HLSL semantically mean a row into a // column here. const uint32_t vecType = theBuilder.getVecType( typeTranslator.translateType(elemType), colCount); const uint32_t elemId = processFlatConversion(elemType, initType, initId, srcLoc); const llvm::SmallVector constituents(size_t(colCount), elemId); const uint32_t colId = theBuilder.createCompositeConstruct(vecType, constituents); const llvm::SmallVector rows(size_t(rowCount), colId); return theBuilder.createCompositeConstruct( typeTranslator.translateType(type), rows); } } // Struct type if (const auto *structType = type->getAs()) { const auto *decl = structType->getDecl(); llvm::SmallVector fields; for (const auto *field : decl->fields()) { // There is a special case for FlatConversion. If T is a struct with only // one member, S, then (T) is allowed, which essentially // constructs a new T instance using the instance of S as its only member. // Check whether we are handling that case here first. if (field->getType().getCanonicalType() == initType.getCanonicalType()) { fields.push_back(initId); } else { fields.push_back( processFlatConversion(field->getType(), initType, initId, srcLoc)); } } return theBuilder.createCompositeConstruct( typeTranslator.translateType(type), fields); } // Array type if (const auto *arrayType = astContext.getAsConstantArrayType(type)) { const auto size = static_cast(arrayType->getSize().getZExtValue()); const uint32_t elemId = processFlatConversion(arrayType->getElementType(), initType, initId, srcLoc); llvm::SmallVector constituents(size_t(size), elemId); return theBuilder.createCompositeConstruct( typeTranslator.translateType(type), constituents); } emitError("flat conversion of type %0 unimplemented", {}) << type->getTypeClassName(); type->dump(); return 0; } SpirvEvalInfo SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) { const auto opcode = expr->getOpcode(); // Try to optimize floatMxN *= float and floatN *= float case if (opcode == BO_MulAssign) { if (SpirvEvalInfo result = tryToGenFloatMatrixScale(expr)) return result; if (SpirvEvalInfo result = tryToGenFloatVectorScale(expr)) return result; } const auto *rhs = expr->getRHS(); const auto *lhs = expr->getLHS(); SpirvEvalInfo lhsPtr = 0; const auto result = processBinaryOp(lhs, rhs, opcode, expr->getComputationLHSType(), expr->getType(), expr->getSourceRange(), &lhsPtr); return processAssignment(lhs, result, true, lhsPtr); } SpirvEvalInfo SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) { const auto type = expr->getType(); // Enhancement for special case when the ConditionalOperator return type is a // literal type. For example: // // float a = cond ? 1 : 2; // int b = cond ? 1.5 : 2.5; // // There will be no indications about whether '1' and '2' should be used as // 32-bit or 64-bit integers. Similarly, there will be no indication about // whether '1.5' and '2.5' should be used as 32-bit or 64-bit floats. // // We want to avoid using 64-bit int and 64-bit float as much as possible. // // Note that if the literal is in fact large enough that it can't be // represented in 32 bits (e.g. integer larger than 3e+9), we should *not* // provide a hint. TypeTranslator::LiteralTypeHint hint(typeTranslator); const bool isLitInt = type->isSpecificBuiltinType(BuiltinType::LitInt); const bool isLitFloat = type->isSpecificBuiltinType(BuiltinType::LitFloat); // Return type of ConditionalOperator is a 'literal int' or 'literal float' if (isLitInt || isLitFloat) { // There is no hint about the intended usage of the literal type. if (typeTranslator.getIntendedLiteralType(type) == type) { // If either branch is a literal that is larger than 32-bits, do not // provide a hint. if (!isLiteralLargerThan32Bits(expr->getTrueExpr()) && !isLiteralLargerThan32Bits(expr->getFalseExpr())) { if (isLitInt) hint.setHint(astContext.IntTy); else if (isLitFloat) hint.setHint(astContext.FloatTy); } } } // According to HLSL doc, all sides of the ?: expression are always // evaluated. const uint32_t typeId = typeTranslator.translateType(type); // If we are selecting between two SampleState objects, none of the three // operands has a LValueToRValue implicit cast. uint32_t condition = loadIfGLValue(expr->getCond()); const auto trueBranch = loadIfGLValue(expr->getTrueExpr()); const auto falseBranch = loadIfGLValue(expr->getFalseExpr()); // For cases where the return type is a scalar or a vector, we can use // OpSelect to choose between the two. OpSelect's return type must be either // scalar or vector. if (TypeTranslator::isScalarType(type) || TypeTranslator::isVectorType(type)) { // The SPIR-V OpSelect instruction must have a selection argument that is // the same size as the return type. If the return type is a vector, the // selection must be a vector of booleans (one per output component). uint32_t count = 0; if (TypeTranslator::isVectorType(expr->getType(), nullptr, &count) && !TypeTranslator::isVectorType(expr->getCond()->getType())) { const uint32_t condVecType = theBuilder.getVecType(theBuilder.getBoolType(), count); const llvm::SmallVector components(size_t(count), condition); condition = theBuilder.createCompositeConstruct(condVecType, components); } auto valueId = theBuilder.createSelect(typeId, condition, trueBranch, falseBranch); return SpirvEvalInfo(valueId).setRValue(); } // If we can't use OpSelect, we need to create if-else control flow. const uint32_t tempVar = theBuilder.addFnVar(typeId, "temp.var.ternary"); const uint32_t thenBB = theBuilder.createBasicBlock("if.true"); const uint32_t mergeBB = theBuilder.createBasicBlock("if.merge"); const uint32_t elseBB = theBuilder.createBasicBlock("if.false"); // Create the branch instruction. This will end the current basic block. theBuilder.createConditionalBranch(condition, thenBB, elseBB, mergeBB); theBuilder.addSuccessor(thenBB); theBuilder.addSuccessor(elseBB); theBuilder.setMergeTarget(mergeBB); // Handle the then branch theBuilder.setInsertPoint(thenBB); theBuilder.createStore(tempVar, trueBranch); theBuilder.createBranch(mergeBB); theBuilder.addSuccessor(mergeBB); // Handle the else branch theBuilder.setInsertPoint(elseBB); theBuilder.createStore(tempVar, falseBranch); theBuilder.createBranch(mergeBB); theBuilder.addSuccessor(mergeBB); // From now on, emit instructions into the merge block. theBuilder.setInsertPoint(mergeBB); return SpirvEvalInfo(theBuilder.createLoad(typeId, tempVar)).setRValue(); } uint32_t SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions( const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument(); const auto objectId = loadIfAliasVarRef(object); const auto type = object->getType(); const bool isByteAddressBuffer = TypeTranslator::isByteAddressBuffer(type) || TypeTranslator::isRWByteAddressBuffer(type); const bool isStructuredBuffer = TypeTranslator::isStructuredBuffer(type) || TypeTranslator::isAppendStructuredBuffer(type) || TypeTranslator::isConsumeStructuredBuffer(type); assert(isByteAddressBuffer || isStructuredBuffer); // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure // with only one member that is a runtime array. We need to perform // OpArrayLength on member 0. const auto uintType = theBuilder.getUint32Type(); uint32_t length = theBuilder.createBinaryOp(spv::Op::OpArrayLength, uintType, objectId, 0); // For (RW)ByteAddressBuffers, GetDimensions() must return the array length // in bytes, but OpArrayLength returns the number of uints in the runtime // array. Therefore we must multiply the results by 4. if (isByteAddressBuffer) { length = theBuilder.createBinaryOp(spv::Op::OpIMul, uintType, length, theBuilder.getConstantUint32(4u)); } theBuilder.createStore(doExpr(expr->getArg(0)), length); if (isStructuredBuffer) { // For (RW)StructuredBuffer, the stride of the runtime array (which is the // size of the struct) must also be written to the second argument. uint32_t size = 0, stride = 0; std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize( type, spirvOptions.sBufferLayoutRule, &stride); const auto sizeId = theBuilder.getConstantUint32(size); theBuilder.createStore(doExpr(expr->getArg(1)), sizeId); } return 0; } uint32_t SPIRVEmitter::processRWByteAddressBufferAtomicMethods( hlsl::IntrinsicOp opcode, const CXXMemberCallExpr *expr) { // The signature of RWByteAddressBuffer atomic methods are largely: // void Interlocked*(in UINT dest, in UINT value); // void Interlocked*(in UINT dest, in UINT value, out UINT original_value); const auto *object = expr->getImplicitObjectArgument(); const auto objectInfo = loadIfAliasVarRef(object); const auto uintType = theBuilder.getUint32Type(); const uint32_t zero = theBuilder.getConstantUint32(0); const uint32_t offset = doExpr(expr->getArg(0)); // Right shift by 2 to convert the byte offset to uint32_t offset const uint32_t address = theBuilder.createBinaryOp(spv::Op::OpShiftRightLogical, uintType, offset, theBuilder.getConstantUint32(2)); const auto ptrType = theBuilder.getPointerType(uintType, objectInfo.getStorageClass()); const uint32_t ptr = theBuilder.createAccessChain(ptrType, objectInfo, {zero, address}); const uint32_t scope = theBuilder.getConstantUint32(1); // Device const bool isCompareExchange = opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareExchange; const bool isCompareStore = opcode == hlsl::IntrinsicOp::MOP_InterlockedCompareStore; if (isCompareExchange || isCompareStore) { const uint32_t comparator = doExpr(expr->getArg(1)); const uint32_t originalVal = theBuilder.createAtomicCompareExchange( uintType, ptr, scope, zero, zero, doExpr(expr->getArg(2)), comparator); if (isCompareExchange) theBuilder.createStore(doExpr(expr->getArg(3)), originalVal); } else { const uint32_t value = doExpr(expr->getArg(1)); const uint32_t originalVal = theBuilder.createAtomicOp( translateAtomicHlslOpcodeToSpirvOpcode(opcode), uintType, ptr, scope, zero, value); if (expr->getNumArgs() > 2) theBuilder.createStore(doExpr(expr->getArg(2)), originalVal); } return 0; } uint32_t SPIRVEmitter::processGetSamplePosition(const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument()->IgnoreParens(); const auto sampleCount = theBuilder.createUnaryOp( spv::Op::OpImageQuerySamples, theBuilder.getUint32Type(), loadIfGLValue(object)); emitWarning( "GetSamplePosition only supports standard sample settings with 1, 2, 4, " "8, or 16 samples and will return float2(0, 0) for other cases", expr->getCallee()->getExprLoc()); return emitGetSamplePosition(sampleCount, doExpr(expr->getArg(0))); } SpirvEvalInfo SPIRVEmitter::processSubpassLoad(const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument()->IgnoreParens(); const uint32_t sample = expr->getNumArgs() == 1 ? doExpr(expr->getArg(0)) : 0; const uint32_t zero = theBuilder.getConstantInt32(0); const uint32_t location = theBuilder.getConstantComposite( theBuilder.getVecType(theBuilder.getInt32Type(), 2), {zero, zero}); return processBufferTextureLoad(object, location, /*constOffset*/ 0, /*varOffset*/ 0, /*lod*/ sample, /*residencyCode*/ 0); } uint32_t SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) { const auto *object = expr->getImplicitObjectArgument(); const auto objectId = loadIfGLValue(object); const auto type = object->getType(); const auto *recType = type->getAs(); assert(recType); const auto typeName = recType->getDecl()->getName(); const auto numArgs = expr->getNumArgs(); const Expr *mipLevel = nullptr, *numLevels = nullptr, *numSamples = nullptr; assert(TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) || TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type)); // For Texture1D, arguments are either: // a) width // b) MipLevel, width, NumLevels // For Texture1DArray, arguments are either: // a) width, elements // b) MipLevel, width, elements, NumLevels // For Texture2D, arguments are either: // a) width, height // b) MipLevel, width, height, NumLevels // For Texture2DArray, arguments are either: // a) width, height, elements // b) MipLevel, width, height, elements, NumLevels // For Texture3D, arguments are either: // a) width, height, depth // b) MipLevel, width, height, depth, NumLevels // For Texture2DMS, arguments are: width, height, NumSamples // For Texture2DMSArray, arguments are: width, height, elements, NumSamples // For TextureCube, arguments are either: // a) width, height // b) MipLevel, width, height, NumLevels // For TextureCubeArray, arguments are either: // a) width, height, elements // b) MipLevel, width, height, elements, NumLevels // Note: SPIR-V Spec requires return type of OpImageQuerySize(Lod) to be a // scalar/vector of integers. SPIR-V Spec also requires return type of // OpImageQueryLevels and OpImageQuerySamples to be scalar integers. // The HLSL methods, however, have overloaded functions which have float // output arguments. Since the AST naturally won't have casting AST nodes for // such cases, we'll have to perform the cast ourselves. const auto storeToOutputArg = [this](const Expr *outputArg, uint32_t toStoreId) { const auto outputArgType = outputArg->getType(); // Perform cast to float if necessary. if (isFloatOrVecMatOfFloatType(outputArgType)) { toStoreId = theBuilder.createUnaryOp( spv::Op::OpConvertUToF, typeTranslator.translateType(outputArgType), toStoreId); } theBuilder.createStore(doExpr(outputArg), toStoreId); }; if ((typeName == "Texture1D" && numArgs > 1) || (typeName == "Texture2D" && numArgs > 2) || (typeName == "TextureCube" && numArgs > 2) || (typeName == "Texture3D" && numArgs > 3) || (typeName == "Texture1DArray" && numArgs > 2) || (typeName == "TextureCubeArray" && numArgs > 3) || (typeName == "Texture2DArray" && numArgs > 3)) { mipLevel = expr->getArg(0); numLevels = expr->getArg(numArgs - 1); } if (TypeTranslator::isTextureMS(type)) { numSamples = expr->getArg(numArgs - 1); } uint32_t querySize = numArgs; // If numLevels arg is present, mipLevel must also be present. These are not // queried via ImageQuerySizeLod. if (numLevels) querySize -= 2; // If numLevels arg is present, mipLevel must also be present. else if (numSamples) querySize -= 1; const uint32_t uintId = theBuilder.getUint32Type(); const uint32_t resultTypeId = querySize == 1 ? uintId : theBuilder.getVecType(uintId, querySize); // Only Texture types use ImageQuerySizeLod. // TextureMS, RWTexture, Buffers, RWBuffers use ImageQuerySize. uint32_t lod = 0; if (TypeTranslator::isTexture(type) && !numSamples) { if (mipLevel) { // For Texture types when mipLevel argument is present. lod = doExpr(mipLevel); } else { // For Texture types when mipLevel argument is omitted. lod = theBuilder.getConstantInt32(0); } } const uint32_t query = lod ? theBuilder.createBinaryOp(spv::Op::OpImageQuerySizeLod, resultTypeId, objectId, lod) : theBuilder.createUnaryOp(spv::Op::OpImageQuerySize, resultTypeId, objectId); if (querySize == 1) { const uint32_t argIndex = mipLevel ? 1 : 0; storeToOutputArg(expr->getArg(argIndex), query); } else { for (uint32_t i = 0; i < querySize; ++i) { const uint32_t component = theBuilder.createCompositeExtract(uintId, query, {i}); // If the first arg is the mipmap level, we must write the results // starting from Arg(i+1), not Arg(i). const uint32_t argIndex = mipLevel ? i + 1 : i; storeToOutputArg(expr->getArg(argIndex), component); } } if (numLevels || numSamples) { const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples; const spv::Op opcode = numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples; const uint32_t numLevelsSamplesQuery = theBuilder.createUnaryOp(opcode, uintId, objectId); storeToOutputArg(numLevelsSamplesArg, numLevelsSamplesQuery); } return 0; } uint32_t SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr) { // Possible signatures are as follows: // Texture1D(Array).CalculateLevelOfDetail(SamplerState S, float x); // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy); // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz); // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz); // Return type is always a single float (LOD). assert(expr->getNumArgs() == 2u); const auto *object = expr->getImplicitObjectArgument(); const uint32_t objectId = loadIfGLValue(object); const uint32_t samplerState = doExpr(expr->getArg(0)); const uint32_t coordinate = doExpr(expr->getArg(1)); const uint32_t sampledImageType = theBuilder.getSampledImageType( typeTranslator.translateType(object->getType())); const uint32_t sampledImage = theBuilder.createBinaryOp( spv::Op::OpSampledImage, sampledImageType, objectId, samplerState); // The result type of OpImageQueryLod must be a float2. const uint32_t queryResultType = theBuilder.getVecType(theBuilder.getFloat32Type(), 2u); const uint32_t query = theBuilder.createBinaryOp( spv::Op::OpImageQueryLod, queryResultType, sampledImage, coordinate); // The first component of the float2 contains the mipmap array layer. return theBuilder.createCompositeExtract(theBuilder.getFloat32Type(), query, {0}); } uint32_t SPIRVEmitter::processTextureGatherRGBACmpRGBA( const CXXMemberCallExpr *expr, const bool isCmp, const uint32_t component) { // Parameters for .Gather{Red|Green|Blue|Alpha}() are one of the following // two sets: // * SamplerState s, float2 location, int2 offset // * SamplerState s, float2 location, int2 offset0, int2 offset1, // int offset2, int2 offset3 // // An additional 'out uint status' parameter can appear in both of the above. // // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() are one of the following // two sets: // * SamplerState s, float2 location, float compare_value, int2 offset // * SamplerState s, float2 location, float compare_value, int2 offset1, // int2 offset2, int2 offset3, int2 offset4 // // An additional 'out uint status' parameter can appear in both of the above. // // TextureCube's signature is somewhat different from the rest. // Parameters for .Gather{Red|Green|Blue|Alpha}() for TextureCube are: // * SamplerState s, float2 location, out uint status // Parameters for .GatherCmp{Red|Green|Blue|Alpha}() for TextureCube are: // * SamplerState s, float2 location, float compare_value, out uint status // // Return type is always a 4-component vector. const FunctionDecl *callee = expr->getDirectCallee(); const auto numArgs = expr->getNumArgs(); const auto *imageExpr = expr->getImplicitObjectArgument(); const QualType imageType = imageExpr->getType(); const auto imageTypeId = typeTranslator.translateType(imageType); const auto retTypeId = typeTranslator.translateType(callee->getReturnType()); // If the last arg is an unsigned integer, it must be the status. const bool hasStatusArg = expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType(); // Subtract 1 for status arg (if it exists), subtract 1 for compare_value (if // it exists), and subtract 2 for SamplerState and location. const auto numOffsetArgs = numArgs - hasStatusArg - isCmp - 2; // No offset args for TextureCube, 1 or 4 offset args for the rest. assert(numOffsetArgs == 0 || numOffsetArgs == 1 || numOffsetArgs == 4); const uint32_t image = loadIfGLValue(imageExpr); const uint32_t sampler = doExpr(expr->getArg(0)); const uint32_t coordinate = doExpr(expr->getArg(1)); const uint32_t compareVal = isCmp ? doExpr(expr->getArg(2)) : 0; // Handle offsets (if any). bool needsEmulation = false; uint32_t constOffset = 0, varOffset = 0, constOffsets = 0; if (numOffsetArgs == 1) { // The offset arg is not optional. handleOffsetInMethodCall(expr, 2 + isCmp, &constOffset, &varOffset); } else if (numOffsetArgs == 4) { const auto offset0 = tryToEvaluateAsConst(expr->getArg(2 + isCmp)); const auto offset1 = tryToEvaluateAsConst(expr->getArg(3 + isCmp)); const auto offset2 = tryToEvaluateAsConst(expr->getArg(4 + isCmp)); const auto offset3 = tryToEvaluateAsConst(expr->getArg(5 + isCmp)); // If any of the offsets is not constant, we then need to emulate the call // using 4 OpImageGather instructions. Otherwise, we can leverage the // ConstOffsets image operand. if (offset0 && offset1 && offset2 && offset3) { const uint32_t v2i32 = theBuilder.getVecType(theBuilder.getInt32Type(), 2); const uint32_t offsetType = theBuilder.getArrayType(v2i32, theBuilder.getConstantUint32(4)); constOffsets = theBuilder.getConstantComposite( offsetType, {offset0, offset1, offset2, offset3}); } else { needsEmulation = true; } } const auto status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : 0; if (needsEmulation) { const auto elemType = typeTranslator.translateType( hlsl::GetHLSLVecElementType(callee->getReturnType())); uint32_t texels[4]; for (uint32_t i = 0; i < 4; ++i) { varOffset = doExpr(expr->getArg(2 + isCmp + i)); const uint32_t gatherRet = theBuilder.createImageGather( retTypeId, imageTypeId, image, sampler, coordinate, theBuilder.getConstantInt32(component), compareVal, /*constOffset*/ 0, varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0, status); texels[i] = theBuilder.createCompositeExtract(elemType, gatherRet, {i}); } return theBuilder.createCompositeConstruct( retTypeId, {texels[0], texels[1], texels[2], texels[3]}); } return theBuilder.createImageGather( retTypeId, imageTypeId, image, sampler, coordinate, theBuilder.getConstantInt32(component), compareVal, constOffset, varOffset, constOffsets, /*sampleNumber*/ 0, status); } uint32_t SPIRVEmitter::processTextureGatherCmp(const CXXMemberCallExpr *expr) { // Signature for Texture2D/Texture2DArray: // // float4 GatherCmp( // in SamplerComparisonState s, // in float2 location, // in float compare_value // [,in int2 offset] // [,out uint Status] // ); // // Signature for TextureCube/TextureCubeArray: // // float4 GatherCmp( // in SamplerComparisonState s, // in float2 location, // in float compare_value, // out uint Status // ); // // Other Texture types do not have the GatherCmp method. const FunctionDecl *callee = expr->getDirectCallee(); const auto numArgs = expr->getNumArgs(); const bool hasStatusArg = expr->getArg(numArgs - 1)->getType()->isUnsignedIntegerType(); const bool hasOffsetArg = (numArgs == 5) || (numArgs == 4 && !hasStatusArg); const auto *imageExpr = expr->getImplicitObjectArgument(); const uint32_t image = loadIfGLValue(imageExpr); const uint32_t sampler = doExpr(expr->getArg(0)); const uint32_t coordinate = doExpr(expr->getArg(1)); const uint32_t comparator = doExpr(expr->getArg(2)); uint32_t constOffset = 0, varOffset = 0; if (hasOffsetArg) handleOffsetInMethodCall(expr, 3, &constOffset, &varOffset); const auto retType = typeTranslator.translateType(callee->getReturnType()); const auto imageType = typeTranslator.translateType(imageExpr->getType()); const auto status = hasStatusArg ? doExpr(expr->getArg(numArgs - 1)) : 0; return theBuilder.createImageGather( retType, imageType, image, sampler, coordinate, /*component*/ 0, comparator, constOffset, varOffset, /*constOffsets*/ 0, /*sampleNumber*/ 0, status); } SpirvEvalInfo SPIRVEmitter::processBufferTextureLoad( const Expr *object, const uint32_t locationId, uint32_t constOffset, uint32_t varOffset, uint32_t lod, uint32_t residencyCode) { // Loading for Buffer and RWBuffer translates to an OpImageFetch. // The result type of an OpImageFetch must be a vec4 of float or int. const auto type = object->getType(); assert(TypeTranslator::isBuffer(type) || TypeTranslator::isRWBuffer(type) || TypeTranslator::isTexture(type) || TypeTranslator::isRWTexture(type) || TypeTranslator::isSubpassInput(type) || TypeTranslator::isSubpassInputMS(type)); const bool doFetch = TypeTranslator::isBuffer(type) || TypeTranslator::isTexture(type); const uint32_t objectId = loadIfGLValue(object); // For Texture2DMS and Texture2DMSArray, Sample must be used rather than Lod. uint32_t sampleNumber = 0; if (TypeTranslator::isTextureMS(type) || TypeTranslator::isSubpassInputMS(type)) { sampleNumber = lod; lod = 0; } const auto sampledType = hlsl::GetHLSLResourceResultType(type); QualType elemType = sampledType; uint32_t elemCount = 1; uint32_t elemTypeId = 0; (void)TypeTranslator::isVectorType(sampledType, &elemType, &elemCount); if (elemType->isFloatingType()) { elemTypeId = theBuilder.getFloat32Type(); } else if (elemType->isSignedIntegerType()) { elemTypeId = theBuilder.getInt32Type(); } else if (elemType->isUnsignedIntegerType()) { elemTypeId = theBuilder.getUint32Type(); } else { emitError("buffer/texture type unimplemented", object->getExprLoc()); return 0; } // OpImageFetch and OpImageRead can only fetch a vector of 4 elements. const uint32_t texelTypeId = theBuilder.getVecType(elemTypeId, 4u); const uint32_t texel = theBuilder.createImageFetchOrRead( doFetch, texelTypeId, type, objectId, locationId, lod, constOffset, varOffset, /*constOffsets*/ 0, sampleNumber, residencyCode); // If the result type is a vec1, vec2, or vec3, some extra processing // (extraction) is required. uint32_t retVal = extractVecFromVec4(texel, elemCount, elemTypeId); return SpirvEvalInfo(retVal).setRValue(); } SpirvEvalInfo SPIRVEmitter::processByteAddressBufferLoadStore( const CXXMemberCallExpr *expr, uint32_t numWords, bool doStore) { uint32_t resultId = 0; const auto object = expr->getImplicitObjectArgument(); const auto type = object->getType(); const auto objectInfo = loadIfAliasVarRef(object); assert(numWords >= 1 && numWords <= 4); if (doStore) { assert(typeTranslator.isRWByteAddressBuffer(type)); assert(expr->getNumArgs() == 2); } else { assert(typeTranslator.isRWByteAddressBuffer(type) || typeTranslator.isByteAddressBuffer(type)); if (expr->getNumArgs() == 2) { emitError( "(RW)ByteAddressBuffer::Load(in address, out status) not supported", expr->getExprLoc()); return 0; } } const Expr *addressExpr = expr->getArg(0); const uint32_t byteAddress = doExpr(addressExpr); const uint32_t addressTypeId = typeTranslator.translateType(addressExpr->getType()); // Do a OpShiftRightLogical by 2 (divide by 4 to get aligned memory // access). The AST always casts the address to unsinged integer, so shift // by unsinged integer 2. const uint32_t constUint2 = theBuilder.getConstantUint32(2); const uint32_t address = theBuilder.createBinaryOp( spv::Op::OpShiftRightLogical, addressTypeId, byteAddress, constUint2); // Perform access chain into the RWByteAddressBuffer. // First index must be zero (member 0 of the struct is a // runtimeArray). The second index passed to OpAccessChain should be // the address. const uint32_t uintTypeId = theBuilder.getUint32Type(); const uint32_t ptrType = theBuilder.getPointerType(uintTypeId, objectInfo.getStorageClass()); const uint32_t constUint0 = theBuilder.getConstantUint32(0); if (doStore) { const uint32_t valuesId = doExpr(expr->getArg(1)); uint32_t curStoreAddress = address; for (uint32_t wordCounter = 0; wordCounter < numWords; ++wordCounter) { // Extract a 32-bit word from the input. const uint32_t curValue = numWords == 1 ? valuesId : theBuilder.createCompositeExtract( uintTypeId, valuesId, {wordCounter}); // Update the output address if necessary. if (wordCounter > 0) { const uint32_t offset = theBuilder.getConstantUint32(wordCounter); curStoreAddress = theBuilder.createBinaryOp( spv::Op::OpIAdd, addressTypeId, address, offset); } // Store the word to the right address at the output. const uint32_t storePtr = theBuilder.createAccessChain( ptrType, objectInfo, {constUint0, curStoreAddress}); theBuilder.createStore(storePtr, curValue); } } else { uint32_t loadPtr = theBuilder.createAccessChain(ptrType, objectInfo, {constUint0, address}); resultId = theBuilder.createLoad(uintTypeId, loadPtr); if (numWords > 1) { // Load word 2, 3, and 4 where necessary. Use OpCompositeConstruct to // return a vector result. llvm::SmallVector values; values.push_back(resultId); for (uint32_t wordCounter = 2; wordCounter <= numWords; ++wordCounter) { const uint32_t offset = theBuilder.getConstantUint32(wordCounter - 1); const uint32_t newAddress = theBuilder.createBinaryOp( spv::Op::OpIAdd, addressTypeId, address, offset); loadPtr = theBuilder.createAccessChain(ptrType, objectInfo, {constUint0, newAddress}); values.push_back(theBuilder.createLoad(uintTypeId, loadPtr)); } const uint32_t resultType = theBuilder.getVecType(addressTypeId, numWords); resultId = theBuilder.createCompositeConstruct(resultType, values); } } return SpirvEvalInfo(resultId).setRValue(); } SpirvEvalInfo SPIRVEmitter::processStructuredBufferLoad(const CXXMemberCallExpr *expr) { if (expr->getNumArgs() == 2) { emitError( "(RW)StructuredBuffer::Load(in location, out status) not supported", expr->getExprLoc()); return 0; } const auto *buffer = expr->getImplicitObjectArgument(); auto info = loadIfAliasVarRef(buffer); const QualType structType = hlsl::GetHLSLResourceResultType(buffer->getType()); const uint32_t zero = theBuilder.getConstantInt32(0); const uint32_t index = doExpr(expr->getArg(0)); return turnIntoElementPtr(info, structType, {zero, index}); } uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr, bool isInc, bool loadObject) { const uint32_t i32Type = theBuilder.getInt32Type(); const uint32_t one = theBuilder.getConstantUint32(1); // As scope: Device const uint32_t zero = theBuilder.getConstantUint32(0); // As memory sema: None const uint32_t sOne = theBuilder.getConstantInt32(1); const auto *object = expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); if (loadObject) { // We don't need the object's here since counter variable is a // separate variable. But we still need the side effects of evaluating the // object, e.g., if the source code is foo(...).IncrementCounter(), we still // want to emit the code for foo(...). (void)doExpr(object); } const auto *counterPair = getFinalACSBufferCounter(object); if (!counterPair) { emitFatalError("cannot find the associated counter variable", object->getExprLoc()); return 0; } const uint32_t counterPtrType = theBuilder.getPointerType( theBuilder.getInt32Type(), spv::StorageClass::Uniform); const uint32_t counterPtr = theBuilder.createAccessChain( counterPtrType, counterPair->get(theBuilder, typeTranslator), {zero}); uint32_t index = 0; if (isInc) { index = theBuilder.createAtomicOp(spv::Op::OpAtomicIAdd, i32Type, counterPtr, one, zero, sOne); } else { // Note that OpAtomicISub returns the value before the subtraction; // so we need to do substraction again with OpAtomicISub's return value. const auto prev = theBuilder.createAtomicOp(spv::Op::OpAtomicISub, i32Type, counterPtr, one, zero, sOne); index = theBuilder.createBinaryOp(spv::Op::OpISub, i32Type, prev, sOne); } return index; } bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl, const Expr *srcExpr) { // We are handling associated counters here. Casts should not alter which // associated counter to manipulate. srcExpr = srcExpr->IgnoreParenCasts(); // For parameters of forward-declared functions. We must make sure the // associated counter variable is created. But for forward-declared functions, // the translation of the real definition may not be started yet. if (const auto *param = dyn_cast(dstDecl)) declIdMapper.createFnParamCounterVar(param); // For implicit objects of methods. Similar to the above. else if (const auto *thisObject = dyn_cast(dstDecl)) declIdMapper.createFnParamCounterVar(thisObject); // Handle AssocCounter#1 (see CounterVarFields comment) if (const auto *dstPair = declIdMapper.getCounterIdAliasPair(dstDecl)) { const auto *srcPair = getFinalACSBufferCounter(srcExpr); if (!srcPair) { emitFatalError("cannot find the associated counter variable", srcExpr->getExprLoc()); return false; } dstPair->assign(*srcPair, theBuilder, typeTranslator); return true; } // Handle AssocCounter#3 llvm::SmallVector srcIndices; const auto *dstFields = declIdMapper.getCounterVarFields(dstDecl); const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices); if (dstFields && srcFields) { if (!dstFields->assign(*srcFields, theBuilder, typeTranslator)) { emitFatalError("cannot handle associated counter variable assignment", srcExpr->getExprLoc()); return false; } return true; } // AssocCounter#2 and AssocCounter#4 for the lhs cannot happen since the lhs // is a stand-alone decl in this method. return false; } bool SPIRVEmitter::tryToAssignCounterVar(const Expr *dstExpr, const Expr *srcExpr) { dstExpr = dstExpr->IgnoreParenCasts(); srcExpr = srcExpr->IgnoreParenCasts(); const auto *dstPair = getFinalACSBufferCounter(dstExpr); const auto *srcPair = getFinalACSBufferCounter(srcExpr); if ((dstPair == nullptr) != (srcPair == nullptr)) { emitFatalError("cannot handle associated counter variable assignment", srcExpr->getExprLoc()); return false; } // Handle AssocCounter#1 & AssocCounter#2 if (dstPair && srcPair) { dstPair->assign(*srcPair, theBuilder, typeTranslator); return true; } // Handle AssocCounter#3 & AssocCounter#4 llvm::SmallVector dstIndices; llvm::SmallVector srcIndices; const auto *srcFields = getIntermediateACSBufferCounter(srcExpr, &srcIndices); const auto *dstFields = getIntermediateACSBufferCounter(dstExpr, &dstIndices); if (dstFields && srcFields) { return dstFields->assign(*srcFields, dstIndices, srcIndices, theBuilder, typeTranslator); } return false; } const CounterIdAliasPair * SPIRVEmitter::getFinalACSBufferCounter(const Expr *expr) { // AssocCounter#1: referencing some stand-alone variable if (const auto *decl = getReferencedDef(expr)) return declIdMapper.getCounterIdAliasPair(decl); // AssocCounter#2: referencing some non-struct field llvm::SmallVector indices; const auto *base = collectArrayStructIndices(expr, &indices, /*rawIndex=*/true); const auto *decl = (base && isa(base)) ? getOrCreateDeclForMethodObject(cast(curFunction)) : getReferencedDef(base); return declIdMapper.getCounterIdAliasPair(decl, &indices); } const CounterVarFields *SPIRVEmitter::getIntermediateACSBufferCounter( const Expr *expr, llvm::SmallVector *indices) { const auto *base = collectArrayStructIndices(expr, indices, /*rawIndex=*/true); const auto *decl = (base && isa(base)) // Use the decl we created to represent the implicit object ? getOrCreateDeclForMethodObject(cast(curFunction)) // Find the referenced decl from the original source code : getReferencedDef(base); return declIdMapper.getCounterVarFields(decl); } const ImplicitParamDecl * SPIRVEmitter::getOrCreateDeclForMethodObject(const CXXMethodDecl *method) { const auto found = thisDecls.find(method); if (found != thisDecls.end()) return found->second; const std::string name = method->getName().str() + ".this"; // Create a new identifier to convey the name auto &identifier = astContext.Idents.get(name); return thisDecls[method] = ImplicitParamDecl::Create( astContext, /*DC=*/nullptr, SourceLocation(), &identifier, method->getThisType(astContext)->getPointeeType()); } SpirvEvalInfo SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) { const bool isAppend = expr->getNumArgs() == 1; const uint32_t zero = theBuilder.getConstantUint32(0); const auto *object = expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); auto bufferInfo = loadIfAliasVarRef(object); uint32_t index = incDecRWACSBufferCounter( expr, isAppend, // We have already translated the object in the above. Avoid duplication. /*loadObject=*/false); const auto bufferElemTy = hlsl::GetHLSLResourceResultType(object->getType()); (void)turnIntoElementPtr(bufferInfo, bufferElemTy, {zero, index}); if (isAppend) { // Write out the value storeValue(bufferInfo, doExpr(expr->getArg(0)), bufferElemTy); return 0; } else { // Note that we are returning a pointer (lvalue) here inorder to further // acess the fields in this element, e.g., buffer.Consume().a.b. So we // cannot forcefully set all normal function calls as returning rvalue. return bufferInfo; } } uint32_t SPIRVEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) { // TODO: handle multiple stream-output objects const auto *object = expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext); const auto *stream = cast(object)->getDecl(); const uint32_t value = doExpr(expr->getArg(0)); declIdMapper.writeBackOutputStream(stream, stream->getType(), value); theBuilder.createEmitVertex(); return 0; } uint32_t SPIRVEmitter::processStreamOutputRestart(const CXXMemberCallExpr *expr) { // TODO: handle multiple stream-output objects theBuilder.createEndPrimitive(); return 0; } uint32_t SPIRVEmitter::emitGetSamplePosition(const uint32_t sampleCount, const uint32_t sampleIndex) { struct Float2 { float x; float y; }; static const Float2 pos2[] = { {4.0 / 16.0, 4.0 / 16.0}, {-4.0 / 16.0, -4.0 / 16.0}, }; static const Float2 pos4[] = { {-2.0 / 16.0, -6.0 / 16.0}, {6.0 / 16.0, -2.0 / 16.0}, {-6.0 / 16.0, 2.0 / 16.0}, {2.0 / 16.0, 6.0 / 16.0}, }; static const Float2 pos8[] = { {1.0 / 16.0, -3.0 / 16.0}, {-1.0 / 16.0, 3.0 / 16.0}, {5.0 / 16.0, 1.0 / 16.0}, {-3.0 / 16.0, -5.0 / 16.0}, {-5.0 / 16.0, 5.0 / 16.0}, {-7.0 / 16.0, -1.0 / 16.0}, {3.0 / 16.0, 7.0 / 16.0}, {7.0 / 16.0, -7.0 / 16.0}, }; static const Float2 pos16[] = { {1.0 / 16.0, 1.0 / 16.0}, {-1.0 / 16.0, -3.0 / 16.0}, {-3.0 / 16.0, 2.0 / 16.0}, {4.0 / 16.0, -1.0 / 16.0}, {-5.0 / 16.0, -2.0 / 16.0}, {2.0 / 16.0, 5.0 / 16.0}, {5.0 / 16.0, 3.0 / 16.0}, {3.0 / 16.0, -5.0 / 16.0}, {-2.0 / 16.0, 6.0 / 16.0}, {0.0 / 16.0, -7.0 / 16.0}, {-4.0 / 16.0, -6.0 / 16.0}, {-6.0 / 16.0, 4.0 / 16.0}, {-8.0 / 16.0, 0.0 / 16.0}, {7.0 / 16.0, -4.0 / 16.0}, {6.0 / 16.0, 7.0 / 16.0}, {-7.0 / 16.0, -8.0 / 16.0}, }; // We are emitting the SPIR-V for the following HLSL source code: // // float2 position; // // if (count == 2) { // position = pos2[index]; // } // else if (count == 4) { // position = pos4[index]; // } // else if (count == 8) { // position = pos8[index]; // } // else if (count == 16) { // position = pos16[index]; // } // else { // position = float2(0.0f, 0.0f); // } const uint32_t boolType = theBuilder.getBoolType(); const auto v2f32Type = theBuilder.getVecType(theBuilder.getFloat32Type(), 2); const uint32_t ptrType = theBuilder.getPointerType(v2f32Type, spv::StorageClass::Function); // Creates a SPIR-V function scope variable of type float2[len]. const auto createArray = [this, v2f32Type](const Float2 *ptr, uint32_t len) { llvm::SmallVector components; for (uint32_t i = 0; i < len; ++i) { const auto x = theBuilder.getConstantFloat32(ptr[i].x); const auto y = theBuilder.getConstantFloat32(ptr[i].y); components.push_back(theBuilder.getConstantComposite(v2f32Type, {x, y})); } const auto arrType = theBuilder.getArrayType(v2f32Type, theBuilder.getConstantUint32(len)); const auto val = theBuilder.getConstantComposite(arrType, components); const std::string varName = "var.GetSamplePosition.data." + std::to_string(len); const auto var = theBuilder.addFnVar(arrType, varName); theBuilder.createStore(var, val); return var; }; const uint32_t pos2Arr = createArray(pos2, 2); const uint32_t pos4Arr = createArray(pos4, 4); const uint32_t pos8Arr = createArray(pos8, 8); const uint32_t pos16Arr = createArray(pos16, 16); const uint32_t resultVar = theBuilder.addFnVar(v2f32Type, "var.GetSamplePosition.result"); const uint32_t then2BB = theBuilder.createBasicBlock("if.GetSamplePosition.then2"); const uint32_t then4BB = theBuilder.createBasicBlock("if.GetSamplePosition.then4"); const uint32_t then8BB = theBuilder.createBasicBlock("if.GetSamplePosition.then8"); const uint32_t then16BB = theBuilder.createBasicBlock("if.GetSamplePosition.then16"); const uint32_t else2BB = theBuilder.createBasicBlock("if.GetSamplePosition.else2"); const uint32_t else4BB = theBuilder.createBasicBlock("if.GetSamplePosition.else4"); const uint32_t else8BB = theBuilder.createBasicBlock("if.GetSamplePosition.else8"); const uint32_t else16BB = theBuilder.createBasicBlock("if.GetSamplePosition.else16"); const uint32_t merge2BB = theBuilder.createBasicBlock("if.GetSamplePosition.merge2"); const uint32_t merge4BB = theBuilder.createBasicBlock("if.GetSamplePosition.merge4"); const uint32_t merge8BB = theBuilder.createBasicBlock("if.GetSamplePosition.merge8"); const uint32_t merge16BB = theBuilder.createBasicBlock("if.GetSamplePosition.merge16"); // if (count == 2) { const auto check2 = theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount, theBuilder.getConstantUint32(2)); theBuilder.createConditionalBranch(check2, then2BB, else2BB, merge2BB); theBuilder.addSuccessor(then2BB); theBuilder.addSuccessor(else2BB); theBuilder.setMergeTarget(merge2BB); // position = pos2[index]; // } theBuilder.setInsertPoint(then2BB); auto ac = theBuilder.createAccessChain(ptrType, pos2Arr, {sampleIndex}); theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac)); theBuilder.createBranch(merge2BB); theBuilder.addSuccessor(merge2BB); // else if (count == 4) { theBuilder.setInsertPoint(else2BB); const auto check4 = theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount, theBuilder.getConstantUint32(4)); theBuilder.createConditionalBranch(check4, then4BB, else4BB, merge4BB); theBuilder.addSuccessor(then4BB); theBuilder.addSuccessor(else4BB); theBuilder.setMergeTarget(merge4BB); // position = pos4[index]; // } theBuilder.setInsertPoint(then4BB); ac = theBuilder.createAccessChain(ptrType, pos4Arr, {sampleIndex}); theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac)); theBuilder.createBranch(merge4BB); theBuilder.addSuccessor(merge4BB); // else if (count == 8) { theBuilder.setInsertPoint(else4BB); const auto check8 = theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount, theBuilder.getConstantUint32(8)); theBuilder.createConditionalBranch(check8, then8BB, else8BB, merge8BB); theBuilder.addSuccessor(then8BB); theBuilder.addSuccessor(else8BB); theBuilder.setMergeTarget(merge8BB); // position = pos8[index]; // } theBuilder.setInsertPoint(then8BB); ac = theBuilder.createAccessChain(ptrType, pos8Arr, {sampleIndex}); theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac)); theBuilder.createBranch(merge8BB); theBuilder.addSuccessor(merge8BB); // else if (count == 16) { theBuilder.setInsertPoint(else8BB); const auto check16 = theBuilder.createBinaryOp(spv::Op::OpIEqual, boolType, sampleCount, theBuilder.getConstantUint32(16)); theBuilder.createConditionalBranch(check16, then16BB, else16BB, merge16BB); theBuilder.addSuccessor(then16BB); theBuilder.addSuccessor(else16BB); theBuilder.setMergeTarget(merge16BB); // position = pos16[index]; // } theBuilder.setInsertPoint(then16BB); ac = theBuilder.createAccessChain(ptrType, pos16Arr, {sampleIndex}); theBuilder.createStore(resultVar, theBuilder.createLoad(v2f32Type, ac)); theBuilder.createBranch(merge16BB); theBuilder.addSuccessor(merge16BB); // else { // position = float2(0.0f, 0.0f); // } theBuilder.setInsertPoint(else16BB); const auto zero = theBuilder.getConstantFloat32(0); const auto v2f32Zero = theBuilder.getConstantComposite(v2f32Type, {zero, zero}); theBuilder.createStore(resultVar, v2f32Zero); theBuilder.createBranch(merge16BB); theBuilder.addSuccessor(merge16BB); theBuilder.setInsertPoint(merge16BB); theBuilder.createBranch(merge8BB); theBuilder.addSuccessor(merge8BB); theBuilder.setInsertPoint(merge8BB); theBuilder.createBranch(merge4BB); theBuilder.addSuccessor(merge4BB); theBuilder.setInsertPoint(merge4BB); theBuilder.createBranch(merge2BB); theBuilder.addSuccessor(merge2BB); theBuilder.setInsertPoint(merge2BB); return theBuilder.createLoad(v2f32Type, resultVar); } SpirvEvalInfo SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) { const FunctionDecl *callee = expr->getDirectCallee(); llvm::StringRef group; uint32_t opcode = static_cast(hlsl::IntrinsicOp::Num_Intrinsics); if (hlsl::GetIntrinsicOp(callee, opcode, group)) { return processIntrinsicMemberCall(expr, static_cast(opcode)); } return processCall(expr); } void SPIRVEmitter::handleOffsetInMethodCall(const CXXMemberCallExpr *expr, uint32_t index, uint32_t *constOffset, uint32_t *varOffset) { // Ensure the given arg index is not out-of-range. assert(index < expr->getNumArgs()); *constOffset = *varOffset = 0; // Initialize both first if (*constOffset = tryToEvaluateAsConst(expr->getArg(index))) return; // Constant offset else *varOffset = doExpr(expr->getArg(index)); }; SpirvEvalInfo SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr, hlsl::IntrinsicOp opcode) { using namespace hlsl; uint32_t retVal = 0; switch (opcode) { case IntrinsicOp::MOP_Sample: retVal = processTextureSampleGather(expr, /*isSample=*/true); break; case IntrinsicOp::MOP_Gather: retVal = processTextureSampleGather(expr, /*isSample=*/false); break; case IntrinsicOp::MOP_SampleBias: retVal = processTextureSampleBiasLevel(expr, /*isBias=*/true); break; case IntrinsicOp::MOP_SampleLevel: retVal = processTextureSampleBiasLevel(expr, /*isBias=*/false); break; case IntrinsicOp::MOP_SampleGrad: retVal = processTextureSampleGrad(expr); break; case IntrinsicOp::MOP_SampleCmp: retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/true); break; case IntrinsicOp::MOP_SampleCmpLevelZero: retVal = processTextureSampleCmpCmpLevelZero(expr, /*isCmp=*/false); break; case IntrinsicOp::MOP_GatherRed: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 0); break; case IntrinsicOp::MOP_GatherGreen: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 1); break; case IntrinsicOp::MOP_GatherBlue: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 2); break; case IntrinsicOp::MOP_GatherAlpha: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/false, 3); break; case IntrinsicOp::MOP_GatherCmp: retVal = processTextureGatherCmp(expr); break; case IntrinsicOp::MOP_GatherCmpRed: retVal = processTextureGatherRGBACmpRGBA(expr, /*isCmp=*/true, 0); break; case IntrinsicOp::MOP_Load: return processBufferTextureLoad(expr); case IntrinsicOp::MOP_Load2: return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ false); case IntrinsicOp::MOP_Load3: return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ false); case IntrinsicOp::MOP_Load4: return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ false); case IntrinsicOp::MOP_Store: return processByteAddressBufferLoadStore(expr, 1, /*doStore*/ true); case IntrinsicOp::MOP_Store2: return processByteAddressBufferLoadStore(expr, 2, /*doStore*/ true); case IntrinsicOp::MOP_Store3: return processByteAddressBufferLoadStore(expr, 3, /*doStore*/ true); case IntrinsicOp::MOP_Store4: return processByteAddressBufferLoadStore(expr, 4, /*doStore*/ true); case IntrinsicOp::MOP_GetDimensions: retVal = processGetDimensions(expr); break; case IntrinsicOp::MOP_CalculateLevelOfDetail: retVal = processTextureLevelOfDetail(expr); break; case IntrinsicOp::MOP_IncrementCounter: retVal = theBuilder.createUnaryOp( spv::Op::OpBitcast, theBuilder.getUint32Type(), incDecRWACSBufferCounter(expr, /*isInc*/ true)); break; case IntrinsicOp::MOP_DecrementCounter: retVal = theBuilder.createUnaryOp( spv::Op::OpBitcast, theBuilder.getUint32Type(), incDecRWACSBufferCounter(expr, /*isInc*/ false)); break; case IntrinsicOp::MOP_Append: if (hlsl::IsHLSLStreamOutputType( expr->getImplicitObjectArgument()->getType())) return processStreamOutputAppend(expr); else return processACSBufferAppendConsume(expr); case IntrinsicOp::MOP_Consume: return processACSBufferAppendConsume(expr); case IntrinsicOp::MOP_RestartStrip: retVal = processStreamOutputRestart(expr); break; case IntrinsicOp::MOP_InterlockedAdd: case IntrinsicOp::MOP_InterlockedAnd: case IntrinsicOp::MOP_InterlockedOr: case IntrinsicOp::MOP_InterlockedXor: case IntrinsicOp::MOP_InterlockedUMax: case IntrinsicOp::MOP_InterlockedUMin: case IntrinsicOp::MOP_InterlockedMax: case IntrinsicOp::MOP_InterlockedMin: case IntrinsicOp::MOP_InterlockedExchange: case IntrinsicOp::MOP_InterlockedCompareExchange: case IntrinsicOp::MOP_InterlockedCompareStore: retVal = processRWByteAddressBufferAtomicMethods(opcode, expr); break; case IntrinsicOp::MOP_GetSamplePosition: retVal = processGetSamplePosition(expr); break; case IntrinsicOp::MOP_SubpassLoad: retVal = processSubpassLoad(expr); break; case IntrinsicOp::MOP_GatherCmpGreen: case IntrinsicOp::MOP_GatherCmpBlue: case IntrinsicOp::MOP_GatherCmpAlpha: case IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped: emitError("no equivalent for %0 intrinsic method in Vulkan", expr->getCallee()->getExprLoc()) << expr->getMethodDecl()->getName(); return 0; default: emitError("intrinsic '%0' method unimplemented", expr->getCallee()->getExprLoc()) << expr->getDirectCallee()->getName(); return 0; } return SpirvEvalInfo(retVal).setRValue(); } uint32_t SPIRVEmitter::createImageSample( QualType retType, uint32_t imageType, uint32_t image, uint32_t sampler, uint32_t coordinate, uint32_t compareVal, uint32_t bias, uint32_t lod, std::pair grad, uint32_t constOffset, uint32_t varOffset, uint32_t constOffsets, uint32_t sample, uint32_t minLod, uint32_t residencyCodeId) { const auto retTypeId = typeTranslator.translateType(retType); // SampleDref* instructions in SPIR-V always return a scalar. // They also have the correct type in HLSL. if (compareVal) { return theBuilder.createImageSample(retTypeId, imageType, image, sampler, coordinate, compareVal, bias, lod, grad, constOffset, varOffset, constOffsets, sample, minLod, residencyCodeId); } // Non-Dref Sample instructions in SPIR-V must always return a vec4. auto texelTypeId = retTypeId; QualType elemType = {}; uint32_t elemTypeId = 0; uint32_t retVecSize = 0; if (TypeTranslator::isVectorType(retType, &elemType, &retVecSize) && retVecSize != 4) { elemTypeId = typeTranslator.translateType(elemType); texelTypeId = theBuilder.getVecType(elemTypeId, 4); } else if (TypeTranslator::isScalarType(retType)) { retVecSize = 1; elemTypeId = typeTranslator.translateType(retType); texelTypeId = theBuilder.getVecType(elemTypeId, 4); } // The Lod and Grad image operands requires explicit-lod instructions. // Otherwise we use implicit-lod instructions. const bool isExplicit = lod || (grad.first && grad.second); // Implicit-lod instructions are only allowed in pixel shader. if (!shaderModel.IsPS() && !isExplicit) needsLegalization = true; uint32_t retVal = theBuilder.createImageSample( texelTypeId, imageType, image, sampler, coordinate, compareVal, bias, lod, grad, constOffset, varOffset, constOffsets, sample, minLod, residencyCodeId); // Extract smaller vector from the vec4 result if necessary. if (texelTypeId != retTypeId) { retVal = extractVecFromVec4(retVal, retVecSize, elemTypeId); } return retVal; } uint32_t SPIRVEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr, const bool isSample) { // Signatures: // For Texture1D, Texture1DArray, Texture2D, Texture2DArray, Texture3D: // DXGI_FORMAT Object.Sample(sampler_state S, // float Location // [, int Offset] // [, float Clamp] // [, out uint Status]); // // For TextureCube and TextureCubeArray: // DXGI_FORMAT Object.Sample(sampler_state S, // float Location // [, float Clamp] // [, out uint Status]); // // For Texture2D/Texture2DArray: //