|
@@ -203,8 +203,9 @@ bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
|
|
return false;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
|
|
-bool spirvToolsLegalize(std::vector<uint32_t> *module, std::string *messages) {
|
|
|
|
- spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
|
|
|
|
|
|
+bool spirvToolsLegalize(spv_target_env env, std::vector<uint32_t> *module,
|
|
|
|
+ std::string *messages) {
|
|
|
|
+ spvtools::Optimizer optimizer(env);
|
|
|
|
|
|
optimizer.SetMessageConsumer(
|
|
optimizer.SetMessageConsumer(
|
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
|
@@ -220,8 +221,9 @@ bool spirvToolsLegalize(std::vector<uint32_t> *module, std::string *messages) {
|
|
return optimizer.Run(module->data(), module->size(), module);
|
|
return optimizer.Run(module->data(), module->size(), module);
|
|
}
|
|
}
|
|
|
|
|
|
-bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
|
|
|
|
- spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
|
|
|
|
|
|
+bool spirvToolsOptimize(spv_target_env env, std::vector<uint32_t> *module,
|
|
|
|
+ std::string *messages) {
|
|
|
|
+ spvtools::Optimizer optimizer(env);
|
|
|
|
|
|
optimizer.SetMessageConsumer(
|
|
optimizer.SetMessageConsumer(
|
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
|
@@ -235,9 +237,9 @@ bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
|
|
return optimizer.Run(module->data(), module->size(), module);
|
|
return optimizer.Run(module->data(), module->size(), module);
|
|
}
|
|
}
|
|
|
|
|
|
-bool spirvToolsValidate(std::vector<uint32_t> *module, std::string *messages,
|
|
|
|
- bool relaxLogicalPointer) {
|
|
|
|
- spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
|
|
|
|
|
|
+bool spirvToolsValidate(spv_target_env env, std::vector<uint32_t> *module,
|
|
|
|
+ std::string *messages, bool relaxLogicalPointer) {
|
|
|
|
+ spvtools::SpirvTools tools(env);
|
|
|
|
|
|
tools.SetMessageConsumer(
|
|
tools.SetMessageConsumer(
|
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
|
[messages](spv_message_level_t /*level*/, const char * /*source*/,
|
|
@@ -477,6 +479,41 @@ void getBaseClassIndices(const CastExpr *expr,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+spv::Capability getCapabilityForGroupNonUniform(spv::Op opcode) {
|
|
|
|
+ switch (opcode) {
|
|
|
|
+ case spv::Op::OpGroupNonUniformElect:
|
|
|
|
+ return spv::Capability::GroupNonUniform;
|
|
|
|
+ case spv::Op::OpGroupNonUniformAny:
|
|
|
|
+ case spv::Op::OpGroupNonUniformAll:
|
|
|
|
+ case spv::Op::OpGroupNonUniformAllEqual:
|
|
|
|
+ return spv::Capability::GroupNonUniformVote;
|
|
|
|
+ case spv::Op::OpGroupNonUniformBallot:
|
|
|
|
+ case spv::Op::OpGroupNonUniformBallotBitCount:
|
|
|
|
+ case spv::Op::OpGroupNonUniformBroadcast:
|
|
|
|
+ case spv::Op::OpGroupNonUniformBroadcastFirst:
|
|
|
|
+ return spv::Capability::GroupNonUniformBallot;
|
|
|
|
+ case spv::Op::OpGroupNonUniformIAdd:
|
|
|
|
+ case spv::Op::OpGroupNonUniformFAdd:
|
|
|
|
+ case spv::Op::OpGroupNonUniformIMul:
|
|
|
|
+ case spv::Op::OpGroupNonUniformFMul:
|
|
|
|
+ case spv::Op::OpGroupNonUniformSMax:
|
|
|
|
+ case spv::Op::OpGroupNonUniformUMax:
|
|
|
|
+ case spv::Op::OpGroupNonUniformFMax:
|
|
|
|
+ case spv::Op::OpGroupNonUniformSMin:
|
|
|
|
+ case spv::Op::OpGroupNonUniformUMin:
|
|
|
|
+ case spv::Op::OpGroupNonUniformFMin:
|
|
|
|
+ case spv::Op::OpGroupNonUniformBitwiseAnd:
|
|
|
|
+ case spv::Op::OpGroupNonUniformBitwiseOr:
|
|
|
|
+ case spv::Op::OpGroupNonUniformBitwiseXor:
|
|
|
|
+ return spv::Capability::GroupNonUniformArithmetic;
|
|
|
|
+ case spv::Op::OpGroupNonUniformQuadBroadcast:
|
|
|
|
+ case spv::Op::OpGroupNonUniformQuadSwap:
|
|
|
|
+ return spv::Capability::GroupNonUniformQuad;
|
|
|
|
+ }
|
|
|
|
+ assert(false && "unhandled opcode");
|
|
|
|
+ return spv::Capability::Max;
|
|
|
|
+}
|
|
|
|
+
|
|
} // namespace
|
|
} // namespace
|
|
|
|
|
|
SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
|
|
SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
|
|
@@ -490,8 +527,8 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
|
|
declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
|
|
declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
|
|
typeTranslator(astContext, theBuilder, diags, options),
|
|
typeTranslator(astContext, theBuilder, diags, options),
|
|
entryFunctionId(0), curFunction(nullptr), curThis(0),
|
|
entryFunctionId(0), curFunction(nullptr), curThis(0),
|
|
- seenPushConstantAt(), isSpecConstantMode(false),
|
|
|
|
- needsLegalization(false) {
|
|
|
|
|
|
+ seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
|
|
|
|
+ needsSpirv1p3(false) {
|
|
if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
|
|
if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
|
|
emitError("unknown shader module: %0", {}) << shaderModel.GetName();
|
|
emitError("unknown shader module: %0", {}) << shaderModel.GetName();
|
|
if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
|
|
if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
|
|
@@ -531,6 +568,12 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
|
|
if (context.getDiagnostics().hasErrorOccurred())
|
|
if (context.getDiagnostics().hasErrorOccurred())
|
|
return;
|
|
return;
|
|
|
|
|
|
|
|
+ spv_target_env targetEnv = SPV_ENV_VULKAN_1_0;
|
|
|
|
+ if (needsSpirv1p3) {
|
|
|
|
+ theBuilder.useSpirv1p3();
|
|
|
|
+ targetEnv = SPV_ENV_VULKAN_1_1;
|
|
|
|
+ }
|
|
|
|
+
|
|
AddRequiredCapabilitiesForShaderModel();
|
|
AddRequiredCapabilitiesForShaderModel();
|
|
|
|
|
|
// Addressing and memory model are required in a valid SPIR-V module.
|
|
// Addressing and memory model are required in a valid SPIR-V module.
|
|
@@ -555,7 +598,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
|
|
// Run legalization passes
|
|
// Run legalization passes
|
|
if (needsLegalization || declIdMapper.requiresLegalization()) {
|
|
if (needsLegalization || declIdMapper.requiresLegalization()) {
|
|
std::string messages;
|
|
std::string messages;
|
|
- if (!spirvToolsLegalize(&m, &messages)) {
|
|
|
|
|
|
+ if (!spirvToolsLegalize(targetEnv, &m, &messages)) {
|
|
emitFatalError("failed to legalize SPIR-V: %0", {}) << messages;
|
|
emitFatalError("failed to legalize SPIR-V: %0", {}) << messages;
|
|
emitNote("please file a bug report on "
|
|
emitNote("please file a bug report on "
|
|
"https://github.com/Microsoft/DirectXShaderCompiler/issues "
|
|
"https://github.com/Microsoft/DirectXShaderCompiler/issues "
|
|
@@ -570,7 +613,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
|
|
// Run optimization passes
|
|
// Run optimization passes
|
|
if (theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) {
|
|
if (theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) {
|
|
std::string messages;
|
|
std::string messages;
|
|
- if (!spirvToolsOptimize(&m, &messages)) {
|
|
|
|
|
|
+ if (!spirvToolsOptimize(targetEnv, &m, &messages)) {
|
|
emitFatalError("failed to optimize SPIR-V: %0", {}) << messages;
|
|
emitFatalError("failed to optimize SPIR-V: %0", {}) << messages;
|
|
emitNote("please file a bug report on "
|
|
emitNote("please file a bug report on "
|
|
"https://github.com/Microsoft/DirectXShaderCompiler/issues "
|
|
"https://github.com/Microsoft/DirectXShaderCompiler/issues "
|
|
@@ -584,7 +627,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
|
|
// Validate the generated SPIR-V code
|
|
// Validate the generated SPIR-V code
|
|
if (!spirvOptions.disableValidation) {
|
|
if (!spirvOptions.disableValidation) {
|
|
std::string messages;
|
|
std::string messages;
|
|
- if (!spirvToolsValidate(&m, &messages,
|
|
|
|
|
|
+ if (!spirvToolsValidate(targetEnv, &m, &messages,
|
|
declIdMapper.requiresLegalization())) {
|
|
declIdMapper.requiresLegalization())) {
|
|
emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
|
|
emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
|
|
emitNote("please file a bug report on "
|
|
emitNote("please file a bug report on "
|
|
@@ -6019,6 +6062,7 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
|
|
retVal = processIntrinsicF32ToF16(callExpr);
|
|
retVal = processIntrinsicF32ToF16(callExpr);
|
|
break;
|
|
break;
|
|
case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
|
|
case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
|
|
|
|
+ needsSpirv1p3 = true;
|
|
const uint32_t retType =
|
|
const uint32_t retType =
|
|
typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
const uint32_t varId =
|
|
const uint32_t varId =
|
|
@@ -6026,23 +6070,73 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
|
|
retVal = theBuilder.createLoad(retType, varId);
|
|
retVal = theBuilder.createLoad(retType, varId);
|
|
} break;
|
|
} break;
|
|
case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
|
|
case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
|
|
|
|
+ needsSpirv1p3 = true;
|
|
const uint32_t retType =
|
|
const uint32_t retType =
|
|
typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
const uint32_t varId =
|
|
const uint32_t varId =
|
|
declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
|
|
declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
|
|
retVal = theBuilder.createLoad(retType, varId);
|
|
retVal = theBuilder.createLoad(retType, varId);
|
|
} break;
|
|
} break;
|
|
- case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst: {
|
|
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveIsFirstLane:
|
|
|
|
+ retVal = processWaveQuery(callExpr, spv::Op::OpGroupNonUniformElect);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveAllTrue:
|
|
|
|
+ retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAll);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveAnyTrue:
|
|
|
|
+ retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAny);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveBallot:
|
|
|
|
+ retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformBallot);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveAllEqual:
|
|
|
|
+ retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
|
|
|
|
+ retVal = processWaveReductionOrPrefix(
|
|
|
|
+ callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
|
|
|
|
+ spv::GroupOperation::Reduce);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveSum:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveUProduct:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveProduct:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveUMax:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveMax:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveUMin:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveMin:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveBitAnd:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveBitOr:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveActiveBitXor: {
|
|
|
|
+ const auto retType = callExpr->getCallReturnType(astContext);
|
|
|
|
+ retVal = processWaveReductionOrPrefix(
|
|
|
|
+ callExpr, translateWaveOp(hlslOpcode, retType, callExpr->getExprLoc()),
|
|
|
|
+ spv::GroupOperation::Reduce);
|
|
|
|
+ } break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WavePrefixUSum:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WavePrefixSum:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WavePrefixUProduct:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WavePrefixProduct: {
|
|
const auto retType = callExpr->getCallReturnType(astContext);
|
|
const auto retType = callExpr->getCallReturnType(astContext);
|
|
- if (!retType->isScalarType()) {
|
|
|
|
- emitError("vector overloads of WaveReadLaneFirst unimplemented",
|
|
|
|
- callExpr->getExprLoc());
|
|
|
|
- return 0;
|
|
|
|
- }
|
|
|
|
- const uint32_t retTypeId = typeTranslator.translateType(retType);
|
|
|
|
- retVal = theBuilder.createSubgroupFirstInvocation(
|
|
|
|
- retTypeId, doExpr(callExpr->getArg(0)));
|
|
|
|
|
|
+ retVal = processWaveReductionOrPrefix(
|
|
|
|
+ callExpr, translateWaveOp(hlslOpcode, retType, callExpr->getExprLoc()),
|
|
|
|
+ spv::GroupOperation::ExclusiveScan);
|
|
} break;
|
|
} break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
|
|
|
|
+ retVal = processWaveReductionOrPrefix(
|
|
|
|
+ callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
|
|
|
|
+ spv::GroupOperation::ExclusiveScan);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
|
|
|
|
+ retVal = processWaveBroadcast(callExpr);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
|
|
|
|
+ retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
|
|
|
|
+ break;
|
|
case hlsl::IntrinsicOp::IOP_abort:
|
|
case hlsl::IntrinsicOp::IOP_abort:
|
|
case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
|
|
case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
|
|
case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
|
|
case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
|
|
@@ -6413,6 +6507,194 @@ uint32_t SPIRVEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
|
|
return theBuilder.createCompositeConstruct(uint4Type, accums);
|
|
return theBuilder.createCompositeConstruct(uint4Type, accums);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+uint32_t SPIRVEmitter::processWaveQuery(const CallExpr *callExpr,
|
|
|
|
+ spv::Op opcode) {
|
|
|
|
+ // Signatures:
|
|
|
|
+ // bool WaveIsFirstLane()
|
|
|
|
+ // uint WaveGetLaneCount()
|
|
|
|
+ // uint WaveGetLaneIndex()
|
|
|
|
+ assert(callExpr->getNumArgs() == 0);
|
|
|
|
+ needsSpirv1p3 = true;
|
|
|
|
+ theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
|
|
|
|
+ const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
|
|
|
|
+ const uint32_t retType =
|
|
|
|
+ typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
|
|
+ return theBuilder.createGroupNonUniformOp(opcode, retType, subgroupScope);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+uint32_t SPIRVEmitter::processWaveVote(const CallExpr *callExpr,
|
|
|
|
+ spv::Op opcode) {
|
|
|
|
+ // Signatures:
|
|
|
|
+ // bool WaveActiveAnyTrue( bool expr )
|
|
|
|
+ // bool WaveActiveAllTrue( bool expr )
|
|
|
|
+ // bool uint4 WaveActiveBallot( bool expr )
|
|
|
|
+ assert(callExpr->getNumArgs() == 1);
|
|
|
|
+ needsSpirv1p3 = true;
|
|
|
|
+ theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
|
|
|
|
+ const uint32_t predicate = doExpr(callExpr->getArg(0));
|
|
|
|
+ const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
|
|
|
|
+ const uint32_t retType =
|
|
|
|
+ typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
|
|
+ return theBuilder.createGroupNonUniformUnaryOp(opcode, retType, subgroupScope,
|
|
|
|
+ predicate);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+spv::Op SPIRVEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
|
|
|
|
+ SourceLocation srcLoc) {
|
|
|
|
+ const bool isSintType = isSintOrVecMatOfSintType(type);
|
|
|
|
+ const bool isUintType = isUintOrVecMatOfUintType(type);
|
|
|
|
+ const bool isFloatType = isFloatOrVecMatOfFloatType(type);
|
|
|
|
+
|
|
|
|
+#define WAVE_OP_CASE_INT(kind, intWaveOp) \
|
|
|
|
+ \
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_Wave##kind: { \
|
|
|
|
+ if (isSintType || isUintType) { \
|
|
|
|
+ return spv::Op::OpGroupNonUniform##intWaveOp; \
|
|
|
|
+ } \
|
|
|
|
+ } break
|
|
|
|
+
|
|
|
|
+#define WAVE_OP_CASE_INT_FLOAT(kind, intWaveOp, floatWaveOp) \
|
|
|
|
+ \
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_Wave##kind: { \
|
|
|
|
+ if (isSintType || isUintType) { \
|
|
|
|
+ return spv::Op::OpGroupNonUniform##intWaveOp; \
|
|
|
|
+ } \
|
|
|
|
+ if (isFloatType) { \
|
|
|
|
+ return spv::Op::OpGroupNonUniform##floatWaveOp; \
|
|
|
|
+ } \
|
|
|
|
+ } break
|
|
|
|
+
|
|
|
|
+#define WAVE_OP_CASE_SINT_UINT_FLOAT(kind, sintWaveOp, uintWaveOp, \
|
|
|
|
+ floatWaveOp) \
|
|
|
|
+ \
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_Wave##kind: { \
|
|
|
|
+ if (isSintType) { \
|
|
|
|
+ return spv::Op::OpGroupNonUniform##sintWaveOp; \
|
|
|
|
+ } \
|
|
|
|
+ if (isUintType) { \
|
|
|
|
+ return spv::Op::OpGroupNonUniform##uintWaveOp; \
|
|
|
|
+ } \
|
|
|
|
+ if (isFloatType) { \
|
|
|
|
+ return spv::Op::OpGroupNonUniform##floatWaveOp; \
|
|
|
|
+ } \
|
|
|
|
+ } break
|
|
|
|
+
|
|
|
|
+ switch (op) {
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(ActiveUSum, IAdd, FAdd);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(ActiveSum, IAdd, FAdd);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(ActiveUProduct, IMul, FMul);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(ActiveProduct, IMul, FMul);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(PrefixUSum, IAdd, FAdd);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(PrefixSum, IAdd, FAdd);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(PrefixUProduct, IMul, FMul);
|
|
|
|
+ WAVE_OP_CASE_INT_FLOAT(PrefixProduct, IMul, FMul);
|
|
|
|
+ WAVE_OP_CASE_INT(ActiveBitAnd, BitwiseAnd);
|
|
|
|
+ WAVE_OP_CASE_INT(ActiveBitOr, BitwiseOr);
|
|
|
|
+ WAVE_OP_CASE_INT(ActiveBitXor, BitwiseXor);
|
|
|
|
+ WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMax, SMax, UMax, FMax);
|
|
|
|
+ WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMax, SMax, UMax, FMax);
|
|
|
|
+ WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMin, SMin, UMin, FMin);
|
|
|
|
+ WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMin, SMin, UMin, FMin);
|
|
|
|
+ }
|
|
|
|
+#undef WAVE_OP_CASE_INT_FLOAT
|
|
|
|
+#undef WAVE_OP_CASE_INT
|
|
|
|
+#undef WAVE_OP_CASE_SINT_UINT_FLOAT
|
|
|
|
+
|
|
|
|
+ emitError("translating wave operator '%0' unimplemented", srcLoc)
|
|
|
|
+ << static_cast<uint32_t>(op);
|
|
|
|
+ return spv::Op::OpNop;
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
|
|
|
|
+ const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
|
|
|
|
+ // Signatures:
|
|
|
|
+ // bool WaveActiveAllEqual( <type> expr )
|
|
|
|
+ // uint WaveActiveCountBits( bool bBit )
|
|
|
|
+ // <type> WaveActiveSum( <type> expr )
|
|
|
|
+ // <type> WaveActiveProduct( <type> expr )
|
|
|
|
+ // <int_type> WaveActiveBitAnd( <int_type> expr )
|
|
|
|
+ // <int_type> WaveActiveBitOr( <int_type> expr )
|
|
|
|
+ // <int_type> WaveActiveBitXor( <int_type> expr )
|
|
|
|
+ // <type> WaveActiveMin( <type> expr)
|
|
|
|
+ // <type> WaveActiveMax( <type> expr)
|
|
|
|
+ //
|
|
|
|
+ // uint WavePrefixCountBits(Bool bBit)
|
|
|
|
+ // <type> WavePrefixProduct(<type> value)
|
|
|
|
+ // <type> WavePrefixSum(<type> value)
|
|
|
|
+ assert(callExpr->getNumArgs() == 1);
|
|
|
|
+ needsSpirv1p3 = true;
|
|
|
|
+ theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
|
|
|
|
+ const uint32_t predicate = doExpr(callExpr->getArg(0));
|
|
|
|
+ const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
|
|
|
|
+ const uint32_t retType =
|
|
|
|
+ typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
|
|
+ return theBuilder.createGroupNonUniformUnaryOp(
|
|
|
|
+ opcode, retType, subgroupScope, predicate,
|
|
|
|
+ llvm::Optional<spv::GroupOperation>(groupOp));
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+uint32_t SPIRVEmitter::processWaveBroadcast(const CallExpr *callExpr) {
|
|
|
|
+ // Signatures:
|
|
|
|
+ // <type> WaveReadLaneFirst(<type> expr)
|
|
|
|
+ // <type> WaveReadLaneAt(<type> expr, uint laneIndex)
|
|
|
|
+ const auto numArgs = callExpr->getNumArgs();
|
|
|
|
+ assert(numArgs == 1 || numArgs == 2);
|
|
|
|
+ needsSpirv1p3 = true;
|
|
|
|
+ theBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
|
|
|
|
+ const uint32_t value = doExpr(callExpr->getArg(0));
|
|
|
|
+ const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
|
|
|
|
+ const uint32_t retType =
|
|
|
|
+ typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
|
|
+ if (numArgs == 2)
|
|
|
|
+ return theBuilder.createGroupNonUniformBinaryOp(
|
|
|
|
+ spv::Op::OpGroupNonUniformBroadcast, retType, subgroupScope, value,
|
|
|
|
+ doExpr(callExpr->getArg(1)));
|
|
|
|
+ else
|
|
|
|
+ return theBuilder.createGroupNonUniformUnaryOp(
|
|
|
|
+ spv::Op::OpGroupNonUniformBroadcastFirst, retType, subgroupScope,
|
|
|
|
+ value);
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+uint32_t SPIRVEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
|
|
|
|
+ hlsl::IntrinsicOp op) {
|
|
|
|
+ // Signatures:
|
|
|
|
+ // <type> QuadReadAcrossX(<type> localValue)
|
|
|
|
+ // <type> QuadReadAcrossY(<type> localValue)
|
|
|
|
+ // <type> QuadReadAcrossDiagonal(<type> localValue)
|
|
|
|
+ // <type> QuadReadLaneAt(<type> sourceValue, uint quadLaneID)
|
|
|
|
+ assert(callExpr->getNumArgs() == 1 || callExpr->getNumArgs() == 2);
|
|
|
|
+ needsSpirv1p3 = true;
|
|
|
|
+ theBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
|
|
|
|
+
|
|
|
|
+ const uint32_t value = doExpr(callExpr->getArg(0));
|
|
|
|
+ const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
|
|
|
|
+ const uint32_t retType =
|
|
|
|
+ typeTranslator.translateType(callExpr->getCallReturnType(astContext));
|
|
|
|
+
|
|
|
|
+ uint32_t target = 0;
|
|
|
|
+ spv::Op opcode = spv::Op::OpGroupNonUniformQuadSwap;
|
|
|
|
+ switch (op) {
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
|
|
|
|
+ target = theBuilder.getConstantUint32(0);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
|
|
|
|
+ target = theBuilder.getConstantUint32(1);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
|
|
|
|
+ target = theBuilder.getConstantUint32(2);
|
|
|
|
+ break;
|
|
|
|
+ case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
|
|
|
|
+ target = doExpr(callExpr->getArg(1));
|
|
|
|
+ opcode = spv::Op::OpGroupNonUniformQuadBroadcast;
|
|
|
|
+ break;
|
|
|
|
+ default:
|
|
|
|
+ llvm_unreachable("case should not appear here");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return theBuilder.createGroupNonUniformBinaryOp(opcode, retType,
|
|
|
|
+ subgroupScope, value, target);
|
|
|
|
+}
|
|
|
|
+
|
|
uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
|
|
uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
|
|
// Signature is: ret modf(x, ip)
|
|
// Signature is: ret modf(x, ip)
|
|
// [in] x: the input floating-point value.
|
|
// [in] x: the input floating-point value.
|