Преглед на файлове

[spirv] Add support for basic geometry shader (#772)

This commit supports .Append() and .RestartStrip() method calls
on stream-output objects, which will be translated into SPIR-V
OpEmitVertex and OpEndPrimitive, respectively.

For each .Append() call, all affected stage output variables
will be flushed.
Lei Zhang преди 7 години
родител
ревизия
37eade8495

+ 9 - 1
docs/SPIR-V.rst

@@ -2073,7 +2073,15 @@ given output stream.
 |``TriangleStream``   | ``OutputTriangleStrip``     |
 +---------------------+-----------------------------+
 
-TODO: Describe more details about how geometry shaders are translated. e.g. OutputStreams, etc.
+In other shader stages, stage output variables are only written in the `entry
+function wrapper`_ after calling the source code entry function. However,
+geometry shaders can output as many vertices as they wish, by calling the
+``.Append()`` method on the output stream object. Therefore, it is incorrect to
+have only one flush in the entry function wrapper like other stages. Instead,
+each time a ``*Stream<T>::Append()`` is encountered, all stage output variables
+behind ``T`` will be flushed before SPIR-V ``OpEmitVertex`` instruction is
+generated. ``.RestartStrip()`` method calls will be translated into the SPIR-V
+``OpEndPrimitive`` instruction.
 
 Vulkan Command-line Options
 ===========================

+ 6 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -254,6 +254,12 @@ public:
   /// \brief Creates an OpControlBarrier instruction with the given flags.
   void createControlBarrier(uint32_t exec, uint32_t memory, uint32_t semantics);
 
+  /// \brief Creates an OpEmitVertex instruction.
+  void createEmitVertex();
+
+  /// \brief Creates an OpEndPrimitive instruction.
+  void createEndPrimitive();
+
   // === SPIR-V Module Structure ===
 
   inline void requireCapability(spv::Capability);

+ 107 - 23
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -64,11 +64,11 @@ ResourceVar::Category getResourceCategory(QualType type) {
 
 /// \brief Returns true if the given declaration has a primitive type qualifier.
 /// Returns false otherwise.
-bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
-  return (decl->hasAttr<HLSLTriangleAttr>() ||
-          decl->hasAttr<HLSLTriangleAdjAttr>() ||
-          decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAdjAttr>() ||
-          decl->hasAttr<HLSLLineAttr>());
+inline bool hasGSPrimitiveTypeQualifier(const Decl *decl) {
+  return decl->hasAttr<HLSLTriangleAttr>() ||
+         decl->hasAttr<HLSLTriangleAdjAttr>() ||
+         decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
+         decl->hasAttr<HLSLLineAdjAttr>();
 }
 
 /// \brief Deduces the parameter qualifier for the given decl.
@@ -128,8 +128,13 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   // none of them should be created as arrays.
   assert(sigPoint->GetKind() != hlsl::DXIL::SigPointKind::HSCPOut);
 
-  return createStageVars(decl, sigPoint, /*asInput=*/false, type,
-                         /*arraySize=*/0, llvm::None, &storedValue, "out.var");
+  return createStageVars(
+      sigPoint, decl, /*asInput=*/false, type,
+      /*arraySize=*/0, "out.var", llvm::None, &storedValue,
+      // Write back of stage output variables in GS is manually controlled by
+      // .Append() intrinsic method, implemented in writeBackOutputStream().
+      // So noWriteBack should be set to true for GS.
+      shaderModel.IsGS());
 }
 
 bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
@@ -143,8 +148,9 @@ bool DeclResultIdMapper::createStageOutputVar(const DeclaratorDecl *decl,
   const auto *sigPoint =
       hlsl::SigPoint::GetSigPoint(hlsl::DXIL::SigPointKind::HSCPOut);
 
-  return createStageVars(decl, sigPoint, /*asInput=*/false, type, arraySize,
-                         invocationId, &storedValue, "out.var");
+  return createStageVars(sigPoint, decl, /*asInput=*/false, type, arraySize,
+                         "out.var", invocationId, &storedValue,
+                         /*noWriteBack=*/false);
 }
 
 bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
@@ -162,12 +168,18 @@ bool DeclResultIdMapper::createStageInputVar(const ParmVarDecl *paramDecl,
     arraySize = hlsl::GetHLSLOutputPatchCount(type);
     type = hlsl::GetHLSLOutputPatchElementType(type);
   }
+  if (hasGSPrimitiveTypeQualifier(paramDecl)) {
+    const auto *typeDecl = astContext.getAsConstantArrayType(type);
+    arraySize = static_cast<uint32_t>(typeDecl->getSize().getZExtValue());
+    type = typeDecl->getElementType();
+  }
 
   const auto *sigPoint = deduceSigPoint(paramDecl, /*asInput=*/true,
                                         shaderModel.GetKind(), forPCF);
 
-  return createStageVars(paramDecl, sigPoint, /*asInput=*/true, type, arraySize,
-                         llvm::None, loadedValue, "in.var");
+  return createStageVars(sigPoint, paramDecl, /*asInput=*/true, type, arraySize,
+                         "in.var", llvm::None, loadedValue,
+                         /*noWriteBack=*/false);
 }
 
 const DeclResultIdMapper::DeclSpirvInfo *
@@ -746,9 +758,9 @@ bool DeclResultIdMapper::decorateResourceBindings() {
 }
 
 bool DeclResultIdMapper::createStageVars(
-    const DeclaratorDecl *decl, const hlsl::SigPoint *sigPoint, bool asInput,
-    QualType type, uint32_t arraySize, llvm::Optional<uint32_t> invocationId,
-    uint32_t *value, const llvm::Twine &namePrefix) {
+    const hlsl::SigPoint *sigPoint, const DeclaratorDecl *decl, bool asInput,
+    QualType type, uint32_t arraySize, const llvm::Twine &namePrefix,
+    llvm::Optional<uint32_t> invocationId, uint32_t *value, bool noWriteBack) {
   // invocationId should only be used for handling HS per-vertex output.
   if (invocationId.hasValue()) {
     assert(shaderModel.IsHS() && arraySize != 0 && !asInput);
@@ -793,8 +805,9 @@ bool DeclResultIdMapper::createStageVars(
     // * SV_InsideTessFactor is a single float for tri patch, and an array of
     //   size 2 for a quad patch, but it must always be an array of size 2 in
     //   SPIR-V for Vulkan.
-    if (glPerVertex.tryToAccess(semanticKind, semanticIndex, invocationId,
-                                value, sigPoint->GetKind()))
+    if (glPerVertex.tryToAccess(sigPoint->GetKind(), semanticKind,
+                                semanticIndex, invocationId, value,
+                                noWriteBack))
       return true;
 
     if (semanticKind == hlsl::Semantic::Kind::DomainLocation)
@@ -822,6 +835,7 @@ bool DeclResultIdMapper::createStageVars(
     stageVar.setSpirvId(varId);
     stageVar.setLocationAttr(decl->getAttr<VKLocationAttr>());
     stageVars.push_back(stageVar);
+    stageVarIds[decl] = varId;
 
     // TODO: the following may not be correct?
     if (sigPoint->GetSignatureKind() ==
@@ -875,6 +889,9 @@ bool DeclResultIdMapper::createStageVars(
             *value, *value, {0, 1});
       }
     } else {
+      if (noWriteBack)
+        return true;
+
       uint32_t ptr = varId;
 
       // Special handling of SV_TessFactor HS patch constant output.
@@ -931,7 +948,7 @@ bool DeclResultIdMapper::createStageVars(
   if (!type->isStructureType()) {
     emitError("semantic string missing for shader %select{output|input}0 "
               "variable '%1'",
-              decl->getLocStart())
+              decl->getLocation())
         << asInput << decl->getName();
     return false;
   }
@@ -945,8 +962,9 @@ bool DeclResultIdMapper::createStageVars(
 
     for (const auto *field : structDecl->fields()) {
       uint32_t subValue = 0;
-      if (!createStageVars(field, sigPoint, asInput, field->getType(),
-                           arraySize, invocationId, &subValue, namePrefix))
+      if (!createStageVars(sigPoint, field, asInput, field->getType(),
+                           arraySize, namePrefix, invocationId, &subValue,
+                           noWriteBack))
         return false;
       subValues.push_back(subValue);
     }
@@ -999,10 +1017,14 @@ bool DeclResultIdMapper::createStageVars(
     // out the value to the correct array element.
     for (const auto *field : structDecl->fields()) {
       const uint32_t fieldType = typeTranslator.translateType(field->getType());
-      uint32_t subValue = theBuilder.createCompositeExtract(
-          fieldType, *value, {field->getFieldIndex()});
-      if (!createStageVars(field, sigPoint, asInput, field->getType(),
-                           arraySize, invocationId, &subValue, namePrefix))
+      uint32_t subValue = 0;
+      if (!noWriteBack)
+        subValue = theBuilder.createCompositeExtract(fieldType, *value,
+                                                     {field->getFieldIndex()});
+
+      if (!createStageVars(sigPoint, field, asInput, field->getType(),
+                           arraySize, namePrefix, invocationId, &subValue,
+                           noWriteBack))
         return false;
     }
   }
@@ -1010,6 +1032,68 @@ bool DeclResultIdMapper::createStageVars(
   return true;
 }
 
+bool DeclResultIdMapper::writeBackOutputStream(const ValueDecl *decl,
+                                               uint32_t value) {
+  assert(shaderModel.IsGS()); // Only for GS use
+
+  QualType type = decl->getType();
+
+  if (hlsl::IsHLSLStreamOutputType(type))
+    type = hlsl::GetHLSLResourceResultType(type);
+  if (hasGSPrimitiveTypeQualifier(decl))
+    type = astContext.getAsConstantArrayType(type)->getElementType();
+
+  llvm::StringRef semanticStr;
+  const hlsl::Semantic *semantic = {};
+  uint32_t semanticIndex = {};
+
+  if (getStageVarSemantic(decl, &semanticStr, &semantic, &semanticIndex)) {
+    // Found semantic attached directly to this Decl. Write the value for this
+    // Decl to the corresponding stage output variable.
+
+    const uint32_t srcTypeId = typeTranslator.translateType(type);
+
+    // Handle SV_Position, SV_ClipDistance, and SV_CullDistance
+    if (glPerVertex.tryToAccess(hlsl::DXIL::SigPointKind::GSOut,
+                                semantic->GetKind(), semanticIndex, llvm::None,
+                                &value, /*noWriteBack=*/false))
+      return true;
+
+    // Query the <result-id> for the stage output variable generated out
+    // of this decl.
+    const auto found = stageVarIds.find(decl);
+
+    // We should have recorded its stage output variable previously.
+    assert(found != stageVarIds.end());
+
+    theBuilder.createStore(found->second, value);
+    return true;
+  }
+
+  // If the decl itself doesn't have semantic string attached, it should be
+  // a struct having all its fields with semantic strings.
+  if (!type->isStructureType()) {
+    emitError("semantic string missing for shader output variable '%0'",
+              decl->getLocation())
+        << decl->getName();
+    return false;
+  }
+
+  const auto *structDecl = cast<RecordType>(type.getTypePtr())->getDecl();
+
+  // Write out each field
+  for (const auto *field : structDecl->fields()) {
+    const uint32_t fieldType = typeTranslator.translateType(field->getType());
+    const uint32_t subValue = theBuilder.createCompositeExtract(
+        fieldType, value, {field->getFieldIndex()});
+
+    if (!writeBackOutputStream(field, subValue))
+      return false;
+  }
+
+  return true;
+}
+
 void DeclResultIdMapper::decoratePSInterpolationMode(const DeclaratorDecl *decl,
                                                      QualType type,
                                                      uint32_t varId) {

+ 33 - 9
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -140,7 +140,7 @@ private:
 class DeclResultIdMapper {
 public:
   inline DeclResultIdMapper(const hlsl::ShaderModel &stage, ASTContext &context,
-                            ModuleBuilder &builder, DiagnosticsEngine &diag,
+                            ModuleBuilder &builder,
                             const EmitSPIRVOptions &spirvOptions);
 
   /// \brief Creates the stage output variables by parsing the semantics
@@ -252,6 +252,18 @@ public:
   /// mapper.
   std::vector<uint32_t> collectStageVars() const;
 
+  /// \brief Writes out the contents in the function parameter for the GS
+  /// stream output to the corresponding stage output variables in a recursive
+  /// manner. Returns true on success, false if errors occur.
+  ///
+  /// decl is the Decl with semantic string attached and will be used to find
+  /// the stage output variable to write to, value is the <result-id> for the
+  /// SPIR-V variable to read data from.
+  ///
+  /// This method is specially for writing back per-vertex data at the time of
+  /// OpEmitVertex in GS.
+  bool writeBackOutputStream(const ValueDecl *decl, uint32_t value);
+
   /// \brief Decorates all stage input and output variables with proper
   /// location and returns true on success.
   ///
@@ -310,18 +322,21 @@ private:
   /// For HS/DS/GS, the outermost arrayness should be discarded and use
   /// arraySize instead.
   ///
-  /// Also performs updating the stage variables (loading/storing from/to the
-  /// given value) depending on asInput.
+  /// Also performs reading the stage variables and compose a temporary value
+  /// of the given type and writing into *value, if asInput is true. Otherwise,
+  /// Decomposes the *value according to type and writes back into the stage
+  /// output variables, unless noWriteBack is set to true. noWriteBack is used
+  /// by GS since in GS we manually control write back using .Append() method.
   ///
   /// invocationId is only used for HS to indicate the index of the output
   /// array element to write to.
   ///
   /// Assumes the decl has semantic attached to itself or to its fields.
-  bool createStageVars(const DeclaratorDecl *decl,
-                       const hlsl::SigPoint *sigPoint, bool asInput,
-                       QualType type, uint32_t arraySize,
+  bool createStageVars(const hlsl::SigPoint *sigPoint,
+                       const DeclaratorDecl *decl, bool asInput, QualType type,
+                       uint32_t arraySize, const llvm::Twine &namePrefix,
                        llvm::Optional<uint32_t> invocationId, uint32_t *value,
-                       const llvm::Twine &namePrefix);
+                       bool noWriteBack);
 
   /// Creates the SPIR-V variable instruction for the given StageVar and returns
   /// the <result-id>. Also sets whether the StageVar is a SPIR-V builtin and
@@ -349,6 +364,7 @@ private:
   const hlsl::ShaderModel &shaderModel;
   ModuleBuilder &theBuilder;
   const EmitSPIRVOptions &spirvOptions;
+  ASTContext &astContext;
   DiagnosticsEngine &diags;
 
   TypeTranslator typeTranslator;
@@ -359,6 +375,14 @@ private:
   llvm::DenseMap<const NamedDecl *, DeclSpirvInfo> astDecls;
   /// Vector of all defined stage variables.
   llvm::SmallVector<StageVar, 8> stageVars;
+  /// Mapping from Clang AST decls to the corresponding stage variables'
+  /// <result-id>s.
+  /// This field is only used by GS for manually emitting vertices, when
+  /// we need to query the <result-id> of the output stage variables
+  /// involved in writing back. For other cases, stage variable reading
+  /// and writing is done at the time of creating that stage variable,
+  /// so that we don't need to query them again for reading and writing.
+  llvm::DenseMap<const NamedDecl *, uint32_t> stageVarIds;
   /// Vector of all defined resource variables.
   llvm::SmallVector<ResourceVar, 8> resourceVars;
   /// Mapping from {RW|Append|Consume}StructuredBuffers to their
@@ -373,10 +397,10 @@ public:
 DeclResultIdMapper::DeclResultIdMapper(const hlsl::ShaderModel &model,
                                        ASTContext &context,
                                        ModuleBuilder &builder,
-                                       DiagnosticsEngine &diag,
                                        const EmitSPIRVOptions &options)
     : shaderModel(model), theBuilder(builder), spirvOptions(options),
-      diags(diag), typeTranslator(context, builder, diag), entryFunctionId(0),
+      astContext(context), diags(context.getDiagnostics()),
+      typeTranslator(context, builder, diags), entryFunctionId(0),
       glPerVertex(model, context, builder, typeTranslator) {}
 
 bool DeclResultIdMapper::decorateStageIOLocations() {

+ 48 - 13
tools/clang/lib/SPIRV/GlPerVertex.cpp

@@ -11,6 +11,7 @@
 
 #include <algorithm>
 
+#include "clang/AST/Attr.h"
 #include "clang/AST/HlslTypes.h"
 
 namespace clang {
@@ -44,6 +45,15 @@ inline QualType getTypeOrFnRetType(const DeclaratorDecl *decl) {
   }
   return decl->getType();
 }
+
+/// Returns true if the given declaration has a primitive type qualifier.
+/// Returns false otherwise.
+inline bool hasGSPrimitiveTypeQualifier(const DeclaratorDecl *decl) {
+  return decl->hasAttr<HLSLTriangleAttr>() ||
+         decl->hasAttr<HLSLTriangleAdjAttr>() ||
+         decl->hasAttr<HLSLPointAttr>() || decl->hasAttr<HLSLLineAttr>() ||
+         decl->hasAttr<HLSLLineAdjAttr>();
+}
 } // anonymous namespace
 
 GlPerVertex::GlPerVertex(const hlsl::ShaderModel &sm, ASTContext &context,
@@ -170,14 +180,21 @@ bool GlPerVertex::doClipCullDistanceDecl(const DeclaratorDecl *decl,
       return doClipCullDistanceDecl(
           decl, hlsl::GetHLSLOutputPatchElementType(baseType), asInput);
     }
+
     if (hlsl::IsHLSLStreamOutputType(baseType)) {
       return doClipCullDistanceDecl(
           decl, hlsl::GetHLSLOutputPatchElementType(baseType), asInput);
     }
+    if (hasGSPrimitiveTypeQualifier(decl)) {
+      // GS inputs have an additional arrayness that we should remove to check
+      // the underlying type instead.
+      baseType = astContext.getAsConstantArrayType(baseType)->getElementType();
+      return doClipCullDistanceDecl(decl, baseType, asInput);
+    }
 
     emitError("semantic string missing for shader %select{output|input}0 "
               "variable '%1'",
-              decl->getLocStart())
+              decl->getLocation())
         << asInput << decl->getName();
     return false;
   }
@@ -356,34 +373,52 @@ uint32_t GlPerVertex::createCullDistanceVar(bool asInput, uint32_t arraySize) {
   return theBuilder.addStageBuiltinVar(type, sc, spv::BuiltIn::CullDistance);
 }
 
-bool GlPerVertex::tryToAccess(hlsl::Semantic::Kind semanticKind,
+bool GlPerVertex::tryToAccess(hlsl::SigPoint::Kind sigPointKind,
+                              hlsl::Semantic::Kind semanticKind,
                               uint32_t semanticIndex,
                               llvm::Optional<uint32_t> invocationId,
-                              uint32_t *value,
-                              hlsl::SigPoint::Kind sigPointKind) {
+                              uint32_t *value, bool noWriteBack) {
   // invocationId should only be used for HSPCOut.
   assert(invocationId.hasValue() ? sigPointKind == hlsl::SigPoint::Kind::HSCPOut
                                  : true);
 
+  switch (semanticKind) {
+  case hlsl::Semantic::Kind::Position:
+  case hlsl::Semantic::Kind::ClipDistance:
+  case hlsl::Semantic::Kind::CullDistance:
+    // gl_PerVertex only cares about these builtins.
+    break;
+  default:
+    return false; // Fall back to the normal path
+  }
+
   switch (sigPointKind) {
+  case hlsl::SigPoint::Kind::PSIn:
+    // We don't handle stand-alone Position builtin in this class.
+    if (semanticKind == hlsl::Semantic::Kind::Position)
+      return false; // Fall back to the normal path
+
+  // Fall through
+
   case hlsl::SigPoint::Kind::HSCPIn:
   case hlsl::SigPoint::Kind::DSCPIn:
   case hlsl::SigPoint::Kind::GSVIn:
     return readField(semanticKind, semanticIndex, value);
-  case hlsl::SigPoint::Kind::PSIn:
+
+  case hlsl::SigPoint::Kind::GSOut:
     // We don't handle stand-alone Position builtin in this class.
-    return semanticKind == hlsl::Semantic::Kind::Position
-               ? 0 // Fall back to the normal path
-               : readField(semanticKind, semanticIndex, value);
+    if (semanticKind == hlsl::Semantic::Kind::Position)
+      return false; // Fall back to the normal path
+
+  // Fall through
+
   case hlsl::SigPoint::Kind::VSOut:
   case hlsl::SigPoint::Kind::HSCPOut:
   case hlsl::SigPoint::Kind::DSOut:
+    if (noWriteBack)
+      return true;
+
     return writeField(semanticKind, semanticIndex, invocationId, value);
-  case hlsl::SigPoint::Kind::GSOut:
-    // We don't handle stand-alone Position builtin in this class.
-    return semanticKind == hlsl::Semantic::Kind::Position
-               ? 0 // Fall back to the normal path
-               : writeField(semanticKind, semanticIndex, invocationId, value);
   }
 
   return false;

+ 13 - 6
tools/clang/lib/SPIRV/GlPerVertex.h

@@ -84,15 +84,22 @@ public:
   void requireCapabilityIfNecessary();
 
   /// Tries to access the builtin translated from the given HLSL semantic of the
-  /// given index. If sigPoint indicates this is input, builtins will be read
-  /// to compose a new temporary value of the correct type and writes to *value.
-  /// Otherwise, the *value will be decomposed and writes to the builtins.
+  /// given index.
+  ///
+  /// If sigPoint indicates this is input, builtins will be read to compose a
+  /// new temporary value of the correct type and writes to *value. Otherwise,
+  /// the *value will be decomposed and writes to the builtins, unless
+  /// noWriteBack is true, which means do not write back the value.
+  ///
+  /// If invocation (should only be used for HS) is not llvm::None, only
+  /// accesses the element at the invocation offset in the gl_PerVeterx array.
+  ///
   /// Emits SPIR-V instructions and returns true if we are accessing builtins
   /// belonging to gl_PerVertex. Does nothing and returns true if we are
   /// accessing builtins not in gl_PerVertex. Returns false if errors occurs.
-  bool tryToAccess(hlsl::Semantic::Kind, uint32_t semanticIndex,
-                   llvm::Optional<uint32_t> invocationId, uint32_t *value,
-                   hlsl::SigPoint::Kind sigPoint);
+  bool tryToAccess(hlsl::SigPoint::Kind sigPoint, hlsl::Semantic::Kind,
+                   uint32_t semanticIndex, llvm::Optional<uint32_t> invocation,
+                   uint32_t *value, bool noWriteBack);
 
 private:
   template <unsigned N>

+ 12 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -560,6 +560,18 @@ void ModuleBuilder::createControlBarrier(uint32_t execution, uint32_t memory,
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
+void ModuleBuilder::createEmitVertex() {
+  assert(insertPoint && "null insert point");
+  instBuilder.opEmitVertex().x();
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
+void ModuleBuilder::createEndPrimitive() {
+  assert(insertPoint && "null insert point");
+  instBuilder.opEndPrimitive().x();
+  insertPoint->appendInstruction(std::move(constructSite));
+}
+
 void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
                                      spv::ExecutionMode em,
                                      llvm::ArrayRef<uint32_t> params) {

+ 37 - 3
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -271,7 +271,7 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       shaderModel(*hlsl::ShaderModel::GetByName(
           ci.getCodeGenOpts().HLSLProfile.c_str())),
       theContext(), theBuilder(&theContext),
-      declIdMapper(shaderModel, astContext, theBuilder, diags, spirvOptions),
+      declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
       typeTranslator(astContext, theBuilder, diags), entryFunctionId(0),
       curFunction(nullptr), curThis(0), needsLegalization(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
@@ -2309,6 +2309,27 @@ SPIRVEmitter::processACSBufferAppendConsume(const CXXMemberCallExpr *expr) {
   }
 }
 
+uint32_t
+SPIRVEmitter::processStreamOutputAppend(const CXXMemberCallExpr *expr) {
+  // TODO: handle multiple stream-output objects
+  const auto *object =
+      expr->getImplicitObjectArgument()->IgnoreParenNoopCasts(astContext);
+  const auto *stream = cast<DeclRefExpr>(object)->getDecl();
+  const uint32_t value = doExpr(expr->getArg(0));
+
+  declIdMapper.writeBackOutputStream(stream, value);
+  theBuilder.createEmitVertex();
+
+  return 0;
+}
+
+uint32_t
+SPIRVEmitter::processStreamOutputRestart(const CXXMemberCallExpr *expr) {
+  // TODO: handle multiple stream-output objects
+  theBuilder.createEndPrimitive();
+  return 0;
+}
+
 SpirvEvalInfo SPIRVEmitter::doCXXMemberCallExpr(const CXXMemberCallExpr *expr) {
   const FunctionDecl *callee = expr->getDirectCallee();
 
@@ -2396,8 +2417,15 @@ SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
         spv::Op::OpBitcast, theBuilder.getUint32Type(),
         incDecRWACSBufferCounter(expr, /*isInc*/ false));
   case IntrinsicOp::MOP_Append:
+    if (hlsl::IsHLSLStreamOutputType(
+            expr->getImplicitObjectArgument()->getType()))
+      return processStreamOutputAppend(expr);
+    else
+      return processACSBufferAppendConsume(expr);
   case IntrinsicOp::MOP_Consume:
     return processACSBufferAppendConsume(expr);
+  case IntrinsicOp::MOP_RestartStrip:
+    return processStreamOutputRestart(expr);
   case IntrinsicOp::MOP_InterlockedAdd:
   case IntrinsicOp::MOP_InterlockedAnd:
   case IntrinsicOp::MOP_InterlockedOr:
@@ -2412,7 +2440,7 @@ SPIRVEmitter::processIntrinsicMemberCall(const CXXMemberCallExpr *expr,
     return processRWByteAddressBufferAtomicMethods(opcode, expr);
   }
 
-  emitError("HLSL intrinsic member call unimplemented: %0")
+  emitError("intrinsic '%0' method unimplemented")
       << expr->getDirectCallee()->getName();
   return 0;
 }
@@ -5810,7 +5838,13 @@ bool SPIRVEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     if (canActAsOutParmVar(param)) {
       // Load the value from the parameter after function call
       const uint32_t typeId = typeTranslator.translateType(param->getType());
-      const uint32_t loadedParam = theBuilder.createLoad(typeId, params[i]);
+      uint32_t loadedParam = 0;
+
+      // Write back of stage output variables in GS is manually controlled by
+      // .Append() intrinsic method. No need to load the parameter since we
+      // won't need to write back here.
+      if (!shaderModel.IsGS())
+        loadedParam = theBuilder.createLoad(typeId, params[i]);
 
       if (!declIdMapper.createStageOutputVar(param, loadedParam, false))
         return false;

+ 7 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -617,6 +617,13 @@ private:
   /// the loaded value for .Consume; returns zero for .Append().
   SpirvEvalInfo processACSBufferAppendConsume(const CXXMemberCallExpr *expr);
 
+  /// \brief Generates SPIR-V instructions to emit the current vertex in GS.
+  uint32_t processStreamOutputAppend(const CXXMemberCallExpr *expr);
+
+  /// \brief Generates SPIR-V instructions to end emitting the current
+  /// primitive in GS.
+  uint32_t processStreamOutputRestart(const CXXMemberCallExpr *expr);
+
 private:
   /// \brief Wrapper method to create a fatal error message and report it
   /// in the diagnostic engine associated with this consumer.

+ 57 - 0
tools/clang/test/CodeGenSPIRV/gs.emit.hlsl

@@ -0,0 +1,57 @@
+// Run: %dxc -T gs_6_0 -E main
+
+struct GsInnerOut {
+    float2 bar  : BAR;
+};
+
+struct GsPerVertexOut {
+    float4 pos  : SV_Position;
+    float3 foo  : FOO;
+    GsInnerOut s;
+};
+
+// CHECK: [[null:%\d+]] = OpConstantNull %GsPerVertexOut
+
+[maxvertexcount(2)]
+void main(in    line float2 foo[2] : FOO,
+          in    line float4 pos[2] : SV_Position,
+          inout      LineStream<GsPerVertexOut> outData)
+{
+// CHECK:            %src_main = OpFunction %void None
+// CHECK:            %bb_entry = OpLabel
+
+// CHECK-NEXT:         %vertex = OpVariable %_ptr_Function_GsPerVertexOut Function
+    GsPerVertexOut vertex;
+// CHECK-NEXT:                   OpStore %vertex [[null]]
+    vertex = (GsPerVertexOut)0;
+
+// Write back to stage output variables
+// CHECK-NEXT: [[vertex:%\d+]] = OpLoad %GsPerVertexOut %vertex
+// CHECK-NEXT:    [[pos:%\d+]] = OpCompositeExtract %v4float [[vertex]] 0
+// CHECK-NEXT:                   OpStore %gl_Position [[pos]]
+// CHECK-NEXT:    [[foo:%\d+]] = OpCompositeExtract %v3float [[vertex]] 1
+// CHECK-NEXT:                   OpStore %out_var_FOO [[foo]]
+// CHECK-NEXT:      [[s:%\d+]] = OpCompositeExtract %GsInnerOut [[vertex]] 2
+// CHECK-NEXT:    [[bar:%\d+]] = OpCompositeExtract %v2float [[s]] 0
+// CHECK-NEXT:                   OpStore %out_var_BAR [[bar]]
+// CHECK-NEXT:                   OpEmitVertex
+
+    outData.Append(vertex);
+
+// Write back to stage output variables
+// CHECK-NEXT: [[vertex:%\d+]] = OpLoad %GsPerVertexOut %vertex
+// CHECK-NEXT:    [[pos:%\d+]] = OpCompositeExtract %v4float [[vertex]] 0
+// CHECK-NEXT:                   OpStore %gl_Position [[pos]]
+// CHECK-NEXT:    [[foo:%\d+]] = OpCompositeExtract %v3float [[vertex]] 1
+// CHECK-NEXT:                   OpStore %out_var_FOO [[foo]]
+// CHECK-NEXT:      [[s:%\d+]] = OpCompositeExtract %GsInnerOut [[vertex]] 2
+// CHECK-NEXT:    [[bar:%\d+]] = OpCompositeExtract %v2float [[s]] 0
+// CHECK-NEXT:                   OpStore %out_var_BAR [[bar]]
+// CHECK-NEXT:                   OpEmitVertex
+    outData.Append(vertex);
+
+// CHECK-NEXT:                   OpEndPrimitive
+    outData.RestartStrip();
+
+// CHECK-NEXT:                   OpReturn
+}

+ 0 - 0
tools/clang/test/CodeGenSPIRV/hull.pcf.input-patch.hlsl → tools/clang/test/CodeGenSPIRV/hs.pcf.input-patch.hlsl


+ 0 - 0
tools/clang/test/CodeGenSPIRV/hull.pcf.output-patch.hlsl → tools/clang/test/CodeGenSPIRV/hs.pcf.output-patch.hlsl


+ 0 - 0
tools/clang/test/CodeGenSPIRV/hull.pcf.primitive-id.1.hlsl → tools/clang/test/CodeGenSPIRV/hs.pcf.primitive-id.1.hlsl


+ 0 - 0
tools/clang/test/CodeGenSPIRV/hull.pcf.primitive-id.2.hlsl → tools/clang/test/CodeGenSPIRV/hs.pcf.primitive-id.2.hlsl


+ 0 - 0
tools/clang/test/CodeGenSPIRV/hull.pcf.void.hlsl → tools/clang/test/CodeGenSPIRV/hs.pcf.void.hlsl


+ 0 - 0
tools/clang/test/CodeGenSPIRV/hull.structure.hlsl → tools/clang/test/CodeGenSPIRV/hs.structure.hlsl


+ 1 - 0
tools/clang/test/CodeGenSPIRV/spirv.interface.ds.hlsl

@@ -2,6 +2,7 @@
 
 // CHECK: OpCapability ClipDistance
 // CHECK: OpCapability CullDistance
+// CHECK: OpCapability Tessellation
 
 // HS PCF output
 

+ 149 - 0
tools/clang/test/CodeGenSPIRV/spirv.interface.gs.hlsl

@@ -0,0 +1,149 @@
+// Run: %dxc -T gs_6_0 -E main
+
+// CHECK: OpCapability ClipDistance
+// CHECK: OpCapability CullDistance
+// CHECK: OpCapability Geometry
+
+struct GsPerVertexIn {
+    float4 pos   : SV_Position;      // Builtin Position
+    float3 clip2 : SV_ClipDistance2; // Builtin ClipDistance
+    float2 clip0 : SV_ClipDistance0; // Builtin ClipDistance
+    float3 foo   : FOO;              // Input variable
+};
+
+struct GsInnerOut {
+    float4 pos   : SV_Position;      // Builtion Position
+    float2 foo   : FOO;              // Output variable
+    float2 cull3 : SV_CullDistance3; // Builtin CullDistance
+};
+
+struct GsPerVertexOut {
+    GsInnerOut s;
+    float  cull2 : SV_CullDistance2; // Builtin CullDistance
+    float4 clip  : SV_ClipDistance;  // Builtin ClipDistance
+    float4 bar   : BAR;              // Output variable
+};
+
+// Input  builtin : gl_PerVertex (Position, ClipDistance)
+// Output builtin : Position, ClipDistance, CullDistance
+// Input  variable: FOO, BAR
+// Output variable: FOO, BAR
+
+// CHECK: OpEntryPoint Geometry %main "main" %gl_PerVertexIn %gl_ClipDistance %gl_CullDistance %in_var_BAR %in_var_FOO %gl_Position %out_var_FOO %out_var_BAR
+
+// CHECK: OpMemberDecorate %type_gl_PerVertex 0 BuiltIn Position
+// CHECK: OpMemberDecorate %type_gl_PerVertex 1 BuiltIn PointSize
+// CHECK: OpMemberDecorate %type_gl_PerVertex 2 BuiltIn ClipDistance
+// CHECK: OpMemberDecorate %type_gl_PerVertex 3 BuiltIn CullDistance
+// CHECK: OpDecorate %type_gl_PerVertex Block
+
+// CHECK: OpDecorate %gl_ClipDistance BuiltIn ClipDistance
+// CHECK: OpDecorate %gl_CullDistance BuiltIn CullDistance
+// CHECK: OpDecorate %gl_Position BuiltIn Position
+
+// CHECK: OpDecorate %in_var_BAR Location 0
+// CHECK: OpDecorate %in_var_FOO Location 1
+// CHECK: OpDecorate %out_var_FOO Location 0
+// CHECK: OpDecorate %out_var_BAR Location 1
+
+// Input : clip0 + clip2 : 5 floats
+// Input : no cull       : 1 floats (default)
+// CHECK: %type_gl_PerVertex = OpTypeStruct %v4float %float %_arr_float_uint_5 %_arr_float_uint_1
+
+// CHECK: %gl_PerVertexIn = OpVariable %_ptr_Input__arr_type_gl_PerVertex_uint_2 Input
+
+// Input : clip          : 4 floats
+// Input : cull2 + cull3 : 3 floats (default)
+// CHECK: %gl_ClipDistance = OpVariable %_ptr_Output__arr_float_uint_4 Output
+// CHECK: %gl_CullDistance = OpVariable %_ptr_Output__arr_float_uint_3 Output
+
+// CHECK: %in_var_BAR = OpVariable %_ptr_Input__arr_v2float_uint_2 Input
+// CHECK: %in_var_FOO = OpVariable %_ptr_Input__arr_v3float_uint_2 Input
+// CHECK: %gl_Position = OpVariable %_ptr_Output_v4float Output
+// CHECK: %out_var_FOO = OpVariable %_ptr_Output_v2float Output
+// CHECK: %out_var_BAR = OpVariable %_ptr_Output_v4float Output
+
+[maxvertexcount(2)]
+void main(in    line float2                     bar   [2] : BAR,
+          in    line GsPerVertexIn              inData[2],
+          inout      LineStream<GsPerVertexOut> outData)
+{
+// Layout of input ClipDistance array:
+//   clip0: 2 floats, offset 0
+//   clip2: 3 floats, offset 2
+
+// Layout of output ClipDistance array:
+//   clip : 4 floats, offset 0
+
+// Layout of output CullDistance array:
+//   cull2: 1 floats, offset 0
+//   cull3: 2 floats, offset 1
+
+    GsPerVertexOut vertex;
+
+    vertex = (GsPerVertexOut)0;
+
+    outData.Append(vertex);
+
+    outData.RestartStrip();
+// CHECK:      [[bar:%\d+]] = OpLoad %_arr_v2float_uint_2 %in_var_BAR
+// CHECK-NEXT:                OpStore %param_var_bar [[bar]]
+
+// Compose an array for GsPerVertexIn::pos
+// CHECK-NEXT:       [[ptr0:%\d+]] = OpAccessChain %_ptr_Input_v4float %gl_PerVertexIn %uint_0 %uint_0
+// CHECK-NEXT:       [[val0:%\d+]] = OpLoad %v4float [[ptr0]]
+// CHECK-NEXT:       [[ptr1:%\d+]] = OpAccessChain %_ptr_Input_v4float %gl_PerVertexIn %uint_1 %uint_0
+// CHECK-NEXT:       [[val1:%\d+]] = OpLoad %v4float [[ptr1]]
+// CHECK-NEXT:   [[inPosArr:%\d+]] = OpCompositeConstruct %_arr_v4float_uint_2 [[val0]] [[val1]]
+
+// Compose an array for GsPerVertexIn::clip2
+// CHECK-NEXT:       [[ptr0:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_0 %uint_2 %uint_2
+// CHECK-NEXT:       [[val0:%\d+]] = OpLoad %float [[ptr0]]
+// CHECK-NEXT:       [[ptr1:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_0 %uint_2 %uint_3
+// CHECK-NEXT:       [[val1:%\d+]] = OpLoad %float [[ptr1]]
+// CHECK-NEXT:       [[ptr2:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_0 %uint_2 %uint_4
+// CHECK-NEXT:       [[val2:%\d+]] = OpLoad %float [[ptr2]]
+// CHECK-NEXT:     [[clip20:%\d+]] = OpCompositeConstruct %v3float [[val0]] [[val1]] [[val2]]
+// CHECK-NEXT:       [[ptr0:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_1 %uint_2 %uint_2
+// CHECK-NEXT:       [[val0:%\d+]] = OpLoad %float [[ptr0]]
+// CHECK-NEXT:       [[ptr1:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_1 %uint_2 %uint_3
+// CHECK-NEXT:       [[val1:%\d+]] = OpLoad %float [[ptr1]]
+// CHECK-NEXT:       [[ptr2:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_1 %uint_2 %uint_4
+// CHECK-NEXT:       [[val2:%\d+]] = OpLoad %float [[ptr2]]
+// CHECK-NEXT:     [[clip21:%\d+]] = OpCompositeConstruct %v3float [[val0]] [[val1]] [[val2]]
+// CHECK-NEXT: [[inClip2Arr:%\d+]] = OpCompositeConstruct %_arr_v3float_uint_2 [[clip20]] [[clip21]]
+
+// Compose an array for GsPerVertexIn::clip0
+// CHECK-NEXT:       [[ptr0:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_0 %uint_2 %uint_0
+// CHECK-NEXT:       [[val0:%\d+]] = OpLoad %float [[ptr0]]
+// CHECK-NEXT:       [[ptr1:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_0 %uint_2 %uint_1
+// CHECK-NEXT:       [[val1:%\d+]] = OpLoad %float [[ptr1]]
+// CHECK-NEXT:     [[clip00:%\d+]] = OpCompositeConstruct %v2float [[val0]] [[val1]]
+// CHECK-NEXT:       [[ptr0:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_1 %uint_2 %uint_0
+// CHECK-NEXT:       [[val0:%\d+]] = OpLoad %float [[ptr0]]
+// CHECK-NEXT:       [[ptr1:%\d+]] = OpAccessChain %_ptr_Input_float %gl_PerVertexIn %uint_1 %uint_2 %uint_1
+// CHECK-NEXT:       [[val1:%\d+]] = OpLoad %float [[ptr1]]
+// CHECK-NEXT:     [[clip01:%\d+]] = OpCompositeConstruct %v2float [[val0]] [[val1]]
+// CHECK-NEXT: [[inClip0Arr:%\d+]] = OpCompositeConstruct %_arr_v2float_uint_2 [[clip00]] [[clip01]]
+
+// CHECK-NEXT:   [[inFooArr:%\d+]] = OpLoad %_arr_v3float_uint_2 %in_var_FOO
+
+// CHECK-NEXT:      [[val0:%\d+]] = OpCompositeExtract %v4float [[inPosArr]] 0
+// CHECK-NEXT:      [[val1:%\d+]] = OpCompositeExtract %v3float [[inClip2Arr]] 0
+// CHECK-NEXT:      [[val2:%\d+]] = OpCompositeExtract %v2float [[inClip0Arr]] 0
+// CHECK-NEXT:      [[val3:%\d+]] = OpCompositeExtract %v3float [[inFooArr]] 0
+// CHECK-NEXT:   [[inData0:%\d+]] = OpCompositeConstruct %GsPerVertexIn [[val0]] [[val1]] [[val2]] [[val3]]
+// CHECK-NEXT:      [[val0:%\d+]] = OpCompositeExtract %v4float [[inPosArr]] 1
+// CHECK-NEXT:      [[val1:%\d+]] = OpCompositeExtract %v3float [[inClip2Arr]] 1
+// CHECK-NEXT:      [[val2:%\d+]] = OpCompositeExtract %v2float [[inClip0Arr]] 1
+// CHECK-NEXT:      [[val3:%\d+]] = OpCompositeExtract %v3float [[inFooArr]] 1
+// CHECK-NEXT:   [[inData1:%\d+]] = OpCompositeConstruct %GsPerVertexIn [[val0]] [[val1]] [[val2]] [[val3]]
+
+// CHECK-NEXT:    [[inData:%\d+]] = OpCompositeConstruct %_arr_GsPerVertexIn_uint_2 [[inData0]] [[inData1]]
+// CHECK-NEXT:                      OpStore %param_var_inData [[inData]]
+
+// CHECK-NEXT:           {{%\d+}} = OpFunctionCall %void %src_main %param_var_bar %param_var_inData %param_var_outData
+
+// No write back after the call
+// CHECK-NEXT:                      OpReturn
+}

+ 1 - 0
tools/clang/test/CodeGenSPIRV/spirv.interface.hs.hlsl

@@ -6,6 +6,7 @@
 
 // CHECK: OpCapability ClipDistance
 // CHECK: OpCapability CullDistance
+// CHECK: OpCapability Tessellation
 
 // Input control point
 struct HsCpIn

+ 14 - 8
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -737,6 +737,9 @@ TEST_F(FileTest, SpirvStageIOInterfaceHS) {
 TEST_F(FileTest, SpirvStageIOInterfaceDS) {
   runFileTest("spirv.interface.ds.hlsl");
 }
+TEST_F(FileTest, SpirvStageIOInterfaceGS) {
+  runFileTest("spirv.interface.gs.hlsl");
+}
 TEST_F(FileTest, SpirvStageIOInterfacePS) {
   runFileTest("spirv.interface.ps.hlsl");
 }
@@ -821,21 +824,24 @@ TEST_F(FileTest, VulkanLayoutConsumeSBufferStd430) {
   runFileTest("vk.layout.csbuffer.std430.hlsl");
 }
 
-// For different Patch Constant Functions (for Hull shaders)
-TEST_F(FileTest, HullShaderPCFVoid) { runFileTest("hull.pcf.void.hlsl"); }
+// HS: for different Patch Constant Functions
+TEST_F(FileTest, HullShaderPCFVoid) { runFileTest("hs.pcf.void.hlsl"); }
 TEST_F(FileTest, HullShaderPCFTakesInputPatch) {
-  runFileTest("hull.pcf.input-patch.hlsl");
+  runFileTest("hs.pcf.input-patch.hlsl");
 }
 TEST_F(FileTest, HullShaderPCFTakesOutputPatch) {
-  runFileTest("hull.pcf.output-patch.hlsl");
+  runFileTest("hs.pcf.output-patch.hlsl");
 }
 TEST_F(FileTest, HullShaderPCFTakesPrimitiveId) {
-  runFileTest("hull.pcf.primitive-id.1.hlsl");
+  runFileTest("hs.pcf.primitive-id.1.hlsl");
 }
 TEST_F(FileTest, HullShaderPCFTakesPrimitiveIdButMainDoesnt) {
-  runFileTest("hull.pcf.primitive-id.2.hlsl");
+  runFileTest("hs.pcf.primitive-id.2.hlsl");
 }
-// For the structure of Hull Shaders
-TEST_F(FileTest, HullShaderStructure) { runFileTest("hull.structure.hlsl"); }
+// HS: for the structure of hull shaders
+TEST_F(FileTest, HullShaderStructure) { runFileTest("hs.structure.hlsl"); }
+
+// GS: emit vertex and emit primitive
+TEST_F(FileTest, GeometryShaderEmit) { runFileTest("gs.emit.hlsl"); }
 
 } // namespace