Преглед на файлове

[spirv] Use uint rather than bool for stage IO vars (#1550)

Ehsan преди 7 години
родител
ревизия
f6bbfdf4d8

+ 55 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -23,10 +23,34 @@
 #include "llvm/ADT/StringMap.h"
 #include "llvm/ADT/StringSet.h"
 
+#include "SPIRVEmitter.h"
+
 namespace clang {
 namespace spirv {
 
 namespace {
+/// \brief Returns true if the given decl is a boolean stage I/O variable.
+/// Returns false if the type is not boolean, or the decl is a built-in stage
+/// variable.
+bool isBooleanStageIOVar(const NamedDecl *decl, QualType type,
+                         const hlsl::DXIL::SemanticKind semanticKind,
+                         const hlsl::SigPoint::Kind sigPointKind) {
+  // [[vk::builtin(...)]] makes the decl a built-in stage variable.
+  // IsFrontFace (if used as PSIn) is the only known boolean built-in stage
+  // variable.
+  const bool isBooleanBuiltin =
+      (decl->getAttr<VKBuiltInAttr>() != nullptr) ||
+      (semanticKind == hlsl::Semantic::Kind::IsFrontFace &&
+       sigPointKind == hlsl::SigPoint::Kind::PSIn);
+
+  // TODO: support boolean matrix stage I/O variable if needed.
+  QualType elemType = {};
+  const bool isBooleanType = ((TypeTranslator::isScalarType(type, &elemType) ||
+                               TypeTranslator::isVectorType(type, &elemType)) &&
+                              elemType->isBooleanType());
+
+  return isBooleanType && !isBooleanBuiltin;
+}
 
 /// \brief Returns the stage variable's register assignment for the given Decl.
 const hlsl::RegisterAssignment *getResourceBinding(const NamedDecl *decl) {
@@ -1402,6 +1426,7 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
       return true;
 
     const uint32_t srcTypeId = typeId; // Variable type in source code
+    const QualType srcQualType = type; // Variable type in source code
     uint32_t srcVecElemTypeId = 0;     // Variable element type if vector
 
     switch (semanticKind) {
@@ -1438,6 +1463,13 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
       break;
     }
 
+    // Boolean stage I/O variables must be represented as unsigned integers.
+    // Boolean built-in variables are represented as bool.
+    if (isBooleanStageIOVar(decl, type, semanticKind, sigPoint->GetKind())) {
+      type = typeTranslator.getUintTypeWithSourceComponents(type);
+      typeId = typeTranslator.translateType(type);
+    }
+
     // Handle the extra arrayness
     const uint32_t elementTypeId = typeId; // Array element's type
     if (arraySize != 0)
@@ -1596,6 +1628,14 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
       // Reciprocate SV_Position.w if requested
       if (semanticKind == hlsl::Semantic::Kind::Position)
         *value = invertWIfRequested(*value);
+
+      // Since boolean stage input variables are represented as unsigned
+      // integers, after loading them, we should cast them to boolean.
+      if (isBooleanStageIOVar(decl, srcQualType, semanticKind,
+                              sigPoint->GetKind())) {
+        *value = theEmitter.castToType(*value, type, srcQualType,
+                                       decl->getLocation());
+      }
     } else {
       if (noWriteBack)
         return true;
@@ -1657,6 +1697,14 @@ bool DeclResultIdMapper::createStageVars(const hlsl::SigPoint *sigPoint,
         ptr = theBuilder.createAccessChain(ptrType, varId, index);
         theBuilder.createStore(ptr, *value);
       }
+      // Since boolean output stage variables are represented as unsigned
+      // integers, we must cast the value to uint before storing.
+      else if (isBooleanStageIOVar(decl, srcQualType, semanticKind,
+                                   sigPoint->GetKind())) {
+        *value = theEmitter.castToType(*value, srcQualType, type,
+                                       decl->getLocation());
+        theBuilder.createStore(ptr, *value);
+      }
       // For all normal cases
       else {
         theBuilder.createStore(ptr, *value);
@@ -1833,6 +1881,13 @@ bool DeclResultIdMapper::writeBackOutputStream(const NamedDecl *decl,
     if (semanticInfo.semantic->GetKind() == hlsl::Semantic::Kind::Position)
       value = invertYIfRequested(value);
 
+    // Boolean stage output variables are represented as unsigned integers.
+    if (isBooleanStageIOVar(decl, type, semanticInfo.semantic->GetKind(),
+                            hlsl::SigPoint::Kind::GSOut)) {
+      QualType uintType = typeTranslator.getUintTypeWithSourceComponents(type);
+      value = theEmitter.castToType(value, type, uintType, decl->getLocation());
+    }
+
     theBuilder.createStore(found->second, value);
     return true;
   }

+ 14 - 11
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -32,6 +32,8 @@
 namespace clang {
 namespace spirv {
 
+class SPIRVEmitter;
+
 /// A struct containing information about a particular HLSL semantic.
 struct SemanticInfo {
   llvm::StringRef str;            ///< The original semantic string
@@ -257,7 +259,8 @@ private:
 class DeclResultIdMapper {
 public:
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            ModuleBuilder &builder, TypeTranslator &translator,
+                            ModuleBuilder &builder, SPIRVEmitter &emitter,
+                            TypeTranslator &translator,
                             FeatureManager &features,
                             const SpirvCodeGenOptions &spirvOptions);
 
@@ -628,6 +631,7 @@ private:
 private:
   const hlsl::ShaderModel &shaderModel;
   ModuleBuilder &theBuilder;
+  SPIRVEmitter &theEmitter;
   const SpirvCodeGenOptions &spirvOptions;
   ASTContext &astContext;
   DiagnosticsEngine &diags;
@@ -742,16 +746,15 @@ void CounterIdAliasPair::assign(const CounterIdAliasPair &srcPair,
   builder.createStore(resultId, srcPair.get(builder, translator));
 }
 
-DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
-                                       ASTContext &context,
-                                       ModuleBuilder &builder,
-                                       TypeTranslator &translator,
-                                       FeatureManager &features,
-                                       const SpirvCodeGenOptions &options)
-    : shaderModel(model), theBuilder(builder), spirvOptions(options),
-      astContext(context), diags(context.getDiagnostics()),
-      typeTranslator(translator), entryFunctionId(0), laneCountBuiltinId(0),
-      laneIndexBuiltinId(0), needsLegalization(false),
+DeclResultIdMapper::DeclResultIdMapper(
+    const hlsl::ShaderModel &model, ASTContext &context, ModuleBuilder &builder,
+    SPIRVEmitter &emitter, TypeTranslator &translator, FeatureManager &features,
+    const SpirvCodeGenOptions &options)
+    : shaderModel(model), theBuilder(builder), theEmitter(emitter),
+      spirvOptions(options), astContext(context),
+      diags(context.getDiagnostics()), typeTranslator(translator),
+      entryFunctionId(0), laneCountBuiltinId(0), laneIndexBuiltinId(0),
+      needsLegalization(false),
       glPerVertex(model, context, builder, typeTranslator) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -599,7 +599,7 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
       theContext(), featureManager(diags, spirvOptions),
       theBuilder(&theContext, &featureManager, spirvOptions),
       typeTranslator(astContext, theBuilder, diags, spirvOptions),
-      declIdMapper(shaderModel, astContext, theBuilder, typeTranslator,
+      declIdMapper(shaderModel, astContext, theBuilder, *this, typeTranslator,
                    featureManager, spirvOptions),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
       seenPushConstantAt(), isSpecConstantMode(false),

+ 24 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -74,6 +74,30 @@ const hlsl::ConstantPacking *getPackOffset(const NamedDecl *decl) {
 
 } // anonymous namespace
 
+QualType TypeTranslator::getBoolTypeWithSourceComponents(QualType sourceType) {
+  if (isScalarType(sourceType)) {
+    return astContext.BoolTy;
+  }
+  uint32_t elemCount = 0;
+  if (isVectorType(sourceType, nullptr, &elemCount)) {
+    return astContext.getExtVectorType(astContext.BoolTy, elemCount);
+  }
+
+  llvm_unreachable("only scalar and vector types are supported in "
+                   "getBoolTypeWithSourceComponents");
+}
+QualType TypeTranslator::getUintTypeWithSourceComponents(QualType sourceType) {
+  if (isScalarType(sourceType)) {
+    return astContext.UnsignedIntTy;
+  }
+  uint32_t elemCount = 0;
+  if (isVectorType(sourceType, nullptr, &elemCount)) {
+    return astContext.getExtVectorType(astContext.UnsignedIntTy, elemCount);
+  }
+  llvm_unreachable("only scalar and vector types are supported in "
+                   "getUintTypeWithSourceComponents");
+}
+
 bool TypeTranslator::isRelaxedPrecisionType(QualType type,
                                             const SpirvCodeGenOptions &opts) {
   // Primitive types

+ 11 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -233,6 +233,17 @@ public:
   /// matrix type.
   uint32_t getComponentVectorType(QualType matrixType);
 
+  /// \brief Returns the QualType that has the same components as the source
+  /// type, but with boolean element type. For instance, if the source type is a
+  /// vector of 3 integers, returns the QualType for a vector of 3 booleans.
+  /// Supports only scalars and vectors.
+  QualType getBoolTypeWithSourceComponents(QualType srceType);
+  /// \brief Returns the QualType that has the same components as the source
+  /// type, but with 32-bit uint element type. For instance, if the source type
+  /// is a vector of 3 booleans, returns the QualType for a vector of 3 uints.
+  /// Supports only scalars and vectors.
+  QualType getUintTypeWithSourceComponents(QualType srceType);
+
   /// \brief Returns true if all members in structType are of the same element
   /// type and can be fit into a 4-component vector. Writes element type and
   /// count to *elemType and *elemCount if not nullptr. Otherwise, emit errors

+ 1 - 1
tools/clang/test/CodeGenSPIRV/semantic.is-front-face.gs.hlsl

@@ -5,7 +5,7 @@
 
 // CHECK:      OpDecorate %out_var_SV_IsFrontFace Location 0
 
-// CHECK:      %out_var_SV_IsFrontFace = OpVariable %_ptr_Output_bool Output
+// CHECK:      %out_var_SV_IsFrontFace = OpVariable %_ptr_Output_uint Output
 
 // GS per-vertex input
 struct GsVIn {

+ 33 - 20
tools/clang/test/CodeGenSPIRV/spirv.entry-function.wrapper.hlsl

@@ -9,40 +9,53 @@ struct S {
 struct T {
     S x;
     int y: D;
+    bool2 z : E;
 };
 
-// CHECK:       %in_var_A = OpVariable %_ptr_Input_bool Input
+// CHECK: [[v2uint0:%\d+]] = OpConstantComposite %v2uint %uint_0 %uint_0
+// CHECK: [[v2uint1:%\d+]] = OpConstantComposite %v2uint %uint_1 %uint_1
+
+// CHECK:       %in_var_A = OpVariable %_ptr_Input_uint Input
 // CHECK-NEXT:  %in_var_B = OpVariable %_ptr_Input_v2uint Input
 // CHECK-NEXT:  %in_var_C = OpVariable %_ptr_Input_mat2v3float Input
 // CHECK-NEXT:  %in_var_D = OpVariable %_ptr_Input_int Input
+// CHECK-NEXT:  %in_var_E = OpVariable %_ptr_Input_v2uint Input
 
-// CHECK-NEXT: %out_var_A = OpVariable %_ptr_Output_bool Output
+// CHECK-NEXT: %out_var_A = OpVariable %_ptr_Output_uint Output
 // CHECK-NEXT: %out_var_B = OpVariable %_ptr_Output_v2uint Output
 // CHECK-NEXT: %out_var_C = OpVariable %_ptr_Output_mat2v3float Output
 // CHECK-NEXT: %out_var_D = OpVariable %_ptr_Output_int Output
+// CHECK-NEXT: %out_var_E = OpVariable %_ptr_Output_v2uint Output
 
 // CHECK-NEXT:        %main = OpFunction %void None {{%\d+}}
 // CHECK-NEXT:     {{%\d+}} = OpLabel
 
 // CHECK-NEXT: %param_var_input = OpVariable %_ptr_Function_T Function
-// CHECK-NEXT: [[inA:%\d+]] = OpLoad %bool %in_var_A
-// CHECK-NEXT: [[inB:%\d+]] = OpLoad %v2uint %in_var_B
-// CHECK-NEXT: [[inC:%\d+]] = OpLoad %mat2v3float %in_var_C
-// CHECK-NEXT: [[inS:%\d+]] = OpCompositeConstruct %S [[inA]] [[inB]] [[inC]]
-// CHECK-NEXT: [[inD:%\d+]] = OpLoad %int %in_var_D
-// CHECK-NEXT: [[inT:%\d+]] = OpCompositeConstruct %T [[inS]] [[inD]]
-// CHECK-NEXT:                OpStore %param_var_input [[inT]]
-
-// CHECK-NEXT:  [[ret:%\d+]] = OpFunctionCall %T %src_main %param_var_input
-// CHECK-NEXT: [[outS:%\d+]] = OpCompositeExtract %S [[ret]] 0
-// CHECK-NEXT: [[outA:%\d+]] = OpCompositeExtract %bool [[outS]] 0
-// CHECK-NEXT:                 OpStore %out_var_A [[outA]]
-// CHECK-NEXT: [[outB:%\d+]] = OpCompositeExtract %v2uint [[outS]] 1
-// CHECK-NEXT:                 OpStore %out_var_B [[outB]]
-// CHECK-NEXT: [[outC:%\d+]] = OpCompositeExtract %mat2v3float [[outS]] 2
-// CHECK-NEXT:                 OpStore %out_var_C [[outC]]
-// CHECK-NEXT: [[outD:%\d+]] = OpCompositeExtract %int [[ret]] 1
-// CHECK-NEXT:                 OpStore %out_var_D [[outD]]
+// CHECK-NEXT:     [[inA:%\d+]] = OpLoad %uint %in_var_A
+// CHECK-NEXT: [[inAbool:%\d+]] = OpINotEqual %bool [[inA]] %uint_0
+// CHECK-NEXT:     [[inB:%\d+]] = OpLoad %v2uint %in_var_B
+// CHECK-NEXT:     [[inC:%\d+]] = OpLoad %mat2v3float %in_var_C
+// CHECK-NEXT:     [[inS:%\d+]] = OpCompositeConstruct %S [[inAbool]] [[inB]] [[inC]]
+// CHECK-NEXT:     [[inD:%\d+]] = OpLoad %int %in_var_D
+// CHECK-NEXT:     [[inE:%\d+]] = OpLoad %v2uint %in_var_E
+// CHECK-NEXT: [[inEbool:%\d+]] = OpINotEqual %v2bool [[inE]] [[v2uint0]]
+// CHECK-NEXT:     [[inT:%\d+]] = OpCompositeConstruct %T [[inS]] [[inD]] [[inEbool]]
+// CHECK-NEXT:                    OpStore %param_var_input [[inT]]
+
+// CHECK-NEXT:      [[ret:%\d+]] = OpFunctionCall %T %src_main %param_var_input
+// CHECK-NEXT:     [[outS:%\d+]] = OpCompositeExtract %S [[ret]] 0
+// CHECK-NEXT:     [[outA:%\d+]] = OpCompositeExtract %bool [[outS]] 0
+// CHECK-NEXT: [[outAuint:%\d+]] = OpSelect %uint [[outA]] %uint_1 %uint_0
+// CHECK-NEXT:                     OpStore %out_var_A [[outAuint]]
+// CHECK-NEXT:     [[outB:%\d+]] = OpCompositeExtract %v2uint [[outS]] 1
+// CHECK-NEXT:                     OpStore %out_var_B [[outB]]
+// CHECK-NEXT:     [[outC:%\d+]] = OpCompositeExtract %mat2v3float [[outS]] 2
+// CHECK-NEXT:                     OpStore %out_var_C [[outC]]
+// CHECK-NEXT:     [[outD:%\d+]] = OpCompositeExtract %int [[ret]] 1
+// CHECK-NEXT:                     OpStore %out_var_D [[outD]]
+// CHECK-NEXT:     [[outE:%\d+]] = OpCompositeExtract %v2bool [[ret]] 2
+// CHECK-NEXT: [[outEuint:%\d+]] = OpSelect %v2uint [[outE]] [[v2uint1]] [[v2uint0]]
+// CHECK-NEXT:                     OpStore %out_var_E [[outEuint]]
 
 // CHECK-NEXT:                 OpReturn
 // CHECK-NEXT:                 OpFunctionEnd