Browse Source

[spirv] Support SM6.0 wave ops using Vulkan 1.1 (#1118)

Support promoting to SPIR-V 1.3 when necessary

Support SM6.0 wave query and vote ops

* WaveIsFirstLane
* WaveGetLaneCount
* WaveGetLaneIndex
* WaveActiveAnyTrue
* WaveActiveAllTrue
* WaveActiveBallot

Support SM6.0 wave reduction ops

* WaveActiveAllEqual
* WaveActiveCountBits
* WaveActiveSum
* WaveActiveProduct
* WaveActiveBitAnd
* WaveActiveBitOr
* WaveActiveBitXor
* WaveActiveMin
* WaveActiveMax

Support SM6.0 wave scan/prefix ops

* WavePrefixSum
* WavePrefixProduct
* WavePrefixCountBits

Support SM6.0 wave broadcast ops

* WaveReadLaneAt
* WaveReadLaneFirst

Support SM6.0 quad-wide shuffle ops

*  QuadReadAcrossX
*  QuadReadAcrossY
*  QuadReadAcrossDiagonal
*  QuadReadLaneAt
Lei Zhang 7 years ago
parent
commit
c133d935eb
41 changed files with 1223 additions and 76 deletions
  1. 38 15
      docs/SPIR-V.rst
  2. 10 0
      tools/clang/include/clang/SPIRV/InstBuilder.h
  3. 15 3
      tools/clang/include/clang/SPIRV/ModuleBuilder.h
  4. 3 1
      tools/clang/include/clang/SPIRV/Structure.h
  5. 9 3
      tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
  6. 62 0
      tools/clang/lib/SPIRV/InstBuilderManual.cpp
  7. 36 12
      tools/clang/lib/SPIRV/ModuleBuilder.cpp
  8. 303 21
      tools/clang/lib/SPIRV/SPIRVEmitter.cpp
  9. 20 0
      tools/clang/lib/SPIRV/SPIRVEmitter.h
  10. 1 1
      tools/clang/lib/SPIRV/Structure.cpp
  11. 32 0
      tools/clang/test/CodeGenSPIRV/sm6.quad-read-across-diagonal.hlsl
  12. 32 0
      tools/clang/test/CodeGenSPIRV/sm6.quad-read-across-x.hlsl
  13. 32 0
      tools/clang/test/CodeGenSPIRV/sm6.quad-read-across-y.hlsl
  14. 33 0
      tools/clang/test/CodeGenSPIRV/sm6.quad-read-lane-at.hlsl
  15. 27 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-all-equal.hlsl
  16. 20 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-all-true.hlsl
  17. 20 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-any-true.hlsl
  18. 20 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-ballot.hlsl
  19. 38 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-bit-and.hlsl
  20. 38 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-bit-or.hlsl
  21. 38 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-bit-xor.hlsl
  22. 20 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-count-bits.hlsl
  23. 31 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-max.hlsl
  24. 31 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-min.hlsl
  25. 31 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-product.hlsl
  26. 31 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-active-sum.hlsl
  27. 3 2
      tools/clang/test/CodeGenSPIRV/sm6.wave-get-lane-count.hlsl
  28. 3 2
      tools/clang/test/CodeGenSPIRV/sm6.wave-get-lane-index.hlsl
  29. 13 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-is-first-lane.hlsl
  30. 19 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-count-bits.hlsl
  31. 31 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-product.hlsl
  32. 31 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-sum.hlsl
  33. 32 0
      tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-at.hlsl
  34. 10 8
      tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-first.hlsl
  35. 1 1
      tools/clang/tools/dxc/dxc.cpp
  36. 96 0
      tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp
  37. 2 2
      tools/clang/unittests/SPIRV/FileTestFixture.cpp
  38. 6 0
      tools/clang/unittests/SPIRV/FileTestFixture.h
  39. 3 3
      tools/clang/unittests/SPIRV/FileTestUtils.cpp
  40. 1 1
      tools/clang/unittests/SPIRV/FileTestUtils.h
  41. 1 1
      tools/clang/unittests/SPIRV/WholeFileTestFixture.cpp

+ 38 - 15
docs/SPIR-V.rst

@@ -2531,21 +2531,44 @@ generated. ``.RestartStrip()`` method calls will be translated into the SPIR-V
 Shader Model 6.0 Wave Intrinsics
 ================================
 
-Shader Model 6.0 introduces a set of wave operations, which are translated
-according to the following table:
-
-====================== ============================= =========================
-      Intrinsic               SPIR-V BuiltIn                Extension
-====================== ============================= =========================
-``WaveGetLaneCount()`` ``SubgroupSize``              ``SPV_KHR_shader_ballot``
-``WaveGetLaneIndex()`` ``SubgroupLocalInvocationId`` ``SPV_KHR_shader_ballot``
-====================== ============================= =========================
-
-======================= ================================ =========================
-      Intrinsic               SPIR-V Instruction                Extension
-======================= ================================ =========================
-``WaveReadLaneFirst()`` ``OpSubgroupFirstInvocationKHR`` ``SPV_KHR_shader_ballot``
-======================= ================================ =========================
+ ... note ::
+
+  Wave intrinsics requires SPIR-V 1.3, which is supported by Vulkan 1.1.
+  If you use wave intrinsics in your source code, the generated SPIR-V code
+  will be of version 1.3 instead of 1.0, which is supported by Vulkan 1.0.
+
+Shader model 6.0 introduces a set of wave operations. Apart from
+``WaveGetLaneCount()`` and ``WaveGetLaneIndex()``, which are translated into
+loading from SPIR-V builtin variable ``SubgroupSize`` and
+``SubgroupLocalInvocationId`` respectively, the rest are translated into SPIR-V
+group operations with ``Subgroup`` scope according to the following chart:
+
+============= ============================ =================================== ======================
+Wave Category       Wave Intrinsics               SPIR-V Opcode                SPIR-V Group Operation
+============= ============================ =================================== ======================
+Query         ``WaveIsFirstLane()``        ``OpGroupNonUniformElect``
+Vote          ``WaveActiveAnyTrue()``      ``OpGroupNonUniformAny``
+Vote          ``WaveActiveAllTrue()``      ``OpGroupNonUniformAll``
+Vote          ``WaveActiveBallot()``       ``OpGroupNonUniformBallot``
+Reduction     ``WaveActiveAllEqual()``     ``OpGroupNonUniformAllEqual``       ``Reduction``
+Reduction     ``WaveActiveCountBits()``    ``OpGroupNonUniformBallotBitCount`` ``Reduction``
+Reduction     ``WaveActiveSum()``          ``OpGroupNonUniform*Add``           ``Reduction``
+Reduction     ``WaveActiveProduct()``      ``OpGroupNonUniform*Mul``           ``Reduction``
+Reduction     ``WaveActiveBitAdd()``       ``OpGroupNonUniformBitwiseAnd``     ``Reduction``
+Reduction     ``WaveActiveBitOr()``        ``OpGroupNonUniformBitwiseOr``      ``Reduction``
+Reduction     ``WaveActiveBitXor()``       ``OpGroupNonUniformBitwiseXor``     ``Reduction``
+Reduction     ``WaveActiveMin()``          ``OpGroupNonUniform*Min``           ``Reduction``
+Reduction     ``WaveActiveMax()``          ``OpGroupNonUniform*Max``           ``Reduction``
+Scan/Prefix   ``WavePrefixSum()``          ``OpGroupNonUniform*Add``           ``ExclusiveScan``
+Scan/Prefix   ``WavePrefixProduct()``      ``OpGroupNonUniform*Mul``           ``ExclusiveScan``
+Scan/Prefix   ``WavePrefixCountBits()`     ``OpGroupNonUniformBallotBitCount`` ``ExclusiveScan``
+Broadcast     ``WaveReadLaneAt()``         ``OpGroupNonUniformBroadcast``
+Broadcast     ``WaveReadLaneFirst()``      ``OpGroupNonUniformBroadcastFirst``
+Quad          ``QuadReadAcrossX()``        ``OpGroupNonUniformQuadSwap``
+Quad          ``QuadReadAcrossY()``        ``OpGroupNonUniformQuadSwap``
+Quad          ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap``
+Quad          ``QuadReadLaneAt()``         ``OpGroupNonUniformQuadBroadcast``
+============= ============================ =================================== ======================
 
 Vulkan Command-line Options
 ===========================

+ 10 - 0
tools/clang/include/clang/SPIRV/InstBuilder.h

@@ -1034,6 +1034,16 @@ public:
                                     uint32_t result_id, uint32_t lhs,
                                     uint32_t rhs);
 
+  // All-in-one methods for creating OpGroupNonUniform* operations.
+  InstBuilder &groupNonUniformOp(spv::Op op, uint32_t result_type,
+                                 uint32_t result_id, uint32_t exec_scope);
+  InstBuilder &groupNonUniformUnaryOp(
+      spv::Op op, uint32_t result_type, uint32_t result_id, uint32_t exec_scope,
+      llvm::Optional<spv::GroupOperation> groupOp, uint32_t operand);
+  InstBuilder &groupNonUniformBinaryOp(spv::Op op, uint32_t result_type,
+                                       uint32_t result_id, uint32_t exec_scope,
+                                       uint32_t operand1, uint32_t operand2);
+
   // Methods for building constants.
   InstBuilder &opConstant(uint32_t result_type, uint32_t result_id,
                           uint32_t value);

+ 15 - 3
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -154,6 +154,17 @@ public:
   uint32_t createSpecConstantBinaryOp(spv::Op op, uint32_t resultType,
                                       uint32_t lhs, uint32_t rhs);
 
+  /// \brief Creates an operation with the given OpGroupNonUniform* SPIR-V
+  /// opcode. Returns the <result-id> for the result.
+  uint32_t createGroupNonUniformOp(spv::Op op, uint32_t resultType,
+                                   uint32_t execScope);
+  uint32_t createGroupNonUniformUnaryOp(
+      spv::Op op, uint32_t resultType, uint32_t execScope, uint32_t operand,
+      llvm::Optional<spv::GroupOperation> groupOp = llvm::None);
+  uint32_t createGroupNonUniformBinaryOp(spv::Op op, uint32_t resultType,
+                                         uint32_t execScope, uint32_t operand1,
+                                         uint32_t operand2);
+
   /// \brief Creates an atomic instruction with the given parameters.
   /// Returns the <result-id> for the result.
   uint32_t createAtomicOp(spv::Op opcode, uint32_t resultType,
@@ -303,11 +314,10 @@ public:
   /// \brief Creates an OpEndPrimitive instruction.
   void createEndPrimitive();
 
-  /// \brief Creates an OpSubgroupFirstInvocationKHR instruciton.
-  uint32_t createSubgroupFirstInvocation(uint32_t resultType, uint32_t value);
-
   // === SPIR-V Module Structure ===
 
+  inline void useSpirv1p3();
+
   inline void requireCapability(spv::Capability);
 
   inline void setAddressingModel(spv::AddressingModel);
@@ -477,6 +487,8 @@ void ModuleBuilder::setMemoryModel(spv::MemoryModel mm) {
   theModule.setMemoryModel(mm);
 }
 
+void ModuleBuilder::useSpirv1p3() { theModule.setVersion(0x00010300); }
+
 void ModuleBuilder::requireCapability(spv::Capability cap) {
   if (cap != spv::Capability::Max)
     theModule.addCapability(cap);

+ 3 - 1
tools/clang/include/clang/SPIRV/Structure.h

@@ -220,7 +220,7 @@ struct Header {
   void collect(const WordConsumer &consumer);
 
   const uint32_t magicNumber;
-  const uint32_t version;
+  uint32_t version;
   const uint32_t generator;
   uint32_t bound;
   const uint32_t reserved;
@@ -293,6 +293,7 @@ public:
   /// destructive; the module will be consumed and cleared after calling it.
   void take(InstBuilder *builder);
 
+  inline void setVersion(uint32_t version);
   /// \brief Sets the id bound to the given bound.
   inline void setBound(uint32_t newBound);
 
@@ -447,6 +448,7 @@ TypeIdPair::TypeIdPair(const Type &ty, uint32_t id) : type(ty), resultId(id) {}
 SPIRVModule::SPIRVModule()
     : addressingModel(llvm::None), memoryModel(llvm::None) {}
 
+void SPIRVModule::setVersion(uint32_t version) { header.version = version; }
 void SPIRVModule::setBound(uint32_t newBound) { header.bound = newBound; }
 
 void SPIRVModule::addCapability(spv::Capability cap) {

+ 9 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -870,6 +870,14 @@ bool DeclResultIdMapper::checkSemanticDuplication(bool forInput) {
   for (const auto &var : stageVars) {
     auto s = var.getSemanticStr();
 
+    if (s.empty()) {
+      // We translate WaveGetLaneCount() and WaveGetLaneIndex() into builtin
+      // variables. Those variables are inserted into the normal stage IO
+      // processing pipeline, but with the semantics as empty strings.
+      assert(var.isSpirvBuitin());
+      continue;
+    }
+
     if (forInput && var.getSigPoint()->IsInput()) {
       if (seenSemantics.count(s)) {
         emitError("input semantic '%0' used more than once", {}) << s;
@@ -1706,9 +1714,7 @@ uint32_t DeclResultIdMapper::getBuiltinVar(spv::BuiltIn builtIn) {
     return 0;
   }
 
-  // Both of them require the SPV_KHR_shader_ballot extension.
-  theBuilder.addExtension("SPV_KHR_shader_ballot");
-  theBuilder.requireCapability(spv::Capability::SubgroupBallotKHR);
+  theBuilder.requireCapability(spv::Capability::GroupNonUniform);
 
   uint32_t type = theBuilder.getUint32Type();
 

+ 62 - 0
tools/clang/lib/SPIRV/InstBuilderManual.cpp

@@ -81,6 +81,68 @@ InstBuilder &InstBuilder::specConstantBinaryOp(spv::Op op, uint32_t result_type,
   TheInst.emplace_back(static_cast<uint32_t>(op));
   TheInst.emplace_back(lhs);
   TheInst.emplace_back(rhs);
+  return *this;
+}
+
+InstBuilder &InstBuilder::groupNonUniformOp(spv::Op op, uint32_t result_type,
+                                            uint32_t result_id,
+                                            uint32_t exec_scope) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  // TODO: check op range
+
+  TheInst.reserve(4);
+  TheInst.emplace_back(static_cast<uint32_t>(op));
+  TheInst.emplace_back(result_type);
+  TheInst.emplace_back(result_id);
+  TheInst.emplace_back(exec_scope);
+
+  return *this;
+}
+
+InstBuilder &InstBuilder::groupNonUniformUnaryOp(
+    spv::Op op, uint32_t result_type, uint32_t result_id, uint32_t exec_scope,
+    llvm::Optional<spv::GroupOperation> groupOp, uint32_t operand) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  // TODO: check op range
+
+  TheInst.reserve(5);
+  TheInst.emplace_back(static_cast<uint32_t>(op));
+  TheInst.emplace_back(result_type);
+  TheInst.emplace_back(result_id);
+  TheInst.emplace_back(exec_scope);
+  if (groupOp.hasValue())
+    TheInst.emplace_back(static_cast<uint32_t>(groupOp.getValue()));
+  TheInst.emplace_back(operand);
+
+  return *this;
+}
+
+InstBuilder &
+InstBuilder::groupNonUniformBinaryOp(spv::Op op, uint32_t result_type,
+                                     uint32_t result_id, uint32_t exec_scope,
+                                     uint32_t operand1, uint32_t operand2) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  // TODO: check op range
+
+  TheInst.reserve(6);
+  TheInst.emplace_back(static_cast<uint32_t>(op));
+  TheInst.emplace_back(result_type);
+  TheInst.emplace_back(result_id);
+  TheInst.emplace_back(exec_scope);
+  TheInst.emplace_back(operand1);
+  TheInst.emplace_back(operand2);
 
   return *this;
 }

+ 36 - 12
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -247,6 +247,42 @@ uint32_t ModuleBuilder::createSpecConstantBinaryOp(spv::Op op,
   return id;
 }
 
+uint32_t ModuleBuilder::createGroupNonUniformOp(spv::Op op, uint32_t resultType,
+                                                uint32_t execScope) {
+  assert(insertPoint && "null insert point");
+  const uint32_t id = theContext.takeNextId();
+  instBuilder.groupNonUniformOp(op, resultType, id, execScope).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return id;
+}
+
+uint32_t ModuleBuilder::createGroupNonUniformUnaryOp(
+    spv::Op op, uint32_t resultType, uint32_t execScope, uint32_t operand,
+    llvm::Optional<spv::GroupOperation> groupOp) {
+  assert(insertPoint && "null insert point");
+  const uint32_t id = theContext.takeNextId();
+  instBuilder
+      .groupNonUniformUnaryOp(op, resultType, id, execScope, groupOp, operand)
+      .x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return id;
+}
+
+uint32_t ModuleBuilder::createGroupNonUniformBinaryOp(spv::Op op,
+                                                      uint32_t resultType,
+                                                      uint32_t execScope,
+                                                      uint32_t operand1,
+                                                      uint32_t operand2) {
+  assert(insertPoint && "null insert point");
+  const uint32_t id = theContext.takeNextId();
+  instBuilder
+      .groupNonUniformBinaryOp(op, resultType, id, execScope, operand1,
+                               operand2)
+      .x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return id;
+}
+
 uint32_t ModuleBuilder::createAtomicOp(spv::Op opcode, uint32_t resultType,
                                        uint32_t orignalValuePtr,
                                        uint32_t scopeId,
@@ -705,18 +741,6 @@ void ModuleBuilder::createEndPrimitive() {
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
-uint32_t ModuleBuilder::createSubgroupFirstInvocation(uint32_t resultType,
-                                                      uint32_t value) {
-  assert(insertPoint && "null insert point");
-  addExtension("SPV_KHR_shader_ballot");
-  requireCapability(spv::Capability::SubgroupBallotKHR);
-
-  uint32_t resultId = theContext.takeNextId();
-  instBuilder.opSubgroupFirstInvocationKHR(resultType, resultId, value).x();
-  insertPoint->appendInstruction(std::move(constructSite));
-  return resultId;
-}
-
 void ModuleBuilder::addExecutionMode(uint32_t entryPointId,
                                      spv::ExecutionMode em,
                                      llvm::ArrayRef<uint32_t> params) {

+ 303 - 21
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -203,8 +203,9 @@ bool isReferencingNonAliasStructuredOrByteBuffer(const Expr *expr) {
   return false;
 }
 
-bool spirvToolsLegalize(std::vector<uint32_t> *module, std::string *messages) {
-  spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
+bool spirvToolsLegalize(spv_target_env env, std::vector<uint32_t> *module,
+                        std::string *messages) {
+  spvtools::Optimizer optimizer(env);
 
   optimizer.SetMessageConsumer(
       [messages](spv_message_level_t /*level*/, const char * /*source*/,
@@ -220,8 +221,9 @@ bool spirvToolsLegalize(std::vector<uint32_t> *module, std::string *messages) {
   return optimizer.Run(module->data(), module->size(), module);
 }
 
-bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
-  spvtools::Optimizer optimizer(SPV_ENV_VULKAN_1_0);
+bool spirvToolsOptimize(spv_target_env env, std::vector<uint32_t> *module,
+                        std::string *messages) {
+  spvtools::Optimizer optimizer(env);
 
   optimizer.SetMessageConsumer(
       [messages](spv_message_level_t /*level*/, const char * /*source*/,
@@ -235,9 +237,9 @@ bool spirvToolsOptimize(std::vector<uint32_t> *module, std::string *messages) {
   return optimizer.Run(module->data(), module->size(), module);
 }
 
-bool spirvToolsValidate(std::vector<uint32_t> *module, std::string *messages,
-                        bool relaxLogicalPointer) {
-  spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
+bool spirvToolsValidate(spv_target_env env, std::vector<uint32_t> *module,
+                        std::string *messages, bool relaxLogicalPointer) {
+  spvtools::SpirvTools tools(env);
 
   tools.SetMessageConsumer(
       [messages](spv_message_level_t /*level*/, const char * /*source*/,
@@ -477,6 +479,41 @@ void getBaseClassIndices(const CastExpr *expr,
   }
 }
 
+spv::Capability getCapabilityForGroupNonUniform(spv::Op opcode) {
+  switch (opcode) {
+  case spv::Op::OpGroupNonUniformElect:
+    return spv::Capability::GroupNonUniform;
+  case spv::Op::OpGroupNonUniformAny:
+  case spv::Op::OpGroupNonUniformAll:
+  case spv::Op::OpGroupNonUniformAllEqual:
+    return spv::Capability::GroupNonUniformVote;
+  case spv::Op::OpGroupNonUniformBallot:
+  case spv::Op::OpGroupNonUniformBallotBitCount:
+  case spv::Op::OpGroupNonUniformBroadcast:
+  case spv::Op::OpGroupNonUniformBroadcastFirst:
+    return spv::Capability::GroupNonUniformBallot;
+  case spv::Op::OpGroupNonUniformIAdd:
+  case spv::Op::OpGroupNonUniformFAdd:
+  case spv::Op::OpGroupNonUniformIMul:
+  case spv::Op::OpGroupNonUniformFMul:
+  case spv::Op::OpGroupNonUniformSMax:
+  case spv::Op::OpGroupNonUniformUMax:
+  case spv::Op::OpGroupNonUniformFMax:
+  case spv::Op::OpGroupNonUniformSMin:
+  case spv::Op::OpGroupNonUniformUMin:
+  case spv::Op::OpGroupNonUniformFMin:
+  case spv::Op::OpGroupNonUniformBitwiseAnd:
+  case spv::Op::OpGroupNonUniformBitwiseOr:
+  case spv::Op::OpGroupNonUniformBitwiseXor:
+    return spv::Capability::GroupNonUniformArithmetic;
+  case spv::Op::OpGroupNonUniformQuadBroadcast:
+  case spv::Op::OpGroupNonUniformQuadSwap:
+    return spv::Capability::GroupNonUniformQuad;
+  }
+  assert(false && "unhandled opcode");
+  return spv::Capability::Max;
+}
+
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
@@ -490,8 +527,8 @@ SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci,
       declIdMapper(shaderModel, astContext, theBuilder, spirvOptions),
       typeTranslator(astContext, theBuilder, diags, options),
       entryFunctionId(0), curFunction(nullptr), curThis(0),
-      seenPushConstantAt(), isSpecConstantMode(false),
-      needsLegalization(false) {
+      seenPushConstantAt(), isSpecConstantMode(false), needsLegalization(false),
+      needsSpirv1p3(false) {
   if (shaderModel.GetKind() == hlsl::ShaderModel::Kind::Invalid)
     emitError("unknown shader module: %0", {}) << shaderModel.GetName();
   if (options.invertY && !shaderModel.IsVS() && !shaderModel.IsDS() &&
@@ -531,6 +568,12 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
   if (context.getDiagnostics().hasErrorOccurred())
     return;
 
+  spv_target_env targetEnv = SPV_ENV_VULKAN_1_0;
+  if (needsSpirv1p3) {
+    theBuilder.useSpirv1p3();
+    targetEnv = SPV_ENV_VULKAN_1_1;
+  }
+
   AddRequiredCapabilitiesForShaderModel();
 
   // Addressing and memory model are required in a valid SPIR-V module.
@@ -555,7 +598,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
     // Run legalization passes
     if (needsLegalization || declIdMapper.requiresLegalization()) {
       std::string messages;
-      if (!spirvToolsLegalize(&m, &messages)) {
+      if (!spirvToolsLegalize(targetEnv, &m, &messages)) {
         emitFatalError("failed to legalize SPIR-V: %0", {}) << messages;
         emitNote("please file a bug report on "
                  "https://github.com/Microsoft/DirectXShaderCompiler/issues "
@@ -570,7 +613,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
     // Run optimization passes
     if (theCompilerInstance.getCodeGenOpts().OptimizationLevel > 0) {
       std::string messages;
-      if (!spirvToolsOptimize(&m, &messages)) {
+      if (!spirvToolsOptimize(targetEnv, &m, &messages)) {
         emitFatalError("failed to optimize SPIR-V: %0", {}) << messages;
         emitNote("please file a bug report on "
                  "https://github.com/Microsoft/DirectXShaderCompiler/issues "
@@ -584,7 +627,7 @@ void SPIRVEmitter::HandleTranslationUnit(ASTContext &context) {
   // Validate the generated SPIR-V code
   if (!spirvOptions.disableValidation) {
     std::string messages;
-    if (!spirvToolsValidate(&m, &messages,
+    if (!spirvToolsValidate(targetEnv, &m, &messages,
                             declIdMapper.requiresLegalization())) {
       emitFatalError("generated SPIR-V is invalid: %0", {}) << messages;
       emitNote("please file a bug report on "
@@ -6019,6 +6062,7 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     retVal = processIntrinsicF32ToF16(callExpr);
     break;
   case hlsl::IntrinsicOp::IOP_WaveGetLaneCount: {
+    needsSpirv1p3 = true;
     const uint32_t retType =
         typeTranslator.translateType(callExpr->getCallReturnType(astContext));
     const uint32_t varId =
@@ -6026,23 +6070,73 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     retVal = theBuilder.createLoad(retType, varId);
   } break;
   case hlsl::IntrinsicOp::IOP_WaveGetLaneIndex: {
+    needsSpirv1p3 = true;
     const uint32_t retType =
         typeTranslator.translateType(callExpr->getCallReturnType(astContext));
     const uint32_t varId =
         declIdMapper.getBuiltinVar(spv::BuiltIn::SubgroupLocalInvocationId);
     retVal = theBuilder.createLoad(retType, varId);
   } break;
-  case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst: {
+  case hlsl::IntrinsicOp::IOP_WaveIsFirstLane:
+    retVal = processWaveQuery(callExpr, spv::Op::OpGroupNonUniformElect);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveActiveAllTrue:
+    retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAll);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveActiveAnyTrue:
+    retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAny);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveActiveBallot:
+    retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformBallot);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveActiveAllEqual:
+    retVal = processWaveVote(callExpr, spv::Op::OpGroupNonUniformAllEqual);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveActiveCountBits:
+    retVal = processWaveReductionOrPrefix(
+        callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
+        spv::GroupOperation::Reduce);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveActiveUSum:
+  case hlsl::IntrinsicOp::IOP_WaveActiveSum:
+  case hlsl::IntrinsicOp::IOP_WaveActiveUProduct:
+  case hlsl::IntrinsicOp::IOP_WaveActiveProduct:
+  case hlsl::IntrinsicOp::IOP_WaveActiveUMax:
+  case hlsl::IntrinsicOp::IOP_WaveActiveMax:
+  case hlsl::IntrinsicOp::IOP_WaveActiveUMin:
+  case hlsl::IntrinsicOp::IOP_WaveActiveMin:
+  case hlsl::IntrinsicOp::IOP_WaveActiveBitAnd:
+  case hlsl::IntrinsicOp::IOP_WaveActiveBitOr:
+  case hlsl::IntrinsicOp::IOP_WaveActiveBitXor: {
+    const auto retType = callExpr->getCallReturnType(astContext);
+    retVal = processWaveReductionOrPrefix(
+        callExpr, translateWaveOp(hlslOpcode, retType, callExpr->getExprLoc()),
+        spv::GroupOperation::Reduce);
+  } break;
+  case hlsl::IntrinsicOp::IOP_WavePrefixUSum:
+  case hlsl::IntrinsicOp::IOP_WavePrefixSum:
+  case hlsl::IntrinsicOp::IOP_WavePrefixUProduct:
+  case hlsl::IntrinsicOp::IOP_WavePrefixProduct: {
     const auto retType = callExpr->getCallReturnType(astContext);
-    if (!retType->isScalarType()) {
-      emitError("vector overloads of WaveReadLaneFirst unimplemented",
-                callExpr->getExprLoc());
-      return 0;
-    }
-    const uint32_t retTypeId = typeTranslator.translateType(retType);
-    retVal = theBuilder.createSubgroupFirstInvocation(
-        retTypeId, doExpr(callExpr->getArg(0)));
+    retVal = processWaveReductionOrPrefix(
+        callExpr, translateWaveOp(hlslOpcode, retType, callExpr->getExprLoc()),
+        spv::GroupOperation::ExclusiveScan);
   } break;
+  case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
+    retVal = processWaveReductionOrPrefix(
+        callExpr, spv::Op::OpGroupNonUniformBallotBitCount,
+        spv::GroupOperation::ExclusiveScan);
+    break;
+  case hlsl::IntrinsicOp::IOP_WaveReadLaneAt:
+  case hlsl::IntrinsicOp::IOP_WaveReadLaneFirst:
+    retVal = processWaveBroadcast(callExpr);
+    break;
+  case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
+  case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
+  case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
+  case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
+    retVal = processWaveQuadWideShuffle(callExpr, hlslOpcode);
+    break;
   case hlsl::IntrinsicOp::IOP_abort:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSampleCount:
   case hlsl::IntrinsicOp::IOP_GetRenderTargetSamplePosition: {
@@ -6413,6 +6507,194 @@ uint32_t SPIRVEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
   return theBuilder.createCompositeConstruct(uint4Type, accums);
 }
 
+uint32_t SPIRVEmitter::processWaveQuery(const CallExpr *callExpr,
+                                        spv::Op opcode) {
+  // Signatures:
+  // bool WaveIsFirstLane()
+  // uint WaveGetLaneCount()
+  // uint WaveGetLaneIndex()
+  assert(callExpr->getNumArgs() == 0);
+  needsSpirv1p3 = true;
+  theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
+  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+  return theBuilder.createGroupNonUniformOp(opcode, retType, subgroupScope);
+}
+
+uint32_t SPIRVEmitter::processWaveVote(const CallExpr *callExpr,
+                                       spv::Op opcode) {
+  // Signatures:
+  // bool WaveActiveAnyTrue( bool expr )
+  // bool WaveActiveAllTrue( bool expr )
+  // bool uint4 WaveActiveBallot( bool expr )
+  assert(callExpr->getNumArgs() == 1);
+  needsSpirv1p3 = true;
+  theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
+  const uint32_t predicate = doExpr(callExpr->getArg(0));
+  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+  return theBuilder.createGroupNonUniformUnaryOp(opcode, retType, subgroupScope,
+                                                 predicate);
+}
+
+spv::Op SPIRVEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
+                                      SourceLocation srcLoc) {
+  const bool isSintType = isSintOrVecMatOfSintType(type);
+  const bool isUintType = isUintOrVecMatOfUintType(type);
+  const bool isFloatType = isFloatOrVecMatOfFloatType(type);
+
+#define WAVE_OP_CASE_INT(kind, intWaveOp)                                      \
+                                                                               \
+  case hlsl::IntrinsicOp::IOP_Wave##kind: {                                    \
+    if (isSintType || isUintType) {                                            \
+      return spv::Op::OpGroupNonUniform##intWaveOp;                            \
+    }                                                                          \
+  } break
+
+#define WAVE_OP_CASE_INT_FLOAT(kind, intWaveOp, floatWaveOp)                   \
+                                                                               \
+  case hlsl::IntrinsicOp::IOP_Wave##kind: {                                    \
+    if (isSintType || isUintType) {                                            \
+      return spv::Op::OpGroupNonUniform##intWaveOp;                            \
+    }                                                                          \
+    if (isFloatType) {                                                         \
+      return spv::Op::OpGroupNonUniform##floatWaveOp;                          \
+    }                                                                          \
+  } break
+
+#define WAVE_OP_CASE_SINT_UINT_FLOAT(kind, sintWaveOp, uintWaveOp,             \
+                                     floatWaveOp)                              \
+                                                                               \
+  case hlsl::IntrinsicOp::IOP_Wave##kind: {                                    \
+    if (isSintType) {                                                          \
+      return spv::Op::OpGroupNonUniform##sintWaveOp;                           \
+    }                                                                          \
+    if (isUintType) {                                                          \
+      return spv::Op::OpGroupNonUniform##uintWaveOp;                           \
+    }                                                                          \
+    if (isFloatType) {                                                         \
+      return spv::Op::OpGroupNonUniform##floatWaveOp;                          \
+    }                                                                          \
+  } break
+
+  switch (op) {
+    WAVE_OP_CASE_INT_FLOAT(ActiveUSum, IAdd, FAdd);
+    WAVE_OP_CASE_INT_FLOAT(ActiveSum, IAdd, FAdd);
+    WAVE_OP_CASE_INT_FLOAT(ActiveUProduct, IMul, FMul);
+    WAVE_OP_CASE_INT_FLOAT(ActiveProduct, IMul, FMul);
+    WAVE_OP_CASE_INT_FLOAT(PrefixUSum, IAdd, FAdd);
+    WAVE_OP_CASE_INT_FLOAT(PrefixSum, IAdd, FAdd);
+    WAVE_OP_CASE_INT_FLOAT(PrefixUProduct, IMul, FMul);
+    WAVE_OP_CASE_INT_FLOAT(PrefixProduct, IMul, FMul);
+    WAVE_OP_CASE_INT(ActiveBitAnd, BitwiseAnd);
+    WAVE_OP_CASE_INT(ActiveBitOr, BitwiseOr);
+    WAVE_OP_CASE_INT(ActiveBitXor, BitwiseXor);
+    WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMax, SMax, UMax, FMax);
+    WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMax, SMax, UMax, FMax);
+    WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMin, SMin, UMin, FMin);
+    WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMin, SMin, UMin, FMin);
+  }
+#undef WAVE_OP_CASE_INT_FLOAT
+#undef WAVE_OP_CASE_INT
+#undef WAVE_OP_CASE_SINT_UINT_FLOAT
+
+  emitError("translating wave operator '%0' unimplemented", srcLoc)
+      << static_cast<uint32_t>(op);
+  return spv::Op::OpNop;
+}
+
+uint32_t SPIRVEmitter::processWaveReductionOrPrefix(
+    const CallExpr *callExpr, spv::Op opcode, spv::GroupOperation groupOp) {
+  // Signatures:
+  // bool WaveActiveAllEqual( <type> expr )
+  // uint WaveActiveCountBits( bool bBit )
+  // <type> WaveActiveSum( <type> expr )
+  // <type> WaveActiveProduct( <type> expr )
+  // <int_type> WaveActiveBitAnd( <int_type> expr )
+  // <int_type> WaveActiveBitOr( <int_type> expr )
+  // <int_type> WaveActiveBitXor( <int_type> expr )
+  // <type> WaveActiveMin( <type> expr)
+  // <type> WaveActiveMax( <type> expr)
+  //
+  // uint WavePrefixCountBits(Bool bBit)
+  // <type> WavePrefixProduct(<type> value)
+  // <type> WavePrefixSum(<type> value)
+  assert(callExpr->getNumArgs() == 1);
+  needsSpirv1p3 = true;
+  theBuilder.requireCapability(getCapabilityForGroupNonUniform(opcode));
+  const uint32_t predicate = doExpr(callExpr->getArg(0));
+  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+  return theBuilder.createGroupNonUniformUnaryOp(
+      opcode, retType, subgroupScope, predicate,
+      llvm::Optional<spv::GroupOperation>(groupOp));
+}
+
+uint32_t SPIRVEmitter::processWaveBroadcast(const CallExpr *callExpr) {
+  // Signatures:
+  // <type> WaveReadLaneFirst(<type> expr)
+  // <type> WaveReadLaneAt(<type> expr, uint laneIndex)
+  const auto numArgs = callExpr->getNumArgs();
+  assert(numArgs == 1 || numArgs == 2);
+  needsSpirv1p3 = true;
+  theBuilder.requireCapability(spv::Capability::GroupNonUniformBallot);
+  const uint32_t value = doExpr(callExpr->getArg(0));
+  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+  if (numArgs == 2)
+    return theBuilder.createGroupNonUniformBinaryOp(
+        spv::Op::OpGroupNonUniformBroadcast, retType, subgroupScope, value,
+        doExpr(callExpr->getArg(1)));
+  else
+    return theBuilder.createGroupNonUniformUnaryOp(
+        spv::Op::OpGroupNonUniformBroadcastFirst, retType, subgroupScope,
+        value);
+}
+
+uint32_t SPIRVEmitter::processWaveQuadWideShuffle(const CallExpr *callExpr,
+                                                  hlsl::IntrinsicOp op) {
+  // Signatures:
+  // <type> QuadReadAcrossX(<type> localValue)
+  // <type> QuadReadAcrossY(<type> localValue)
+  // <type> QuadReadAcrossDiagonal(<type> localValue)
+  // <type> QuadReadLaneAt(<type> sourceValue, uint quadLaneID)
+  assert(callExpr->getNumArgs() == 1 || callExpr->getNumArgs() == 2);
+  needsSpirv1p3 = true;
+  theBuilder.requireCapability(spv::Capability::GroupNonUniformQuad);
+
+  const uint32_t value = doExpr(callExpr->getArg(0));
+  const uint32_t subgroupScope = theBuilder.getConstantInt32(3);
+  const uint32_t retType =
+      typeTranslator.translateType(callExpr->getCallReturnType(astContext));
+
+  uint32_t target = 0;
+  spv::Op opcode = spv::Op::OpGroupNonUniformQuadSwap;
+  switch (op) {
+  case hlsl::IntrinsicOp::IOP_QuadReadAcrossX:
+    target = theBuilder.getConstantUint32(0);
+    break;
+  case hlsl::IntrinsicOp::IOP_QuadReadAcrossY:
+    target = theBuilder.getConstantUint32(1);
+    break;
+  case hlsl::IntrinsicOp::IOP_QuadReadAcrossDiagonal:
+    target = theBuilder.getConstantUint32(2);
+    break;
+  case hlsl::IntrinsicOp::IOP_QuadReadLaneAt:
+    target = doExpr(callExpr->getArg(1));
+    opcode = spv::Op::OpGroupNonUniformQuadBroadcast;
+    break;
+  default:
+    llvm_unreachable("case should not appear here");
+  }
+
+  return theBuilder.createGroupNonUniformBinaryOp(opcode, retType,
+                                                  subgroupScope, value, target);
+}
+
 uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
   // Signature is: ret modf(x, ip)
   // [in]    x: the input floating-point value.

+ 20 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -130,6 +130,8 @@ private:
   /// taking consideration of the operand type.
   spv::Op translateOp(BinaryOperator::Opcode op, QualType type);
 
+  spv::Op translateWaveOp(hlsl::IntrinsicOp op, QualType type, SourceLocation);
+
   /// Generates SPIR-V instructions for the given normal (non-intrinsic and
   /// non-operator) standalone or member function call.
   SpirvEvalInfo processCall(const CallExpr *expr);
@@ -448,6 +450,21 @@ private:
   /// Processes Interlocked* intrinsic functions.
   uint32_t processIntrinsicInterlockedMethod(const CallExpr *,
                                              hlsl::IntrinsicOp);
+  /// Processes SM6.0 wave query intrinsic calls.
+  uint32_t processWaveQuery(const CallExpr *, spv::Op opcode);
+
+  /// Processes SM6.0 wave vote intrinsic calls.
+  uint32_t processWaveVote(const CallExpr *, spv::Op opcode);
+
+  /// Processes SM6.0 wave reduction or scan/prefix intrinsic calls.
+  uint32_t processWaveReductionOrPrefix(const CallExpr *, spv::Op op,
+                                        spv::GroupOperation groupOp);
+
+  /// Processes SM6.0 wave broadcast intrinsic calls.
+  uint32_t processWaveBroadcast(const CallExpr *);
+
+  /// Processes SM6.0 quad-wide shuffle.
+  uint32_t processWaveQuadWideShuffle(const CallExpr *, hlsl::IntrinsicOp op);
 
 private:
   /// Returns the <result-id> for constant value 0 of the given type.
@@ -926,6 +943,9 @@ private:
   /// Note: legalization specific code
   bool needsLegalization;
 
+  /// Indicates whether we should generate SPIR-V 1.3 instead of 1.0.
+  bool needsSpirv1p3;
+
   /// Mapping from methods to the decls to represent their implicit object
   /// parameters
   ///

+ 1 - 1
tools/clang/lib/SPIRV/Structure.cpp

@@ -172,7 +172,7 @@ void Function::getReachableBasicBlocks(std::vector<BasicBlock *> *bbVec) const {
 
 Header::Header()
     // We are using the unfied header, which shows spv::Version as the newest
-    // version. But we need to stick to 1.0 for Vulkan consumption.
+    // version. But we need to stick to 1.0 for Vulkan consumption by default.
     : magicNumber(spv::MagicNumber), version(0x00010000),
       generator((kGeneratorNumber << 16) | kToolVersion), bound(0),
       reserved(0) {}

+ 32 - 0
tools/clang/test/CodeGenSPIRV/sm6.quad-read-across-diagonal.hlsl

@@ -0,0 +1,32 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     int4 val1;
+    uint3 val2;
+    float val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformQuad
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+     int4 val1 = values[x].val1;
+    uint3 val2 = values[x].val2;
+    float val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4int %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %v4int %int_3 [[val1]] %uint_2
+    values[x].val1 = QuadReadAcrossDiagonal(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %v3uint %int_3 [[val2]] %uint_2
+    values[x].val2 = QuadReadAcrossDiagonal(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %float %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %float %int_3 [[val3]] %uint_2
+    values[x].val3 = QuadReadAcrossDiagonal(val3);
+}

+ 32 - 0
tools/clang/test/CodeGenSPIRV/sm6.quad-read-across-x.hlsl

@@ -0,0 +1,32 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     int4 val1;
+    uint3 val2;
+    float val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformQuad
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+     int4 val1 = values[x].val1;
+    uint3 val2 = values[x].val2;
+    float val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4int %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %v4int %int_3 [[val1]] %uint_0
+    values[x].val1 = QuadReadAcrossX(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %v3uint %int_3 [[val2]] %uint_0
+    values[x].val2 = QuadReadAcrossX(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %float %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %float %int_3 [[val3]] %uint_0
+    values[x].val3 = QuadReadAcrossX(val3);
+}

+ 32 - 0
tools/clang/test/CodeGenSPIRV/sm6.quad-read-across-y.hlsl

@@ -0,0 +1,32 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     int4 val1;
+    uint3 val2;
+    float val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformQuad
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+     int4 val1 = values[x].val1;
+    uint3 val2 = values[x].val2;
+    float val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4int %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %v4int %int_3 [[val1]] %uint_1
+    values[x].val1 = QuadReadAcrossY(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %v3uint %int_3 [[val2]] %uint_1
+    values[x].val2 = QuadReadAcrossY(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %float %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadSwap %float %int_3 [[val3]] %uint_1
+    values[x].val3 = QuadReadAcrossY(val3);
+}

+ 33 - 0
tools/clang/test/CodeGenSPIRV/sm6.quad-read-lane-at.hlsl

@@ -0,0 +1,33 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    float4 val1;
+     uint3 val2;
+       int val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformQuad
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+    float4 val1 = values[x].val1;
+     uint3 val2 = values[x].val2;
+       int val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadBroadcast %v4float %int_3 [[val1]] %uint_0
+    values[x].val1 = QuadReadLaneAt(val1, 0);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadBroadcast %v3uint %int_3 [[val2]] %uint_1
+    values[x].val2 = QuadReadLaneAt(val2, 1);
+// CHECK:      [[val3:%\d+]] = OpLoad %int %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformQuadBroadcast %int %int_3 [[val3]] %uint_2
+    values[x].val3 = QuadReadLaneAt(val3, 2);
+}
+

+ 27 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-all-equal.hlsl

@@ -0,0 +1,27 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    float4 val1;
+    uint val2;
+    bool res;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformVote
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+// CHECK:         [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v4float %values %int_0 {{%\d+}} %int_0
+// CHECK-NEXT: [[f32val:%\d+]] = OpLoad %v4float [[ptr]]
+// TODO: The front end will return bool4 for the first call, which acutally should be bool.
+// XXXXX-NEXT:        {{%\d+}} = OpGroupNonUniformAllEqual %bool %int_3 [[f32val]]
+
+// CHECK:         [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %values %int_0 {{%\d+}} %int_1
+// CHECK-NEXT: [[u32val:%\d+]] = OpLoad %uint [[ptr]]
+// CHECK-NEXT:        {{%\d+}} = OpGroupNonUniformAllEqual %bool %int_3 [[u32val]]
+    values[x].res = WaveActiveAllEqual(values[x].val1) && WaveActiveAllEqual(values[x].val2);
+}

+ 20 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-all-true.hlsl

@@ -0,0 +1,20 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    uint val;
+    bool res;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformVote
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+// CHECK:      [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_1
+// CHECK-NEXT:     {{%\d+}} = OpGroupNonUniformAll %bool %int_3 [[cmp]]
+    values[x].res = WaveActiveAllTrue(values[x].val == 1);
+}

+ 20 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-any-true.hlsl

@@ -0,0 +1,20 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    uint val;
+    bool res;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformVote
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+// CHECK:      [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_0
+// CHECK-NEXT:     {{%\d+}} = OpGroupNonUniformAny %bool %int_3 [[cmp]]
+    values[x].res = WaveActiveAnyTrue(values[x].val == 0);
+}

+ 20 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-ballot.hlsl

@@ -0,0 +1,20 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    uint val;
+    uint4 res;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformBallot
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+// CHECK:      [[cmp:%\d+]] = OpIEqual %bool {{%\d+}} %uint_2
+// CHECK-NEXT:     {{%\d+}} = OpGroupNonUniformBallot %v4uint %int_3 [[cmp]]
+    values[x].res = WaveActiveBallot(values[x].val == 2);
+}

+ 38 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-bit-and.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// Note: WaveActiveBitAnd() only accepts unsigned interger scalars/vectors.
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    uint4 val1;
+    uint3 val2;
+    uint2 val3;
+     uint val4;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+    uint4 val1 = values[x].val1;
+    uint3 val2 = values[x].val2;
+    uint2 val3 = values[x].val3;
+     uint val4 = values[x].val4;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4uint %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseAnd %v4uint %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveBitAnd(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseAnd %v3uint %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveBitAnd(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %v2uint %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseAnd %v2uint %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveBitAnd(val3);
+// CHECK:      [[val4:%\d+]] = OpLoad %uint %val4
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseAnd %uint %int_3 Reduce [[val4]]
+    values[x].val4 = WaveActiveBitAnd(val4);
+}

+ 38 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-bit-or.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// Note: WaveActiveBitOr() only accepts unsigned interger scalars/vectors.
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    uint4 val1;
+    uint3 val2;
+    uint2 val3;
+     uint val4;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+    uint4 val1 = values[x].val1;
+    uint3 val2 = values[x].val2;
+    uint2 val3 = values[x].val3;
+     uint val4 = values[x].val4;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4uint %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseOr %v4uint %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveBitOr(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseOr %v3uint %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveBitOr(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %v2uint %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseOr %v2uint %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveBitOr(val3);
+// CHECK:      [[val4:%\d+]] = OpLoad %uint %val4
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseOr %uint %int_3 Reduce [[val4]]
+    values[x].val4 = WaveActiveBitOr(val4);
+}

+ 38 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-bit-xor.hlsl

@@ -0,0 +1,38 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// Note: WaveActiveBitXor() only accepts unsigned interger scalars/vectors.
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    uint4 val1;
+    uint3 val2;
+    uint2 val3;
+     uint val4;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+    uint4 val1 = values[x].val1;
+    uint3 val2 = values[x].val2;
+    uint2 val3 = values[x].val3;
+     uint val4 = values[x].val4;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4uint %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseXor %v4uint %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveBitXor(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseXor %v3uint %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveBitXor(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %v2uint %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseXor %v2uint %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveBitXor(val3);
+// CHECK:      [[val4:%\d+]] = OpLoad %uint %val4
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBitwiseXor %uint %int_3 Reduce [[val4]]
+    values[x].val4 = WaveActiveBitXor(val4);
+}

+ 20 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-count-bits.hlsl

@@ -0,0 +1,20 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     uint val;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformBallot
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+// CHECK:  {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 Reduce {{%\d+}}
+    values[x].val = WaveActiveCountBits(values[x].val == 0);
+}
+

+ 31 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-max.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     uint4 val1;
+    float2 val2;
+       int val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+     uint4 val1 = values[x].val1;
+    float2 val2 = values[x].val2;
+       int val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4uint %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformUMax %v4uint %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveMax(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2float %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformFMax %v2float %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveMax(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %int %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformSMax %int %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveMax(val3);
+}

+ 31 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-min.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     uint4 val1;
+    float2 val2;
+       int val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+     uint4 val1 = values[x].val1;
+    float2 val2 = values[x].val2;
+       int val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4uint %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformUMin %v4uint %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveMin(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2float %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformFMin %v2float %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveMin(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %int %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformSMin %int %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveMin(val3);
+}

+ 31 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-product.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    float4 val1;
+     uint2 val2;
+       int val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+    float4 val1 = values[x].val1;
+     uint2 val2 = values[x].val2;
+       int val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformFMul %v4float %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveProduct(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIMul %v2uint %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveProduct(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %int %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIMul %int %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveProduct(val3);
+}

+ 31 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-active-sum.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     int4 val1;
+    uint2 val2;
+    float val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+     int4 val1 = values[x].val1;
+    uint2 val2 = values[x].val2;
+    float val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4int %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIAdd %v4int %int_3 Reduce [[val1]]
+    values[x].val1 = WaveActiveSum(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIAdd %v2uint %int_3 Reduce [[val2]]
+    values[x].val2 = WaveActiveSum(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %float %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformFAdd %float %int_3 Reduce [[val3]]
+    values[x].val3 = WaveActiveSum(val3);
+}

+ 3 - 2
tools/clang/test/CodeGenSPIRV/sm6.wave-get-lane-count.hlsl

@@ -1,9 +1,10 @@
 // Run: %dxc -T cs_6_0 -E main
 
+// CHECK: ; Version: 1.3
+
 RWStructuredBuffer<uint> values;
 
-// CHECK: OpCapability SubgroupBallotKHR
-// CHECK: OpExtension "SPV_KHR_shader_ballot"
+// CHECK: OpCapability GroupNonUniform
 
 // CHECK: OpEntryPoint GLCompute
 // CHECK-SAME: %SubgroupSize

+ 3 - 2
tools/clang/test/CodeGenSPIRV/sm6.wave-get-lane-index.hlsl

@@ -1,9 +1,10 @@
 // Run: %dxc -T cs_6_0 -E main
 
+// CHECK: ; Version: 1.3
+
 RWStructuredBuffer<uint> values;
 
-// CHECK: OpCapability SubgroupBallotKHR
-// CHECK: OpExtension "SPV_KHR_shader_ballot"
+// CHECK: OpCapability GroupNonUniform
 
 // CHECK: OpEntryPoint GLCompute
 // CHECK-SAME: %SubgroupLocalInvocationId

+ 13 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-is-first-lane.hlsl

@@ -0,0 +1,13 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+RWStructuredBuffer<uint> values;
+
+// CHECK: OpCapability GroupNonUniform
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+// CHECK: {{%\d+}} = OpGroupNonUniformElect %bool %int_3
+    values[id.x] = WaveIsFirstLane();
+}

+ 19 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-count-bits.hlsl

@@ -0,0 +1,19 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     uint val;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformBallot
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+// CHECK:  {{%\d+}} = OpGroupNonUniformBallotBitCount %uint %int_3 ExclusiveScan {{%\d+}}
+    values[x].val = WavePrefixCountBits(values[x].val == 0);
+}

+ 31 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-product.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    float4 val1;
+     uint2 val2;
+       int val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+    float4 val1 = values[x].val1;
+     uint2 val2 = values[x].val2;
+       int val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformFMul %v4float %int_3 ExclusiveScan [[val1]]
+    values[x].val1 = WavePrefixProduct(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIMul %v2uint %int_3 ExclusiveScan [[val2]]
+    values[x].val2 = WavePrefixProduct(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %int %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIMul %int %int_3 ExclusiveScan [[val3]]
+    values[x].val3 = WavePrefixProduct(val3);
+}

+ 31 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-prefix-sum.hlsl

@@ -0,0 +1,31 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+     int4 val1;
+    uint2 val2;
+    float val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformArithmetic
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+     int4 val1 = values[x].val1;
+    uint2 val2 = values[x].val2;
+    float val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4int %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIAdd %v4int %int_3 ExclusiveScan [[val1]]
+    values[x].val1 = WavePrefixSum(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformIAdd %v2uint %int_3 ExclusiveScan [[val2]]
+    values[x].val2 = WavePrefixSum(val2);
+// CHECK:      [[val3:%\d+]] = OpLoad %float %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformFAdd %float %int_3 ExclusiveScan [[val3]]
+    values[x].val3 = WavePrefixSum(val3);
+}

+ 32 - 0
tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-at.hlsl

@@ -0,0 +1,32 @@
+// Run: %dxc -T cs_6_0 -E main
+
+// CHECK: ; Version: 1.3
+
+struct S {
+    float4 val1;
+     uint3 val2;
+       int val3;
+};
+
+RWStructuredBuffer<S> values;
+
+// CHECK: OpCapability GroupNonUniformBallot
+
+[numthreads(32, 1, 1)]
+void main(uint3 id: SV_DispatchThreadID) {
+    uint x = id.x;
+
+    float4 val1 = values[x].val1;
+     uint3 val2 = values[x].val2;
+       int val3 = values[x].val3;
+
+// CHECK:      [[val1:%\d+]] = OpLoad %v4float %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %v4float %int_3 [[val1]] %uint_15
+    values[x].val1 = WaveReadLaneAt(val1, 15);
+// CHECK:      [[val2:%\d+]] = OpLoad %v3uint %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %v3uint %int_3 [[val2]] %uint_42
+    values[x].val2 = WaveReadLaneAt(val2, 42);
+// CHECK:      [[val3:%\d+]] = OpLoad %int %val3
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcast %int %int_3 [[val3]] %uint_15
+    values[x].val3 = WaveReadLaneAt(val3, 15);
+}

+ 10 - 8
tools/clang/test/CodeGenSPIRV/sm6.wave-read-lane-first.hlsl

@@ -1,7 +1,6 @@
 // Run: %dxc -T cs_6_0 -E main
 
-// CHECK: OpCapability SubgroupBallotKHR
-// CHECK: OpExtension "SPV_KHR_shader_ballot"
+// CHECK: ; Version: 1.3
 
 struct S {
     uint4 val1;
@@ -11,6 +10,8 @@ struct S {
 
 RWStructuredBuffer<S> values;
 
+// CHECK: OpCapability GroupNonUniformBallot
+
 [numthreads(32, 1, 1)]
 void main(uint3 id: SV_DispatchThreadID) {
     uint x = id.x;
@@ -19,12 +20,13 @@ void main(uint3 id: SV_DispatchThreadID) {
      int2 val2 = values[x].val2;
     float val3 = values[x].val3;
 
-// OpSubgroupFirstInvocationKHR requires that:
-//   Result Type must be a 32-bit integer type or a 32-bit float type scalar.
-
-    // values[x].val1 = WaveReadLaneFirst(val1);
-    // values[x].val2 = WaveReadLaneFirst(val2);
+// CHECK:      [[val1:%\d+]] = OpLoad %v4uint %val1
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcastFirst %v4uint %int_3 [[val1]]
+    values[x].val1 = WaveReadLaneFirst(val1);
+// CHECK:      [[val2:%\d+]] = OpLoad %v2int %val2
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcastFirst %v2int %int_3 [[val2]]
+    values[x].val2 = WaveReadLaneFirst(val2);
 // CHECK:      [[val3:%\d+]] = OpLoad %float %val3
-// CHECK-NEXT:      {{%\d+}} = OpSubgroupFirstInvocationKHR %float [[val3]]
+// CHECK-NEXT:      {{%\d+}} = OpGroupNonUniformBroadcastFirst %float %int_3 [[val3]]
     values[x].val3 = WaveReadLaneFirst(val3);
 }

+ 1 - 1
tools/clang/tools/dxc/dxc.cpp

@@ -84,7 +84,7 @@ static bool DisassembleSpirv(IDxcBlob *binaryBlob, IDxcLibrary *library,
   memcpy(words.data(), binaryStr.data(), binaryStr.size());
 
   std::string assembly;
-  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
+  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
   uint32_t options = (SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES |
                       SPV_BINARY_TO_TEXT_OPTION_INDENT);
   if (withColor)

+ 96 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -983,21 +983,117 @@ TEST_F(FileTest, PrimitiveErrorGS) {
 }
 
 // Shader model 6.0 wave query
+TEST_F(FileTest, SM6WaveIsFirstLane) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-is-first-lane.hlsl");
+}
 TEST_F(FileTest, SM6WaveGetLaneCount) {
+  useVulkan1p1();
   runFileTest("sm6.wave-get-lane-count.hlsl");
 }
 TEST_F(FileTest, SM6WaveGetLaneIndex) {
+  useVulkan1p1();
   runFileTest("sm6.wave-get-lane-index.hlsl");
 }
 TEST_F(FileTest, SM6WaveBuiltInNoDuplicate) {
+  useVulkan1p1();
   runFileTest("sm6.wave.builtin.no-dup.hlsl");
 }
 
+// Shader model 6.0 wave vote
+TEST_F(FileTest, SM6WaveActiveAnyTrue) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-any-true.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveAllTrue) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-all-true.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveBallot) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-ballot.hlsl");
+}
+
+// Shader model 6.0 wave reduction
+TEST_F(FileTest, SM6WaveActiveAllEqual) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-all-equal.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveSum) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-sum.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveProduct) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-product.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveMax) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-max.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveMin) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-min.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveBitAnd) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-bit-and.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveBitOr) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-bit-or.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveBitXor) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-bit-xor.hlsl");
+}
+TEST_F(FileTest, SM6WaveActiveCountBits) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-active-count-bits.hlsl");
+}
+
+// Shader model 6.0 wave scan/prefix
+TEST_F(FileTest, SM6WavePrefixSum) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-prefix-sum.hlsl");
+}
+TEST_F(FileTest, SM6WavePrefixProduct) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-prefix-product.hlsl");
+}
+TEST_F(FileTest, SM6WavePrefixCountBits) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-prefix-count-bits.hlsl");
+}
+
 // Shader model 6.0 wave broadcast
+TEST_F(FileTest, SM6WaveReadLaneAt) {
+  useVulkan1p1();
+  runFileTest("sm6.wave-read-lane-at.hlsl");
+}
 TEST_F(FileTest, SM6WaveReadLaneFirst) {
+  useVulkan1p1();
   runFileTest("sm6.wave-read-lane-first.hlsl");
 }
 
+// Shader model 6.0 wave quad-wide shuffle
+TEST_F(FileTest, SM6QuadReadAcrossX) {
+  useVulkan1p1();
+  runFileTest("sm6.quad-read-across-x.hlsl");
+}
+TEST_F(FileTest, SM6QuadReadAcrossY) {
+  useVulkan1p1();
+  runFileTest("sm6.quad-read-across-y.hlsl");
+}
+TEST_F(FileTest, SM6QuadReadAcrossDiagonal) {
+  useVulkan1p1();
+  runFileTest("sm6.quad-read-across-diagonal.hlsl");
+}
+TEST_F(FileTest, SM6QuadReadLaneAt) {
+  useVulkan1p1();
+  runFileTest("sm6.quad-read-lane-at.hlsl");
+}
+
 // SPIR-V specific
 TEST_F(FileTest, SpirvStorageClass) { runFileTest("spirv.storage-class.hlsl"); }
 

+ 2 - 2
tools/clang/unittests/SPIRV/FileTestFixture.cpp

@@ -128,8 +128,8 @@ void FileTest::runFileTest(llvm::StringRef filename, Expect expect,
 
   // Run SPIR-V validation for successful compilations
   if (runValidation && expect != Expect::Failure) {
-    EXPECT_TRUE(
-        utils::validateSpirvBinary(generatedBinary, relaxLogicalPointer));
+    EXPECT_TRUE(utils::validateSpirvBinary(targetEnv, generatedBinary,
+                                           relaxLogicalPointer));
   }
 }
 

+ 6 - 0
tools/clang/unittests/SPIRV/FileTestFixture.h

@@ -10,6 +10,7 @@
 #ifndef LLVM_CLANG_UNITTESTS_SPIRV_FILE_TEST_FIXTURE_H
 #define LLVM_CLANG_UNITTESTS_SPIRV_FILE_TEST_FIXTURE_H
 
+#include "spirv-tools/libspirv.h"
 #include "llvm/ADT/StringRef.h"
 #include "gtest/gtest.h"
 
@@ -25,6 +26,10 @@ public:
     Failure, // Failure (with errors) - check error message
   };
 
+  FileTest() : targetEnv(SPV_ENV_VULKAN_1_0) {}
+
+  void useVulkan1p1() { targetEnv = SPV_ENV_VULKAN_1_1; }
+
   /// \brief Runs a File Test! (See class description for more info)
   void runFileTest(llvm::StringRef path, Expect expect = Expect::Success,
                    bool runValidation = true, bool relaxLogicalPointer = false);
@@ -40,6 +45,7 @@ private:
   std::vector<uint32_t> generatedBinary; ///< The generated SPIR-V Binary
   std::string checkCommands;             ///< CHECK commands that verify output
   std::string generatedSpirvAsm;         ///< Disassembled binary (SPIR-V code)
+  spv_target_env targetEnv;              ///< Environment to validate against
 };
 
 } // end namespace spirv

+ 3 - 3
tools/clang/unittests/SPIRV/FileTestUtils.cpp

@@ -22,7 +22,7 @@ namespace utils {
 bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
                             std::string *generatedSpirvAsm,
                             bool generateHeader) {
-  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
+  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
   spirvTools.SetMessageConsumer(
       [](spv_message_level_t, const char *, const spv_position_t &,
          const char *message) { fprintf(stdout, "%s\n", message); });
@@ -32,11 +32,11 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
   return spirvTools.Disassemble(binary, generatedSpirvAsm, options);
 }
 
-bool validateSpirvBinary(std::vector<uint32_t> &binary,
+bool validateSpirvBinary(spv_target_env env, std::vector<uint32_t> &binary,
                          bool relaxLogicalPointer) {
   spvtools::ValidatorOptions options;
   options.SetRelaxLogicalPointer(relaxLogicalPointer);
-  spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_0);
+  spvtools::SpirvTools spirvTools(env);
   spirvTools.SetMessageConsumer(
       [](spv_message_level_t, const char *, const spv_position_t &,
          const char *message) { fprintf(stdout, "%s\n", message); });

+ 1 - 1
tools/clang/unittests/SPIRV/FileTestUtils.h

@@ -32,7 +32,7 @@ bool disassembleSpirvBinary(std::vector<uint32_t> &binary,
 
 /// \brief Runs the SPIR-V Tools validation on the given SPIR-V binary.
 /// Returns true if validation is successful; false otherwise.
-bool validateSpirvBinary(std::vector<uint32_t> &binary,
+bool validateSpirvBinary(spv_target_env, std::vector<uint32_t> &binary,
                          bool relaxLogicalPointer);
 
 /// \brief Parses the Target Profile and Entry Point from the Run command

+ 1 - 1
tools/clang/unittests/SPIRV/WholeFileTestFixture.cpp

@@ -107,7 +107,7 @@ void WholeFileTest::runWholeFileTest(llvm::StringRef filename,
 
   // Run SPIR-V validation if requested.
   if (runSpirvValidation) {
-    EXPECT_TRUE(utils::validateSpirvBinary(generatedBinary,
+    EXPECT_TRUE(utils::validateSpirvBinary(SPV_ENV_VULKAN_1_0, generatedBinary,
                                            /*relaxLogicalPointer=*/false));
   }
 }