Răsfoiți Sursa

Add support for WaveMatch and WaveMultiPrefix<Op> (#1867)

These new DXIL instructions are added to SM 6.5. The valid operations
for <Op> are:

    - BitAnd
    - BitOr
    - BitXor
    - CountBits
    - Product
    - Sum

In HLSL, these are exposed as:

    uint4 WaveMatch(<type> val)
    <type> WaveMultiPrefixBitAnd(<type> val, uint4 mask)
    <type> WaveMultiPrefixBitOr(<type> val, uint4 mask)
    <type> WaveMultiPrefixBitXor(<type> val, uint4 mask)
    uint WaveMultiPrefixCountBits(bool val, uint4 mask)
    <type> WaveMultiPrefixProduct(<type> val, uint4 mask)
    <type> WaveMultiPrefixSum(<type> val, uint4 mask)

In DXIL, these are exposed as:

    [BitAnd,BitOr,BitXor,Product,Sum]
    %dx.types.fouri32 @dx.op.waveMatch.T(i32 %opc, T %val)
    T @dx.op.waveMultiPrefixOp.T(i32 %opc, T %val, i32 %mask_x,
                                 i32 %mask_y, i32 %mask_y, i32 %mask_z,
                                 i8 %operation, i8 %signed)

    [CountBits]
    i32 @dx.op.waveMultiPrefixBitCount(i32 %opc, i1 %val, i32 %mask_x,
                                       i32 %mask_y, i32 %mask_y,
                                       i32 %mask_z)

Scalarization of vector types occur as per the existing wave intrinsics.
For WaveMatch, the match is performed on each scalar and the results
are combined with bitwise AND. For WaveMultiPrefix, the operation is
performed on each scalar and combined into an aggregate.
Justin Holewinski 6 ani în urmă
părinte
comite
bc72038998

+ 3 - 0
docs/DXIL.rst

@@ -2252,6 +2252,9 @@ ID  Name                          Description
 162 Dot2AddHalf                   2D half dot product with accumulate to float
 163 Dot4AddI8Packed               signed dot product of 4 x i8 vectors packed into i32, with accumulate to i32
 164 Dot4AddU8Packed               unsigned dot product of 4 x u8 vectors packed into i32, with accumulate to i32
+165 WaveMatch                     returns the bitmask of active lanes that have the same value
+166 WaveMultiPrefixOp             returns the result of the operation on groups of lanes identified by a bitmask
+167 WaveMultiPrefixBitCount       returns the count of bits set to 1 on groups of lanes identified by a bitmask
 === ============================= =======================================================================================================================================================================================================================
 
 

+ 22 - 2
include/dxc/DXIL/DxilConstants.h

@@ -548,6 +548,9 @@ namespace DXIL {
     WaveGetLaneCount = 112, // returns the number of lanes in the wave
     WaveGetLaneIndex = 111, // returns the index of the current lane in the wave
     WaveIsFirstLane = 110, // returns 1 for the first lane in the wave
+    WaveMatch = 165, // returns the bitmask of active lanes that have the same value
+    WaveMultiPrefixBitCount = 167, // returns the count of bits set to 1 on groups of lanes identified by a bitmask
+    WaveMultiPrefixOp = 166, // returns the result of the operation on groups of lanes identified by a bitmask
     WavePrefixBitCount = 136, // returns the count of bits set to 1 on prior lanes
     WavePrefixOp = 121, // returns the result of the operation on prior lanes
     WaveReadLaneAt = 117, // returns the value from the specified lane
@@ -558,8 +561,9 @@ namespace DXIL {
     NumOpCodes_Dxil_1_2 = 141,
     NumOpCodes_Dxil_1_3 = 162,
     NumOpCodes_Dxil_1_4 = 165,
+    NumOpCodes_Dxil_1_5 = 168,
   
-    NumOpCodes = 165 // exclusive last value of enumeration
+    NumOpCodes = 168 // exclusive last value of enumeration
   };
   // OPCODE-ENUM:END
 
@@ -761,6 +765,9 @@ namespace DXIL {
     WaveGetLaneCount,
     WaveGetLaneIndex,
     WaveIsFirstLane,
+    WaveMatch,
+    WaveMultiPrefixBitCount,
+    WaveMultiPrefixOp,
     WavePrefixOp,
     WaveReadLaneAt,
     WaveReadLaneFirst,
@@ -770,8 +777,9 @@ namespace DXIL {
     NumOpClasses_Dxil_1_2 = 97,
     NumOpClasses_Dxil_1_3 = 118,
     NumOpClasses_Dxil_1_4 = 120,
+    NumOpClasses_Dxil_1_5 = 123,
   
-    NumOpClasses = 120 // exclusive last value of enumeration
+    NumOpClasses = 123 // exclusive last value of enumeration
   };
   // OPCODECLASS-ENUM:END
 
@@ -1057,6 +1065,18 @@ namespace DXIL {
   };
   // WAVEOPKIND-ENUM:END
 
+  /* <py::lines('WAVEMULTIPREFIXOPKIND-ENUM')>hctdb_instrhelp.get_enum_decl("WaveMultiPrefixOpKind")</py>*/
+  // WAVEMULTIPREFIXOPKIND-ENUM:BEGIN
+  // Kind of cross-lane for multi-prefix operation
+  enum class WaveMultiPrefixOpKind : unsigned {
+    And = 1, // bitwise and of values
+    Or = 2, // bitwise or of values
+    Product = 4, // product of values
+    Sum = 0, // sum of values
+    Xor = 3, // bitwise xor of values
+  };
+  // WAVEMULTIPREFIXOPKIND-ENUM:END
+
   /* <py::lines('SIGNEDOPKIND-ENUM')>hctdb_instrhelp.get_enum_decl("SignedOpKind")</py>*/
   // SIGNEDOPKIND-ENUM:BEGIN
   // Sign vs. unsigned operands for operation

+ 109 - 0
include/dxc/DXIL/DxilInstructions.h

@@ -5440,5 +5440,114 @@ struct DxilInst_Dot4AddU8Packed {
   llvm::Value *get_b() const { return Instr->getOperand(3); }
   void set_b(llvm::Value *val) { Instr->setOperand(3, val); }
 };
+
+/// This instruction returns the bitmask of active lanes that have the same value
+struct DxilInst_WaveMatch {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WaveMatch(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WaveMatch);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (2 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_value = 1,
+  };
+  // Accessors
+  llvm::Value *get_value() const { return Instr->getOperand(1); }
+  void set_value(llvm::Value *val) { Instr->setOperand(1, val); }
+};
+
+/// This instruction returns the result of the operation on groups of lanes identified by a bitmask
+struct DxilInst_WaveMultiPrefixOp {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WaveMultiPrefixOp(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WaveMultiPrefixOp);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (8 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_value = 1,
+    arg_mask0 = 2,
+    arg_mask1 = 3,
+    arg_mask2 = 4,
+    arg_mask3 = 5,
+    arg_op = 6,
+    arg_sop = 7,
+  };
+  // Accessors
+  llvm::Value *get_value() const { return Instr->getOperand(1); }
+  void set_value(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_mask0() const { return Instr->getOperand(2); }
+  void set_mask0(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_mask1() const { return Instr->getOperand(3); }
+  void set_mask1(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_mask2() const { return Instr->getOperand(4); }
+  void set_mask2(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_mask3() const { return Instr->getOperand(5); }
+  void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
+  llvm::Value *get_op() const { return Instr->getOperand(6); }
+  void set_op(llvm::Value *val) { Instr->setOperand(6, val); }
+  int8_t get_op_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(6))->getZExtValue()); }
+  void set_op_val(int8_t val) { Instr->setOperand(6, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+  llvm::Value *get_sop() const { return Instr->getOperand(7); }
+  void set_sop(llvm::Value *val) { Instr->setOperand(7, val); }
+  int8_t get_sop_val() const { return (int8_t)(llvm::dyn_cast<llvm::ConstantInt>(Instr->getOperand(7))->getZExtValue()); }
+  void set_sop_val(int8_t val) { Instr->setOperand(7, llvm::Constant::getIntegerValue(llvm::IntegerType::get(Instr->getContext(), 8), llvm::APInt(8, (uint64_t)val))); }
+};
+
+/// This instruction returns the count of bits set to 1 on groups of lanes identified by a bitmask
+struct DxilInst_WaveMultiPrefixBitCount {
+  llvm::Instruction *Instr;
+  // Construction and identification
+  DxilInst_WaveMultiPrefixBitCount(llvm::Instruction *pInstr) : Instr(pInstr) {}
+  operator bool() const {
+    return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::WaveMultiPrefixBitCount);
+  }
+  // Validation support
+  bool isAllowed() const { return true; }
+  bool isArgumentListValid() const {
+    if (6 != llvm::dyn_cast<llvm::CallInst>(Instr)->getNumArgOperands()) return false;
+    return true;
+  }
+  // Metadata
+  bool requiresUniformInputs() const { return false; }
+  // Operand indexes
+  enum OperandIdx {
+    arg_value = 1,
+    arg_mask0 = 2,
+    arg_mask1 = 3,
+    arg_mask2 = 4,
+    arg_mask3 = 5,
+  };
+  // Accessors
+  llvm::Value *get_value() const { return Instr->getOperand(1); }
+  void set_value(llvm::Value *val) { Instr->setOperand(1, val); }
+  llvm::Value *get_mask0() const { return Instr->getOperand(2); }
+  void set_mask0(llvm::Value *val) { Instr->setOperand(2, val); }
+  llvm::Value *get_mask1() const { return Instr->getOperand(3); }
+  void set_mask1(llvm::Value *val) { Instr->setOperand(3, val); }
+  llvm::Value *get_mask2() const { return Instr->getOperand(4); }
+  void set_mask2(llvm::Value *val) { Instr->setOperand(4, val); }
+  llvm::Value *get_mask3() const { return Instr->getOperand(5); }
+  void set_mask3(llvm::Value *val) { Instr->setOperand(5, val); }
+};
 // INSTR-HELPER:END
 } // namespace hlsl

+ 15 - 0
include/dxc/HlslIntrinsicOp.h

@@ -94,6 +94,13 @@ import hctdb_instrhelp
   IOP_WaveGetLaneCount,
   IOP_WaveGetLaneIndex,
   IOP_WaveIsFirstLane,
+  IOP_WaveMatch,
+  IOP_WaveMultiPrefixBitAnd,
+  IOP_WaveMultiPrefixBitOr,
+  IOP_WaveMultiPrefixBitXor,
+  IOP_WaveMultiPrefixCountBits,
+  IOP_WaveMultiPrefixProduct,
+  IOP_WaveMultiPrefixSum,
   IOP_WavePrefixCountBits,
   IOP_WavePrefixProduct,
   IOP_WavePrefixSum,
@@ -263,6 +270,8 @@ import hctdb_instrhelp
   IOP_WaveActiveUMin,
   IOP_WaveActiveUProduct,
   IOP_WaveActiveUSum,
+  IOP_WaveMultiPrefixUProduct,
+  IOP_WaveMultiPrefixUSum,
   IOP_WavePrefixUProduct,
   IOP_WavePrefixUSum,
   IOP_uabs,
@@ -293,6 +302,8 @@ import hctdb_instrhelp
   case IntrinsicOp::IOP_WaveActiveMin:
   case IntrinsicOp::IOP_WaveActiveProduct:
   case IntrinsicOp::IOP_WaveActiveSum:
+  case IntrinsicOp::IOP_WaveMultiPrefixProduct:
+  case IntrinsicOp::IOP_WaveMultiPrefixSum:
   case IntrinsicOp::IOP_WavePrefixProduct:
   case IntrinsicOp::IOP_WavePrefixSum:
   case IntrinsicOp::IOP_abs:
@@ -332,6 +343,10 @@ import hctdb_instrhelp
     return static_cast<unsigned>(IntrinsicOp::IOP_WaveActiveUProduct);
   case IntrinsicOp::IOP_WaveActiveSum:
     return static_cast<unsigned>(IntrinsicOp::IOP_WaveActiveUSum);
+  case IntrinsicOp::IOP_WaveMultiPrefixProduct:
+    return static_cast<unsigned>(IntrinsicOp::IOP_WaveMultiPrefixUProduct);
+  case IntrinsicOp::IOP_WaveMultiPrefixSum:
+    return static_cast<unsigned>(IntrinsicOp::IOP_WaveMultiPrefixUSum);
   case IntrinsicOp::IOP_WavePrefixProduct:
     return static_cast<unsigned>(IntrinsicOp::IOP_WavePrefixUProduct);
   case IntrinsicOp::IOP_WavePrefixSum:

+ 21 - 2
lib/DXIL/DxilOperations.cpp

@@ -316,6 +316,11 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = {
   {  OC::Dot2AddHalf,             "Dot2AddHalf",              OCC::Dot2AddHalf,              "dot2AddHalf",               { false, false,  true, false, false, false, false, false, false, false, false}, Attribute::ReadNone, },
   {  OC::Dot4AddI8Packed,         "Dot4AddI8Packed",          OCC::Dot4AddPacked,            "dot4AddPacked",             { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
   {  OC::Dot4AddU8Packed,         "Dot4AddU8Packed",          OCC::Dot4AddPacked,            "dot4AddPacked",             { false, false, false, false, false, false, false,  true, false, false, false}, Attribute::ReadNone, },
+
+  // Wave                                                                                                                    void,     h,     f,     d,    i1,    i8,   i16,   i32,   i64,   udt,   obj ,  function attribute
+  {  OC::WaveMatch,               "WaveMatch",                OCC::WaveMatch,                "waveMatch",                 { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
+  {  OC::WaveMultiPrefixOp,       "WaveMultiPrefixOp",        OCC::WaveMultiPrefixOp,        "waveMultiPrefixOp",         { false,  true,  true,  true, false,  true,  true,  true,  true, false, false}, Attribute::None,     },
+  {  OC::WaveMultiPrefixBitCount, "WaveMultiPrefixBitCount",  OCC::WaveMultiPrefixBitCount,  "waveMultiPrefixBitCount",   {  true, false, false, false, false, false, false, false, false, false, false}, Attribute::None,     },
 };
 // OPCODE-OLOADS:END
 
@@ -498,8 +503,9 @@ bool OP::IsDxilOpWave(OpCode C) {
   // WaveActiveAllEqual=115, WaveActiveBallot=116, WaveReadLaneAt=117,
   // WaveReadLaneFirst=118, WaveActiveOp=119, WaveActiveBit=120,
   // WavePrefixOp=121, QuadReadLaneAt=122, QuadOp=123, WaveAllBitCount=135,
-  // WavePrefixBitCount=136
-  return (110 <= op && op <= 123) || (135 <= op && op <= 136);
+  // WavePrefixBitCount=136, WaveMatch=165, WaveMultiPrefixOp=166,
+  // WaveMultiPrefixBitCount=167
+  return (110 <= op && op <= 123) || (135 <= op && op <= 136) || (165 <= op && op <= 167);
   // OPCODE-WAVE:END
 }
 
@@ -650,6 +656,12 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation,
     major = 6;  minor = 4;
     return;
   }
+  // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
+  // WaveMultiPrefixBitCount=167
+  if ((165 <= op && op <= 167)) {
+    major = 6;  minor = 5;
+    return;
+  }
   // OPCODE-SMMASK:END
 #undef SFLAG
 }
@@ -1045,6 +1057,11 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) {
   case OpCode::Dot2AddHalf:            A(pETy);     A(pI32); A(pETy); A(pF16); A(pF16); A(pF16); A(pF16); break;
   case OpCode::Dot4AddI8Packed:        A(pI32);     A(pI32); A(pI32); A(pI32); A(pI32); break;
   case OpCode::Dot4AddU8Packed:        A(pI32);     A(pI32); A(pI32); A(pI32); A(pI32); break;
+
+    // Wave
+  case OpCode::WaveMatch:              A(pI4S);     A(pI32); A(pETy); break;
+  case OpCode::WaveMultiPrefixOp:      A(pETy);     A(pI32); A(pETy); A(pI32); A(pI32); A(pI32); A(pI32); A(pI8);  A(pI8);  break;
+  case OpCode::WaveMultiPrefixBitCount:A(pI32);     A(pI32); A(pI1);  A(pI32); A(pI32); A(pI32); A(pI32); break;
   // OPCODE-OLOAD-FUNCS:END
   default: DXASSERT(false, "otherwise unhandled case"); break;
   }
@@ -1152,6 +1169,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::USubb:
   case OpCode::WaveActiveAllEqual:
   case OpCode::CreateHandleForLib:
+  case OpCode::WaveMatch:
     DXASSERT_NOMSG(FT->getNumParams() > 1);
     return FT->getParamType(1);
   case OpCode::TextureStore:
@@ -1196,6 +1214,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) {
   case OpCode::WavePrefixBitCount:
   case OpCode::IgnoreHit:
   case OpCode::AcceptHitAndEndSearch:
+  case OpCode::WaveMultiPrefixBitCount:
     return Type::getVoidTy(m_Ctx);
   case OpCode::CheckAccessFullyMapped:
   case OpCode::AtomicBinOp:

+ 4 - 0
lib/HLSL/DxilValidation.cpp

@@ -855,6 +855,10 @@ static bool ValidateOpcodeInProfile(DXIL::OpCode opcode,
   // Instructions: Dot2AddHalf=162, Dot4AddI8Packed=163, Dot4AddU8Packed=164
   if ((162 <= op && op <= 164))
     return (major > 6 || (major == 6 && minor >= 4));
+  // Instructions: WaveMatch=165, WaveMultiPrefixOp=166,
+  // WaveMultiPrefixBitCount=167
+  if ((165 <= op && op <= 167))
+    return (major > 6 || (major == 6 && minor >= 5));
   return true;
   // VALOPCODESM-TEXT:END
 }

+ 131 - 0
lib/HLSL/HLOperationLower.cpp

@@ -1089,6 +1089,68 @@ Value *TranslateWaveAllEqual(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
   return TrivialDxilOperation(DXIL::OpCode::WaveActiveAllEqual, args, Ty, RetTy,
                               hlslOP, Builder);
 }
+
+// WaveMatch(val<n>)->uint4
+Value *TranslateWaveMatch(CallInst *CI, IntrinsicOp IOP, OP::OpCode Opc,
+                          HLOperationLowerHelper &Helper,
+                          HLObjectOperationLowerHelper *ObjHelper,
+                          bool &Translated) {
+  hlsl::OP *Op = &Helper.hlslOP;
+  IRBuilder<> Builder(CI);
+
+  // Generate a dx.op.waveMatch call for each scalar in the input, and perform
+  // a bitwise AND between each result to derive the final bitmask in the case
+  // of vector inputs.
+
+  // (1) Collect the list of all scalar inputs (e.g. decompose vectors)
+  SmallVector<Value *, 4> ScalarInputs;
+
+  Value *Val = CI->getArgOperand(1);
+  Type *ValTy = Val->getType();
+  Type *EltTy = ValTy->getScalarType();
+
+  if (ValTy->isVectorTy()) {
+    for (uint64_t i = 0, e = ValTy->getVectorNumElements(); i != e; ++i) {
+      Value *Elt = Builder.CreateExtractElement(Val, i);
+      ScalarInputs.push_back(Elt);
+    }
+  } else {
+    ScalarInputs.push_back(Val);
+  }
+
+  Value *Res = nullptr;
+  Constant *OpcArg = Op->GetU32Const((unsigned)DXIL::OpCode::WaveMatch);
+  Value *Fn = Op->GetOpFunc(OP::OpCode::WaveMatch, EltTy);
+
+  // (2) For each scalar, emit a call to dx.op.waveMatch. If this is not the
+  // first scalar, then AND the result with the accumulator.
+  for (unsigned i = 0, e = ScalarInputs.size(); i != e; ++i) {
+    Value *Args[] = { OpcArg, ScalarInputs[i] };
+    Value *Call = Builder.CreateCall(Fn, Args);
+
+    if (Res) {
+      // Generate bitwise AND of the components
+      for (unsigned j = 0; j != 4; ++j) {
+        Value *ResVal = Builder.CreateExtractValue(Res, j);
+        Value *CallVal = Builder.CreateExtractValue(Call, j);
+        Value *And = Builder.CreateAnd(ResVal, CallVal);
+        Res = Builder.CreateInsertValue(Res, And, j);
+      }
+    } else {
+      Res = Call;
+    }
+  }
+
+  // (3) Convert the final aggregate into a vector to make the types match
+  Value *ResVec = UndefValue::get(CI->getType());
+  for (unsigned i = 0; i != 4; ++i) {
+    Value *Elt = Builder.CreateExtractValue(Res, i);
+    ResVec = Builder.CreateInsertElement(ResVec, Elt, i);
+  }
+
+  return ResVec;
+}
+
 // Wave intrinsics of the form fn(valA)->valB, where no overloading takes place
 Value *TranslateWaveA2B(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
                         HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
@@ -1139,6 +1201,8 @@ static unsigned WaveIntrinsicToSignedOpKind(IntrinsicOp IOP) {
       IOP == IntrinsicOp::IOP_WaveActiveUMin ||
       IOP == IntrinsicOp::IOP_WaveActiveUSum ||
       IOP == IntrinsicOp::IOP_WaveActiveUProduct ||
+      IOP == IntrinsicOp::IOP_WaveMultiPrefixUProduct ||
+      IOP == IntrinsicOp::IOP_WaveMultiPrefixUSum ||
       IOP == IntrinsicOp::IOP_WavePrefixUSum ||
       IOP == IntrinsicOp::IOP_WavePrefixUProduct)
     return (unsigned)DXIL::SignedOpKind::Unsigned;
@@ -1173,6 +1237,19 @@ static unsigned WaveIntrinsicToOpKind(IntrinsicOp IOP) {
     return (unsigned)DXIL::WaveOpKind::Sum;
   case IntrinsicOp::IOP_WaveActiveProduct:
   case IntrinsicOp::IOP_WaveActiveUProduct:
+  // MultiPrefix operations
+  case IntrinsicOp::IOP_WaveMultiPrefixBitAnd:
+    return (unsigned)DXIL::WaveMultiPrefixOpKind::And;
+  case IntrinsicOp::IOP_WaveMultiPrefixBitOr:
+    return (unsigned)DXIL::WaveMultiPrefixOpKind::Or;
+  case IntrinsicOp::IOP_WaveMultiPrefixBitXor:
+    return (unsigned)DXIL::WaveMultiPrefixOpKind::Xor;
+  case IntrinsicOp::IOP_WaveMultiPrefixProduct:
+  case IntrinsicOp::IOP_WaveMultiPrefixUProduct:
+    return (unsigned)DXIL::WaveMultiPrefixOpKind::Product;
+  case IntrinsicOp::IOP_WaveMultiPrefixSum:
+  case IntrinsicOp::IOP_WaveMultiPrefixUSum:
+    return (unsigned)DXIL::WaveMultiPrefixOpKind::Sum;
   default:
     DXASSERT(IOP == IntrinsicOp::IOP_WaveActiveProduct ||
              IOP == IntrinsicOp::IOP_WaveActiveUProduct,
@@ -1199,6 +1276,51 @@ Value *TranslateWaveA2A(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
                               CI->getOperand(1)->getType(), CI, hlslOP);
 }
 
+// WaveMultiPrefixOP(val<n>, mask) -> val<n>
+Value *TranslateWaveMultiPrefix(CallInst *CI, IntrinsicOp IOP, OP::OpCode Opc,
+                                HLOperationLowerHelper &Helper,
+                                HLObjectOperationLowerHelper *ObjHelper,
+                                bool &Translated) {
+  hlsl::OP *Op = &Helper.hlslOP;
+
+  Constant *KindValInt = Op->GetI8Const(WaveIntrinsicToOpKind(IOP));
+  Constant *SignValInt = Op->GetI8Const(WaveIntrinsicToSignedOpKind(IOP));
+
+  // Decompose mask into scalars
+  IRBuilder<> Builder(CI);
+  Value *Mask = CI->getArgOperand(2);
+  Value *Mask0 = Builder.CreateExtractElement(Mask, 0ULL);
+  Value *Mask1 = Builder.CreateExtractElement(Mask, 1ULL);
+  Value *Mask2 = Builder.CreateExtractElement(Mask, 2ULL);
+  Value *Mask3 = Builder.CreateExtractElement(Mask, 3ULL);
+
+  Value *Args[] = { nullptr, CI->getOperand(1),
+                    Mask0, Mask1, Mask2, Mask3, KindValInt, SignValInt };
+
+  return TrivialDxilOperation(Opc, Args, CI->getOperand(1)->getType(), CI, Op);
+}
+
+// WaveMultiPrefixBitCount(i1, mask) -> i32
+Value *TranslateWaveMultiPrefixBitCount(CallInst *CI, IntrinsicOp IOP,
+                                        OP::OpCode Opc,
+                                        HLOperationLowerHelper &Helper,
+                                        HLObjectOperationLowerHelper *ObjHelper,
+                                        bool &Translated) {
+  hlsl::OP *Op = &Helper.hlslOP;
+
+  // Decompose mask into scalars
+  IRBuilder<> Builder(CI);
+  Value *Mask = CI->getArgOperand(2);
+  Value *Mask0 = Builder.CreateExtractElement(Mask, 0ULL);
+  Value *Mask1 = Builder.CreateExtractElement(Mask, 1ULL);
+  Value *Mask2 = Builder.CreateExtractElement(Mask, 2ULL);
+  Value *Mask3 = Builder.CreateExtractElement(Mask, 3ULL);
+
+  Value *Args[] = { nullptr, CI->getOperand(1), Mask0, Mask1, Mask2, Mask3 };
+
+  return TrivialDxilOperation(Opc, Args, Helper.voidTy, CI, Op);
+}
+
 // Wave intrinsics of the form fn()->val
 Value *TranslateWaveToVal(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode,
                           HLOperationLowerHelper &helper,  HLObjectOperationLowerHelper *pObjHelper, bool &Translated) {
@@ -4736,6 +4858,13 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     {IntrinsicOp::IOP_WaveGetLaneCount, TranslateWaveToVal, DXIL::OpCode::WaveGetLaneCount},
     {IntrinsicOp::IOP_WaveGetLaneIndex, TranslateWaveToVal, DXIL::OpCode::WaveGetLaneIndex},
     {IntrinsicOp::IOP_WaveIsFirstLane, TranslateWaveToVal, DXIL::OpCode::WaveIsFirstLane},
+    {IntrinsicOp::IOP_WaveMatch, TranslateWaveMatch, DXIL::OpCode::WaveMatch},
+    {IntrinsicOp::IOP_WaveMultiPrefixBitAnd, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp},
+    {IntrinsicOp::IOP_WaveMultiPrefixBitOr, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp},
+    {IntrinsicOp::IOP_WaveMultiPrefixBitXor, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp},
+    {IntrinsicOp::IOP_WaveMultiPrefixCountBits, TranslateWaveMultiPrefixBitCount, DXIL::OpCode::WaveMultiPrefixBitCount},
+    {IntrinsicOp::IOP_WaveMultiPrefixProduct, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp},
+    {IntrinsicOp::IOP_WaveMultiPrefixSum, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp},
     {IntrinsicOp::IOP_WavePrefixCountBits, TranslateWaveA2B, DXIL::OpCode::WavePrefixBitCount},
     {IntrinsicOp::IOP_WavePrefixProduct, TranslateWaveA2A, DXIL::OpCode::WavePrefixOp},
     {IntrinsicOp::IOP_WavePrefixSum, TranslateWaveA2A, DXIL::OpCode::WavePrefixOp},
@@ -4910,6 +5039,8 @@ IntrinsicLower gLowerTable[static_cast<unsigned>(IntrinsicOp::Num_Intrinsics)] =
     { IntrinsicOp::IOP_WaveActiveUMin, TranslateWaveA2A, DXIL::OpCode::WaveActiveOp },
     { IntrinsicOp::IOP_WaveActiveUProduct, TranslateWaveA2A, DXIL::OpCode::WaveActiveOp },
     { IntrinsicOp::IOP_WaveActiveUSum, TranslateWaveA2A, DXIL::OpCode::WaveActiveOp },
+    { IntrinsicOp::IOP_WaveMultiPrefixUProduct, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp },
+    { IntrinsicOp::IOP_WaveMultiPrefixUSum, TranslateWaveMultiPrefix, DXIL::OpCode::WaveMultiPrefixOp },
     { IntrinsicOp::IOP_WavePrefixUProduct, TranslateWaveA2A, DXIL::OpCode::WavePrefixOp },
     { IntrinsicOp::IOP_WavePrefixUSum, TranslateWaveA2A, DXIL::OpCode::WavePrefixOp },
     { IntrinsicOp::IOP_uabs, TranslateUAbs, DXIL::OpCode::NumOpCodes },

Fișier diff suprimat deoarece este prea mare
+ 173 - 125
tools/clang/lib/Sema/gen_intrin_main_tables_15.h


+ 1 - 1
tools/clang/test/CodeGenHLSL/abs1.hlsl

@@ -2,7 +2,7 @@
 
 // CHECK: main
 // After lowering, these would turn into multiple abs calls rather than a 4 x float
-// CHECK: call <4 x float> @"dx.hl.op..<4 x float> (i32, <4 x float>)"(i32 84,
+// CHECK: call <4 x float> @"dx.hl.op..<4 x float> (i32, <4 x float>)"(i32 91,
 
 float4 main(float4 a : A) : SV_TARGET {
   return abs(a*a.yxxx);

+ 72 - 0
tools/clang/test/CodeGenHLSL/quick-test/sm_6_5_wave.hlsl

@@ -0,0 +1,72 @@
+// RUN: %dxc -E main -T ps_6_5 %s | FileCheck %s
+
+StructuredBuffer<uint4> g_mask;
+
+uint4 main(uint4 input : ATTR0) : SV_Target {
+    uint4 mask = g_mask[0];
+
+    // CHECK: call %dx.types.fouri32 @dx.op.waveMatch.i32(i32 165,
+    // CHECK: call %dx.types.fouri32 @dx.op.waveMatch.i32(i32 165,
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: call %dx.types.fouri32 @dx.op.waveMatch.i32(i32 165,
+    // CHECK-DAG: call %dx.types.fouri32 @dx.op.waveMatch.i32(i32 165,
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    // CHECK-DAG: and i32
+    uint4 res = WaveMatch(input);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 1, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 1, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 1, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 1, i8 0)
+    res += WaveMultiPrefixBitAnd(input, mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 2, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 2, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 2, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 2, i8 0)
+    res += WaveMultiPrefixBitOr(input, mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 3, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 3, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 3, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 3, i8 0)
+    res += WaveMultiPrefixBitXor(input, mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixBitCount(i32 167, i1 %{{[A-Za-z0-9]+}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}})
+    res.x += WaveMultiPrefixCountBits((input.x == 1), mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 1)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 1)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 1)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 1)
+    res += WaveMultiPrefixProduct(input, mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 4, i8 0)
+    res += WaveMultiPrefixProduct((int4)input, mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 1)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 1)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 1)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 1)
+    res += WaveMultiPrefixSum(input, mask);
+
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 0)
+    // CHECK: call i32 @dx.op.waveMultiPrefixOp.i32(i32 166, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i32 %{{[0-9}+]}}, i8 0, i8 0)
+    res += WaveMultiPrefixSum((int4)input, mask);
+
+    return res;
+}

+ 4 - 1
tools/clang/tools/dxcompiler/dxcdisassembler.cpp

@@ -1180,7 +1180,10 @@ static const char *OpCodeSignatures[] = {
   "()",  // PrimitiveIndex
   "(acc,ax,ay,bx,by)",  // Dot2AddHalf
   "(acc,a,b)",  // Dot4AddI8Packed
-  "(acc,a,b)"  // Dot4AddU8Packed
+  "(acc,a,b)",  // Dot4AddU8Packed
+  "(value)",  // WaveMatch
+  "(value,mask0,mask1,mask2,mask3,op,sop)",  // WaveMultiPrefixOp
+  "(value,mask0,mask1,mask2,mask3)"  // WaveMultiPrefixBitCount
 };
 // OPCODE-SIGS:END
 

+ 201 - 0
tools/clang/unittests/HLSL/ExecutionTest.cpp

@@ -320,6 +320,15 @@ public:
   BEGIN_TEST_METHOD(WaveIntrinsicsPrefixUintTest)
   TEST_METHOD_PROPERTY(L"DataSource", L"Table:ShaderOpArithTable.xml#WaveIntrinsicsPrefixUintTable")
   END_TEST_METHOD()
+
+  BEGIN_TEST_METHOD(WaveIntrinsicsSM65IntTest)
+  TEST_METHOD_PROPERTY(L"DataSource", L"Table:ShaderOpArithTable.xml#WaveIntrinsicsMultiPrefixIntTable")
+  END_TEST_METHOD()
+
+  BEGIN_TEST_METHOD(WaveIntrinsicsSM65UintTest)
+  TEST_METHOD_PROPERTY(L"DataSource", L"Table:ShaderOpArithTable.xml#WaveIntrinsicsMultiPrefixUintTable")
+  END_TEST_METHOD()
+
   // TAEF data-driven tests.
   BEGIN_TEST_METHOD(UnaryFloatOpTest)
     TEST_METHOD_PROPERTY(L"DataSource", L"Table:ShaderOpArithTable.xml#UnaryFloatOpTable")
@@ -504,6 +513,10 @@ public:
   void WaveIntrinsicsActivePrefixTest(TableParameter *pParameterList,
                                       size_t numParameter, bool isPrefix);
 
+  template <typename T>
+  void WaveIntrinsicsMultiPrefixOpTest(TableParameter *pParameterList,
+                                       size_t numParameters);
+
   void BasicTriangleTestSetup(LPCSTR OpName, LPCWSTR FileName, D3D_SHADER_MODEL testModel);
 
   void RunBasicShaderModelTest(D3D_SHADER_MODEL shaderModel);
@@ -3240,6 +3253,22 @@ static TableParameter WaveIntrinsicsPrefixUintParameters[] = {
   { L"Validation.InputSet4", TableParameter::UINT32_TABLE, false }
 };
 
+static TableParameter WaveIntrinsicsMultiPrefixIntParameters[] = {
+  { L"ShaderOp.Name", TableParameter::STRING, true },
+  { L"ShaderOp.Target", TableParameter::STRING, true },
+  { L"ShaderOp.Text", TableParameter::STRING, true },
+  { L"Validation.Keys", TableParameter::INT32_TABLE, true },
+  { L"Validation.Values", TableParameter::INT32_TABLE, true },
+};
+
+static TableParameter WaveIntrinsicsMultiPrefixUintParameters[] = {
+  { L"ShaderOp.Name", TableParameter::STRING, true },
+  { L"ShaderOp.Target", TableParameter::STRING, true },
+  { L"ShaderOp.Text", TableParameter::STRING, true },
+  { L"Validation.Keys", TableParameter::UINT32_TABLE, true },
+  { L"Validation.Values", TableParameter::UINT32_TABLE, true },
+};
+
 static TableParameter WaveIntrinsicsActiveBoolParameters[] = {
   { L"ShaderOp.Name", TableParameter::STRING, true },
   { L"ShaderOp.Text", TableParameter::STRING, true },
@@ -6148,6 +6177,178 @@ TEST_F(ExecutionTest, WaveIntrinsicsPrefixUintTest) {
       /*isPrefix*/ true);
 }
 
+template <typename T>
+static T GetWaveMultiPrefixInitialAccumValue(LPCWSTR testName) {
+  if (_wcsicmp(testName, L"WaveMultiPrefixProduct") == 0 ||
+      _wcsicmp(testName, L"WaveMultiPrefixUProduct") == 0) {
+    return static_cast<T>(1);
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixSum") == 0 ||
+             _wcsicmp(testName, L"WaveMultiPrefixUSum") == 0 ||
+             _wcsicmp(testName, L"WaveMultiPrefixBitOr") == 0 ||
+             _wcsicmp(testName, L"WaveMultiPrefixBitXor") == 0 ||
+             _wcsicmp(testName, L"WaveMultiPrefixCountBits") == 0) {
+    return static_cast<T>(0);
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixBitAnd") == 0) {
+    return static_cast<T>(-1);
+  } else {
+    return static_cast<T>(0);
+  }
+}
+
+template <typename T>
+std::function<T(T, T)> GetWaveMultiPrefixReferenceFunction(LPCWSTR testName) {
+  if (_wcsicmp(testName, L"WaveMultiPrefixProduct") == 0 ||
+      _wcsicmp(testName, L"WaveMultiPrefixUProduct") == 0) {
+    return [] (T lhs, T rhs) -> T { return lhs * rhs; };
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixSum") == 0 ||
+             _wcsicmp(testName, L"WaveMultiPrefixUSum") == 0) {
+    return [] (T lhs, T rhs) -> T { return lhs + rhs; };
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixBitAnd") == 0) {
+    return [] (T lhs, T rhs) -> T { return lhs & rhs; };
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixBitOr") == 0) {
+    return [] (T lhs, T rhs) -> T { return lhs | rhs; };
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixBitXor") == 0) {
+    return [] (T lhs, T rhs) -> T { return lhs ^ rhs; };
+  } else if (_wcsicmp(testName, L"WaveMultiPrefixCountBits") == 0) {
+    // For CountBits, each lane contributes a boolean value. The test input is
+    // an integer, so convert to a boolean by computing (input > 10) If this
+    // condition is true, we contribute one to the bit count.
+    return [] (T lhs, T rhs) -> T { return lhs + (rhs > 10 ? 1 : 0); };
+  } else {
+    return [] (T lhs, T rhs) -> T { return 0; };
+  }
+}
+
+template <class T>
+void
+ExecutionTest::WaveIntrinsicsMultiPrefixOpTest(TableParameter *pParameterList,
+                                               size_t numParameters) {
+  WEX::TestExecution::SetVerifyOutput
+    verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
+
+  struct PerThreadData {
+    uint32_t key;
+    uint32_t firstLaneId;
+    uint32_t laneId;
+    uint32_t mask;
+    T value;
+    T result;
+  };
+
+  constexpr size_t NumThreadsX = 8;
+  constexpr size_t NumThreadsY = 12;
+  constexpr size_t NumThreadsZ = 1;
+
+  constexpr size_t ThreadsPerGroup = NumThreadsX * NumThreadsY * NumThreadsZ;
+  constexpr size_t DispatchGroupSize = 1;
+  constexpr size_t ThreadCount = ThreadsPerGroup * DispatchGroupSize;
+
+  CComPtr<IStream> pStream;
+  ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
+
+  CComPtr<ID3D12Device> pDevice;
+
+  if (!CreateDevice(&pDevice)) {
+    return;
+  }
+
+  if (!DoesDeviceSupportWaveOps(pDevice)) {
+    // Optional feature, so it's correct to not support it if declared as such.
+    WEX::Logging::Log::Comment(L"Device does not support wave operations.");
+    return;
+  }
+
+  std::shared_ptr<st::ShaderOpSet>
+    ShaderOpSet = std::make_shared<st::ShaderOpSet>();
+  st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get());
+
+  TableParameterHandler handler(pParameterList, numParameters);
+  CW2A shaderSource(handler.GetTableParamByName(L"ShaderOp.Text")->m_str);
+  CW2A shaderProfile(handler.GetTableParamByName(L"ShaderOp.Target")->m_str);
+  auto testName = handler.GetTableParamByName(L"ShaderOp.Name")->m_str;
+
+  std::vector<T> *keys = handler.GetDataArray<T>(L"Validation.Keys");
+  std::vector<T> *values = handler.GetDataArray<T>(L"Validation.Values");
+
+  for (size_t maskIndex = 0; maskIndex < _countof(MaskFunctionTable); ++maskIndex) {
+    std::shared_ptr<ShaderOpTestResult> test =
+      RunShaderOpTestAfterParse(pDevice, m_support, "WaveIntrinsicsOp",
+      [&] (LPCSTR name, std::vector<BYTE> &data, st::ShaderOp *pShaderOp) {
+
+        const size_t dataSize = sizeof(PerThreadData) * ThreadCount;
+
+        data.resize(dataSize);
+        PerThreadData *pThreadData = reinterpret_cast<PerThreadData *>(data.data());
+
+        for (size_t i = 0; i != ThreadCount; ++i) {
+          pThreadData[i].key = keys->at(i % keys->size());
+          pThreadData[i].value = values->at(i % values->size());
+          pThreadData[i].firstLaneId = 0xdeadbeef;
+          pThreadData[i].laneId = 0xdeadbeef;
+          pThreadData[i].mask = MaskFunctionTable[maskIndex]((int)i);
+          pThreadData[i].result = 0xdeadbeef;
+        }
+
+        pShaderOp->Shaders.at(0).Text = shaderSource;
+        pShaderOp->Shaders.at(0).Target = shaderProfile;
+      }, ShaderOpSet);
+
+    MappedData mappedData;
+    test->Test->GetReadBackData("SWaveIntrinsicsOp", &mappedData);
+    PerThreadData *resultData = reinterpret_cast<PerThreadData *>(mappedData.data());
+
+    // Partition our data into waves
+    std::map<uint32_t, std::vector<PerThreadData *>> waves;
+
+    for (size_t i = 0, e = ThreadCount; i != e; ++i) {
+      PerThreadData *elt = &resultData[i];
+
+      // Basic sanity checks
+      VERIFY_IS_TRUE(elt->firstLaneId != 0xdeadbeef);
+      VERIFY_IS_TRUE(elt->laneId != 0xdeadbeef);
+
+      waves[elt->firstLaneId].push_back(elt);
+    }
+
+    // Verify each wave
+    auto refFn = GetWaveMultiPrefixReferenceFunction<T>(testName);
+
+    for (auto &w : waves) {
+      std::vector<PerThreadData *> &waveData = w.second;
+
+      LogCommentFmt(L"LaneId    Mask      Key       Value     Result    Expected");
+      LogCommentFmt(L"--------  --------  --------  --------  --------  --------");
+      for (size_t i = 0, e = waveData.size(); i != e; ++i) {
+        PerThreadData *data = waveData[i];
+
+        // Compute prefix operation over each previous lane element that has the
+        // same key value, and is part of the same active thread group
+        T accum = GetWaveMultiPrefixInitialAccumValue<T>(testName);
+        for (unsigned j = 0; j < i; ++j) {
+          if (waveData[j]->key == data->key && waveData[j]->mask == data->mask) {
+            accum = refFn(accum, waveData[j]->value);
+          }
+        }
+
+        LogCommentFmt(L"%08X  %08X  %08X  %08X  %08X  %08X", data->laneId, data->mask, data->key, data->value, data->result, accum);
+
+        VERIFY_IS_TRUE(accum == data->result);
+      }
+      LogCommentFmt(L"\n");
+    }
+  }
+}
+
+TEST_F(ExecutionTest, WaveIntrinsicsSM65IntTest) {
+  WaveIntrinsicsMultiPrefixOpTest<int>(WaveIntrinsicsMultiPrefixIntParameters,
+                                       _countof(WaveIntrinsicsMultiPrefixIntParameters));
+}
+
+TEST_F(ExecutionTest, WaveIntrinsicsSM65UintTest) {
+  WaveIntrinsicsMultiPrefixOpTest<unsigned>(WaveIntrinsicsMultiPrefixUintParameters,
+                                            _countof(WaveIntrinsicsMultiPrefixUintParameters));
+}
+
 TEST_F(ExecutionTest, CBufferTestHalf) {
   WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
   CComPtr<IStream> pStream;

+ 724 - 0
tools/clang/unittests/HLSL/ShaderOpArithTable.xml

@@ -6001,6 +6001,730 @@
             </Parameter>
         </Row>
     </Table>
+    <Table Id="WaveIntrinsicsMultiPrefixIntTable">
+        <ParameterTypes>
+            <ParameterType Name="ShaderOp.Target">String</ParameterType>
+            <ParameterType Name="ShaderOp.Text">String</ParameterType>
+            <ParameterType Array="true" Name="Validation.Keys">String</ParameterType>
+            <ParameterType Array="true" Name="Validation.Values">String</ParameterType>
+        </ParameterTypes>
+        <Row Name="WaveMultiPrefixBitAnd">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixBitAnd</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    int value;
+                    int result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitAnd(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitAnd(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixBitOr">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixBitOr</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    int value;
+                    int result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitOr(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitOr(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixBitXor">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixBitXor</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    int value;
+                    int result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitXor(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitXor(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixSum">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixSum</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    int value;
+                    int result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixSum(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixSum(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixProduct">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixProduct</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    int value;
+                    int result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixProduct(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixProduct(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixCountBits">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixCountBits</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    int value;
+                    int result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixCountBits(data.value > 10, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixCountBits(data.value > 10, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+    </Table>
+    <Table Id="WaveIntrinsicsMultiPrefixUintTable">
+        <ParameterTypes>
+            <ParameterType Name="ShaderOp.Target">String</ParameterType>
+            <ParameterType Name="ShaderOp.Text">String</ParameterType>
+            <ParameterType Array="true" Name="Validation.Keys">String</ParameterType>
+            <ParameterType Array="true" Name="Validation.Values">String</ParameterType>
+        </ParameterTypes>
+        <Row Name="WaveMultiPrefixBitAnd">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixBitAnd</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    uint value;
+                    uint result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitAnd(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitAnd(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixBitOr">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixBitOr</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    uint value;
+                    uint result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitOr(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitOr(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixBitXor">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixBitXor</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    uint value;
+                    uint result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitXor(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixBitXor(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixUSum">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixUSum</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    uint value;
+                    uint result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixSum(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixSum(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixUSum">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixUProduct</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    uint value;
+                    uint result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixProduct(data.value, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixProduct(data.value, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+        <Row Name="WaveMultiPrefixCountBits">
+            <Parameter Name="ShaderOp.Name">WaveMultiPrefixCountBits</Parameter>
+            <Parameter Name="ShaderOp.Target">cs_6_5</Parameter>
+            <Parameter Name="ShaderOp.Text">
+                struct ThreadData {
+                    uint key;
+                    uint firstLaneId;
+                    uint laneId;
+                    uint mask;
+                    uint value;
+                    uint result;
+                };
+
+                RWStructuredBuffer&lt;ThreadData&gt; g_buffer : register(u0);
+
+                [numthreads(8, 12, 1)]
+                void main
+                (
+                    uint id : SV_GroupIndex
+                )
+                {
+                    ThreadData data = g_buffer[id];
+
+                    data.firstLaneId = WaveReadLaneFirst(id);
+                    data.laneId = WaveGetLaneIndex();
+
+                    if (data.mask != 0) {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixCountBits(data.value > 10, mask);
+                    } else {
+                        uint4 mask = WaveMatch(data.key);
+                        data.result = WaveMultiPrefixCountBits(data.value > 10, mask);
+                    }
+
+                    g_buffer[id] = data;
+                }
+            </Parameter>
+            <Parameter Name="Validation.Keys">
+                <Value>0</Value>
+                <Value>3</Value>
+                <Value>1</Value>
+                <Value>5</Value>
+                <Value>4</Value>
+            </Parameter>
+            <Parameter Name="Validation.Values">
+                <Value>10</Value>
+                <Value>42</Value>
+                <Value>1</Value>
+                <Value>64</Value>
+                <Value>11</Value>
+                <Value>76</Value>
+                <Value>90</Value>
+                <Value>111</Value>
+                <Value>9</Value>
+                <Value>6</Value>
+                <Value>79</Value>
+                <Value>34</Value>
+            </Parameter>
+        </Row>
+    </Table>
     <Table Id="DenormBinaryFloatOpTable">
         <ParameterTypes>
             <ParameterType Name="ShaderOp.Target">String</ParameterType>

+ 7 - 0
utils/hct/gen_intrin_main.txt

@@ -262,6 +262,13 @@ $type1 [[unsigned_op=WaveActiveUMax]] WaveActiveMax(in numeric<> value);
 uint   [[]] WavePrefixCountBits(in bool value);
 $type1 [[unsigned_op=WavePrefixUSum]] WavePrefixSum(in numeric<> value);
 $type1 [[unsigned_op=WavePrefixUProduct]] WavePrefixProduct(in numeric<> value);
+uint<4> [[]] WaveMatch(in numeric<> value);
+$type1 [[]] WaveMultiPrefixBitAnd(in any_int<> value, in uint<4> mask);
+$type1 [[]] WaveMultiPrefixBitOr(in any_int<> value, in uint<4> mask);
+$type1 [[]] WaveMultiPrefixBitXor(in any_int<> value, in uint<4> mask);
+uint [[]] WaveMultiPrefixCountBits(in bool value, in uint<4> mask);
+$type1 [[unsigned_op=WaveMultiPrefixUProduct]] WaveMultiPrefixProduct(in numeric<> value, in uint<4> mask);
+$type1 [[unsigned_op=WaveMultiPrefixUSum]] WaveMultiPrefixSum(in numeric<> value, in uint<4> mask);
 $type1 [[]] QuadReadLaneAt(in numeric<> value, in uint quadLane);
 $type1 [[]] QuadReadAcrossX(in numeric<> value);
 $type1 [[]] QuadReadAcrossY(in numeric<> value);

+ 38 - 0
utils/hct/hctdb.py

@@ -377,6 +377,9 @@ class db_dxil(object):
         for i in "Dot4AddU8Packed,Dot4AddI8Packed,Dot2AddHalf".split(","):
             self.name_idx[i].category = "Dot product with accumulate"
             self.name_idx[i].shader_model = 6,4
+        for i in "WaveMatch,WaveMultiPrefixOp,WaveMultiPrefixBitCount".split(","):
+            self.name_idx[i].category = "Wave"
+            self.name_idx[i].shader_model = 6,5
 
     def populate_llvm_instructions(self):
         # Add instructions that map to LLVM instructions.
@@ -1339,6 +1342,41 @@ class db_dxil(object):
         self.set_op_count_for_version(1, 4, next_op_idx)
         assert next_op_idx == 165, "next operation index is %d rather than 165 and thus opcodes are broken" % next_op_idx
 
+        self.add_dxil_op("WaveMatch", next_op_idx, "WaveMatch", "returns the bitmask of active lanes that have the same value", "hfd8wil", "", [
+            db_dxil_param(0, "$u4", "", "operation result"),
+            db_dxil_param(2, "$o", "value", "input value")])
+        next_op_idx += 1
+
+        self.add_dxil_op("WaveMultiPrefixOp", next_op_idx, "WaveMultiPrefixOp", "returns the result of the operation on groups of lanes identified by a bitmask", "hfd8wil", "", [
+            db_dxil_param(0, "$o", "", "operation result"),
+            db_dxil_param(2, "$o", "value", "input value"),
+            db_dxil_param(3, "i32", "mask0", "mask 0"),
+            db_dxil_param(4, "i32", "mask1", "mask 1"),
+            db_dxil_param(5, "i32", "mask2", "mask 2"),
+            db_dxil_param(6, "i32", "mask3", "mask 3"),
+            db_dxil_param(7, "i8", "op", "operation", enum_name="WaveMultiPrefixOpKind", is_const=True),
+            db_dxil_param(8, "i8", "sop", "sign of operands", enum_name="SignedOpKind", is_const=True)])
+        next_op_idx += 1
+        self.add_enum_type("WaveMultiPrefixOpKind", "Kind of cross-lane for multi-prefix operation", [
+            (0, "Sum", "sum of values"),
+            (1, "And", "bitwise and of values"),
+            (2, "Or", "bitwise or of values"),
+            (3, "Xor", "bitwise xor of values"),
+            (4, "Product", "product of values")])
+
+        self.add_dxil_op("WaveMultiPrefixBitCount", next_op_idx, "WaveMultiPrefixBitCount", "returns the count of bits set to 1 on groups of lanes identified by a bitmask", "v", "", [
+            db_dxil_param(0, "i32", "", "operation result"),
+            db_dxil_param(2, "i1", "value", "input value"),
+            db_dxil_param(3, "i32", "mask0", "mask 0"),
+            db_dxil_param(4, "i32", "mask1", "mask 1"),
+            db_dxil_param(5, "i32", "mask2", "mask 2"),
+            db_dxil_param(6, "i32", "mask3", "mask 3")])
+        next_op_idx += 1
+
+        # End of DXIL 1.5 opcodes.
+        self.set_op_count_for_version(1, 5, next_op_idx)
+        assert next_op_idx == 168, "next operation index is %d rather than 168 and thus opcodes are broken" % next_op_idx
+
         # Set interesting properties.
         self.build_indices()
         for i in "CalculateLOD,DerivCoarseX,DerivCoarseY,DerivFineX,DerivFineY,Sample,SampleBias,SampleCmp,TextureGather,TextureGatherCmp".split(","):

Unele fișiere nu au fost afișate deoarece prea multe fișiere au fost modificate în acest diff