Przeglądaj źródła

[spirv] Translate intrinsic msad4 function. (#808)

Ehsan 7 lat temu
rodzic
commit
998bab9dea

+ 12 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -256,6 +256,18 @@ public:
   /// is created; otherwise an OpMemoryBarrier is created.
   void createBarrier(uint32_t exec, uint32_t memory, uint32_t semantics);
 
+  /// \brief Creates an OpBitFieldInsert SPIR-V instruction for the given
+  /// arguments.
+  uint32_t createBitFieldInsert(uint32_t resultType, uint32_t base,
+                                uint32_t insert, uint32_t offset,
+                                uint32_t count);
+
+  /// \brief Creates an OpBitFieldUExtract or OpBitFieldSExtract SPIR-V
+  /// instruction for the given arguments.
+  uint32_t createBitFieldExtract(uint32_t resultType, uint32_t base,
+                                 uint32_t offset, uint32_t count,
+                                 bool isSigned);
+
   /// \brief Creates an OpEmitVertex instruction.
   void createEmitVertex();
 

+ 26 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -563,6 +563,32 @@ void ModuleBuilder::createBarrier(uint32_t execution, uint32_t memory,
   insertPoint->appendInstruction(std::move(constructSite));
 }
 
+uint32_t ModuleBuilder::createBitFieldExtract(uint32_t resultType,
+                                              uint32_t base, uint32_t offset,
+                                              uint32_t count, bool isSigned) {
+  assert(insertPoint && "null insert point");
+  uint32_t resultId = theContext.takeNextId();
+  if (isSigned)
+    instBuilder.opBitFieldSExtract(resultType, resultId, base, offset, count);
+  else
+    instBuilder.opBitFieldUExtract(resultType, resultId, base, offset, count);
+  instBuilder.x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return resultId;
+}
+
+uint32_t ModuleBuilder::createBitFieldInsert(uint32_t resultType, uint32_t base,
+                                             uint32_t insert, uint32_t offset,
+                                             uint32_t count) {
+  assert(insertPoint && "null insert point");
+  uint32_t resultId = theContext.takeNextId();
+  instBuilder
+      .opBitFieldInsert(resultType, resultId, base, insert, offset, count)
+      .x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return resultId;
+}
+
 void ModuleBuilder::createEmitVertex() {
   assert(insertPoint && "null insert point");
   instBuilder.opEmitVertex().x();

+ 167 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -4317,6 +4317,9 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_modf:
     retVal = processIntrinsicModf(callExpr);
     break;
+  case hlsl::IntrinsicOp::IOP_msad4:
+    retVal = processIntrinsicMsad4(callExpr);
+    break;
   case hlsl::IntrinsicOp::IOP_sign: {
     if (isFloatOrVecMatOfFloatType(callExpr->getArg(0)->getType()))
       retVal = processIntrinsicFloatSign(callExpr);
@@ -4547,6 +4550,170 @@ SPIRVEmitter::processIntrinsicInterlockedMethod(const CallExpr *expr,
   return 0;
 }
 
+uint32_t SPIRVEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
+  emitWarning("msad4 intrinsic function is emulated using many SPIR-V "
+              "instructions due to lack of direct SPIR-V equivalent",
+              callExpr->getExprLoc());
+
+  // Compares a 4-byte reference value and an 8-byte source value and
+  // accumulates a vector of 4 sums. Each sum corresponds to the masked sum
+  // of absolute differences of a different byte alignment between the
+  // reference value and the source value.
+
+  // If we have:
+  // uint  v0; // reference
+  // uint2 v1; // source
+  // uint4 v2; // accum
+  // uint4 o0; // result of msad4
+  // uint4 r0, t0; // temporary values
+  //
+  // Then msad4(v0, v1, v2) translates to the following SM5 assembly according
+  // to fxc:
+  //   Step 1:
+  //     ushr r0.xyz, v1.xxxx, l(8, 16, 24, 0)
+  //   Step 2:
+  //         [result], [    width    ], [    offset   ], [ insert ], [ base ]
+  //     bfi   t0.yzw, l(0, 8, 16, 24), l(0, 24, 16, 8),  v1.yyyy  , r0.xxyz
+  //     mov t0.x, v1.x
+  //   Step 3:
+  //     msad o0.xyzw, v0.xxxx, t0.xyzw, v2.xyzw
+
+  const uint32_t glsl = theBuilder.getGLSLExtInstSet();
+  const auto boolType = theBuilder.getBoolType();
+  const auto intType = theBuilder.getInt32Type();
+  const auto uintType = theBuilder.getUint32Type();
+  const auto uint4Type = theBuilder.getVecType(uintType, 4);
+  const uint32_t reference = doExpr(callExpr->getArg(0));
+  const uint32_t source = doExpr(callExpr->getArg(1));
+  const uint32_t accum = doExpr(callExpr->getArg(2));
+  const auto uint0 = theBuilder.getConstantUint32(0);
+  const auto uint8 = theBuilder.getConstantUint32(8);
+  const auto uint16 = theBuilder.getConstantUint32(16);
+  const auto uint24 = theBuilder.getConstantUint32(24);
+
+  // Step 1.
+  const uint32_t v1x = theBuilder.createCompositeExtract(uintType, source, {0});
+  // r0.x = v1xS8 = v1.x shifted by 8 bits
+  uint32_t v1xS8 = theBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
+                                             uintType, v1x, uint8);
+  // r0.y = v1xS16 = v1.x shifted by 16 bits
+  uint32_t v1xS16 = theBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
+                                              uintType, v1x, uint16);
+  // r0.z = v1xS24 = v1.x shifted by 24 bits
+  uint32_t v1xS24 = theBuilder.createBinaryOp(spv::Op::OpShiftLeftLogical,
+                                              uintType, v1x, uint24);
+
+  // Step 2.
+  // Do bfi 3 times. DXIL bfi is equivalent to SPIR-V OpBitFieldInsert.
+  const uint32_t v1y = theBuilder.createCompositeExtract(uintType, source, {1});
+  // Note that t0.x = v1.x, nothing we need to do for that.
+  const uint32_t t0y =
+      theBuilder.createBitFieldInsert(uintType, /*base*/ v1xS8, /*insert*/ v1y,
+                                      /*offset*/ uint24,
+                                      /*width*/ uint8);
+  const uint32_t t0z =
+      theBuilder.createBitFieldInsert(uintType, /*base*/ v1xS16, /*insert*/ v1y,
+                                      /*offset*/ uint16,
+                                      /*width*/ uint16);
+  const uint32_t t0w =
+      theBuilder.createBitFieldInsert(uintType, /*base*/ v1xS24, /*insert*/ v1y,
+                                      /*offset*/ uint8,
+                                      /*width*/ uint24);
+
+  // Step 3. MSAD (Masked Sum of Absolute Differences)
+
+  // Now perform MSAD four times.
+  // Need to mimic this algorithm in SPIR-V!
+  //
+  // UINT msad( UINT ref, UINT src, UINT accum )
+  // {
+  //     for (UINT i = 0; i < 4; i++)
+  //     {
+  //         BYTE refByte, srcByte, absDiff;
+  // 
+  //         refByte = (BYTE)(ref >> (i * 8));
+  //         if (!refByte)
+  //         {
+  //             continue;
+  //         }
+  // 
+  //         srcByte = (BYTE)(src >> (i * 8));
+  //         if (refByte >= srcByte)
+  //         {
+  //             absDiff = refByte - srcByte;
+  //         }
+  //         else
+  //         {
+  //             absDiff = srcByte - refByte;
+  //         }
+  // 
+  //         // The recommended overflow behavior for MSAD is
+  //         // to do a 32-bit saturate. This is not
+  //         // required, however, and wrapping is allowed.
+  //         // So from an application point of view,
+  //         // overflow behavior is undefined.
+  //         if (UINT_MAX - accum < absDiff)
+  //         {
+  //             accum = UINT_MAX;
+  //             break;
+  //         }
+  //         accum += absDiff;
+  //     }
+  // 
+  //     return accum;
+  // }
+
+  llvm::SmallVector<uint32_t, 4> result;
+  const uint32_t accum0 =
+      theBuilder.createCompositeExtract(uintType, accum, {0});
+  const uint32_t accum1 =
+      theBuilder.createCompositeExtract(uintType, accum, {1});
+  const uint32_t accum2 =
+      theBuilder.createCompositeExtract(uintType, accum, {2});
+  const uint32_t accum3 =
+      theBuilder.createCompositeExtract(uintType, accum, {3});
+  const llvm::SmallVector<uint32_t, 4> sources = {v1x, t0y, t0z, t0w};
+  llvm::SmallVector<uint32_t, 4> accums = {accum0, accum1, accum2, accum3};
+  llvm::SmallVector<uint32_t, 4> refBytes;
+  llvm::SmallVector<uint32_t, 4> signedRefBytes;
+  llvm::SmallVector<uint32_t, 4> isRefByteZero;
+  for (uint32_t i = 0; i < 4; ++i) {
+    refBytes.push_back(theBuilder.createBitFieldExtract(
+        uintType, reference, /*offset*/ theBuilder.getConstantUint32(i * 8),
+        /*count*/ uint8, /*isSigned*/ false));
+    signedRefBytes.push_back(
+        theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, refBytes.back()));
+    isRefByteZero.push_back(theBuilder.createBinaryOp(
+        spv::Op::OpIEqual, boolType, refBytes.back(), uint0));
+  }
+
+  for (uint32_t msadNum = 0; msadNum < 4; ++msadNum) {
+    for (uint32_t byteCount = 0; byteCount < 4; ++byteCount) {
+      // 'count' is always 8 because we are extracting 8 bits out of 32.
+      const uint32_t srcByte = theBuilder.createBitFieldExtract(
+          uintType, sources[msadNum],
+          /*offset*/ theBuilder.getConstantUint32(8 * byteCount),
+          /*count*/ uint8, /*isSigned*/ false);
+      const uint32_t signedSrcByte =
+          theBuilder.createUnaryOp(spv::Op::OpBitcast, intType, srcByte);
+      const uint32_t sub = theBuilder.createBinaryOp(
+          spv::Op::OpISub, intType, signedRefBytes[byteCount], signedSrcByte);
+      const uint32_t absSub = theBuilder.createExtInst(
+          intType, glsl, GLSLstd450::GLSLstd450SAbs, {sub});
+      const uint32_t diff = theBuilder.createSelect(
+          uintType, isRefByteZero[byteCount], uint0,
+          theBuilder.createUnaryOp(spv::Op::OpBitcast, uintType, absSub));
+
+      // As pointed out by the DXIL reference above, it is *not* required to
+      // saturate the output to UINT_MAX in case of overflow. Wrapping around is
+      // also allowed. For simplicity, we will wrap around at this point.
+      accums[msadNum] = theBuilder.createBinaryOp(spv::Op::OpIAdd, uintType,
+                                                  accums[msadNum], diff);
+    }
+  }
+  return theBuilder.createCompositeConstruct(uint4Type, accums);
+}
+
 uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
   // Signature is: ret modf(x, ip)
   // [in]    x: the input floating-point value.

+ 3 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -290,6 +290,9 @@ private:
   /// Processes the 'modf' intrinsic function.
   uint32_t processIntrinsicModf(const CallExpr *);
 
+  /// Processes the 'msad4' intrinsic function.
+  uint32_t processIntrinsicMsad4(const CallExpr *);
+
   /// Processes the 'mul' intrinsic function.
   uint32_t processIntrinsicMul(const CallExpr *);
 

+ 190 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.msad4.hlsl

@@ -0,0 +1,190 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: [[glsl:%\d+]] = OpExtInstImport "GLSL.std.450"
+
+uint4 main(uint reference : REF, uint2 source :SOURCE, uint4 accum : ACCUM) : MSAD_RESULT
+{
+
+// CHECK:          [[ref:%\d+]] = OpLoad %uint %reference
+// CHECK-NEXT:     [[src:%\d+]] = OpLoad %v2uint %source
+// CHECK-NEXT:   [[accum:%\d+]] = OpLoad %v4uint %accum
+// CHECK-NEXT:    [[src0:%\d+]] = OpCompositeExtract %uint [[src]] 0
+// CHECK-NEXT:  [[src0s8:%\d+]] = OpShiftLeftLogical %uint [[src0]] %uint_8
+// CHECK-NEXT: [[src0s16:%\d+]] = OpShiftLeftLogical %uint [[src0]] %uint_16
+// CHECK-NEXT: [[src0s24:%\d+]] = OpShiftLeftLogical %uint [[src0]] %uint_24
+// CHECK-NEXT:    [[src1:%\d+]] = OpCompositeExtract %uint [[src]] 1
+// CHECK-NEXT:    [[bfi0:%\d+]] = OpBitFieldInsert %uint [[src0s8]] [[src1]] %uint_24 %uint_8
+// CHECK-NEXT:    [[bfi1:%\d+]] = OpBitFieldInsert %uint [[src0s16]] [[src1]] %uint_16 %uint_16
+// CHECK-NEXT:    [[bfi2:%\d+]] = OpBitFieldInsert %uint [[src0s24]] [[src1]] %uint_8 %uint_24
+// CHECK-NEXT:  [[accum0:%\d+]] = OpCompositeExtract %uint [[accum]] 0
+// CHECK-NEXT:  [[accum1:%\d+]] = OpCompositeExtract %uint [[accum]] 1
+// CHECK-NEXT:  [[accum2:%\d+]] = OpCompositeExtract %uint [[accum]] 2
+// CHECK-NEXT:  [[accum3:%\d+]] = OpCompositeExtract %uint [[accum]] 3
+
+// Now perforoming MSAD four times
+
+// CHECK-NEXT:           [[refByte0:%\d+]] = OpBitFieldUExtract %uint [[ref]] %uint_0 %uint_8
+// CHECK-NEXT:        [[intRefByte0:%\d+]] = OpBitcast %int [[refByte0]]
+// CHECK-NEXT:     [[isRefByte0Zero:%\d+]] = OpIEqual %bool [[refByte0]] %uint_0
+// CHECK-NEXT:           [[refByte1:%\d+]] = OpBitFieldUExtract %uint [[ref]] %uint_8 %uint_8
+// CHECK-NEXT:        [[intRefByte1:%\d+]] = OpBitcast %int [[refByte1]]
+// CHECK-NEXT:     [[isRefByte1Zero:%\d+]] = OpIEqual %bool [[refByte1]] %uint_0
+// CHECK-NEXT:           [[refByte2:%\d+]] = OpBitFieldUExtract %uint [[ref]] %uint_16 %uint_8
+// CHECK-NEXT:        [[intRefByte2:%\d+]] = OpBitcast %int [[refByte2]]
+// CHECK-NEXT:     [[isRefByte2Zero:%\d+]] = OpIEqual %bool [[refByte2]] %uint_0
+// CHECK-NEXT:           [[refByte3:%\d+]] = OpBitFieldUExtract %uint [[ref]] %uint_24 %uint_8
+// CHECK-NEXT:        [[intRefByte3:%\d+]] = OpBitcast %int [[refByte3]]
+// CHECK-NEXT:     [[isRefByte3Zero:%\d+]] = OpIEqual %bool [[refByte3]] %uint_0
+
+// MSAD 0 Byte 0
+// CHECK-NEXT:          [[src0Byte0:%\d+]] = OpBitFieldUExtract %uint [[src0]] %uint_0 %uint_8
+// CHECK-NEXT:       [[intSrc0Byte0:%\d+]] = OpBitcast %int [[src0Byte0]]
+// CHECK-NEXT:               [[sub0:%\d+]] = OpISub %int [[intRefByte0]] [[intSrc0Byte0]]
+// CHECK-NEXT:            [[absSub0:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub0]]
+// CHECK-NEXT:        [[uintAbsSub0:%\d+]] = OpBitcast %uint [[absSub0]]
+// CHECK-NEXT:              [[diff0:%\d+]] = OpSelect %uint [[isRefByte0Zero]] %uint_0 [[uintAbsSub0]]
+// CHECK-NEXT:    [[accum0PlusDiff0:%\d+]] = OpIAdd %uint [[accum0]] [[diff0]]
+
+// MSAD 0 Byte 1
+// CHECK-NEXT:          [[src0Byte1:%\d+]] = OpBitFieldUExtract %uint [[src0]] %uint_8 %uint_8
+// CHECK-NEXT:       [[intSrc0Byte1:%\d+]] = OpBitcast %int [[src0Byte1]]
+// CHECK-NEXT:               [[sub1:%\d+]] = OpISub %int [[intRefByte1]] [[intSrc0Byte1]]
+// CHECK-NEXT:            [[absSub1:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub1]]
+// CHECK-NEXT:        [[uintAbsSub1:%\d+]] = OpBitcast %uint [[absSub1]]
+// CHECK-NEXT:              [[diff1:%\d+]] = OpSelect %uint [[isRefByte1Zero]] %uint_0 [[uintAbsSub1]]
+// CHECK-NEXT:   [[accum0PlusDiff01:%\d+]] = OpIAdd %uint [[accum0PlusDiff0]] [[diff1]]
+
+// MSAD 0 Byte 2
+// CHECK-NEXT:          [[src0Byte2:%\d+]] = OpBitFieldUExtract %uint [[src0]] %uint_16 %uint_8
+// CHECK-NEXT:       [[intSrc0Byte2:%\d+]] = OpBitcast %int [[src0Byte2]]
+// CHECK-NEXT:               [[sub2:%\d+]] = OpISub %int [[intRefByte2]] [[intSrc0Byte2]]
+// CHECK-NEXT:            [[absSub2:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub2]]
+// CHECK-NEXT:        [[uintAbsSub2:%\d+]] = OpBitcast %uint [[absSub2]]
+// CHECK-NEXT:              [[diff2:%\d+]] = OpSelect %uint [[isRefByte2Zero]] %uint_0 [[uintAbsSub2]]
+// CHECK-NEXT:  [[accum0PlusDiff012:%\d+]] = OpIAdd %uint [[accum0PlusDiff01]] [[diff2]]
+
+// MSAD 0 Byte 3
+// CHECK-NEXT:          [[src0Byte3:%\d+]] = OpBitFieldUExtract %uint [[src0]] %uint_24 %uint_8
+// CHECK-NEXT:       [[intSrc0Byte3:%\d+]] = OpBitcast %int [[src0Byte3]]
+// CHECK-NEXT:               [[sub3:%\d+]] = OpISub %int [[intRefByte3]] [[intSrc0Byte3]]
+// CHECK-NEXT:            [[absSub3:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub3]]
+// CHECK-NEXT:        [[uintAbsSub3:%\d+]] = OpBitcast %uint [[absSub3]]
+// CHECK-NEXT:              [[diff3:%\d+]] = OpSelect %uint [[isRefByte3Zero]] %uint_0 [[uintAbsSub3]]
+// CHECK-NEXT: [[accum0PlusDiff0123:%\d+]] = OpIAdd %uint [[accum0PlusDiff012]] [[diff3]]
+
+
+// MSAD 1 Byte 0
+// CHECK-NEXT:          [[src1Byte0:%\d+]] = OpBitFieldUExtract %uint [[bfi0]] %uint_0 %uint_8
+// CHECK-NEXT:       [[intSrc1Byte0:%\d+]] = OpBitcast %int [[src1Byte0]]
+// CHECK-NEXT:               [[sub0:%\d+]] = OpISub %int [[intRefByte0]] [[intSrc1Byte0]]
+// CHECK-NEXT:            [[absSub0:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub0]]
+// CHECK-NEXT:        [[uintAbsSub0:%\d+]] = OpBitcast %uint [[absSub0]]
+// CHECK-NEXT:              [[diff0:%\d+]] = OpSelect %uint [[isRefByte0Zero]] %uint_0 [[uintAbsSub0]]
+// CHECK-NEXT:    [[accum1PlusDiff0:%\d+]] = OpIAdd %uint [[accum1]] [[diff0]]
+
+// MSAD 1 Byte 1
+// CHECK-NEXT:          [[src1Byte1:%\d+]] = OpBitFieldUExtract %uint [[bfi0]] %uint_8 %uint_8
+// CHECK-NEXT:       [[intSrc1Byte1:%\d+]] = OpBitcast %int [[src1Byte1]]
+// CHECK-NEXT:               [[sub1:%\d+]] = OpISub %int [[intRefByte1]] [[intSrc1Byte1]]
+// CHECK-NEXT:            [[absSub1:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub1]]
+// CHECK-NEXT:        [[uintAbsSub1:%\d+]] = OpBitcast %uint [[absSub1]]
+// CHECK-NEXT:              [[diff1:%\d+]] = OpSelect %uint [[isRefByte1Zero]] %uint_0 [[uintAbsSub1]]
+// CHECK-NEXT:   [[accum1PlusDiff01:%\d+]] = OpIAdd %uint [[accum1PlusDiff0]] [[diff1]]
+
+// MSAD 1 Byte 2
+// CHECK-NEXT:          [[src1Byte2:%\d+]] = OpBitFieldUExtract %uint [[bfi0]] %uint_16 %uint_8
+// CHECK-NEXT:       [[intSrc1Byte2:%\d+]] = OpBitcast %int [[src1Byte2]]
+// CHECK-NEXT:               [[sub2:%\d+]] = OpISub %int [[intRefByte2]] [[intSrc1Byte2]]
+// CHECK-NEXT:            [[absSub2:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub2]]
+// CHECK-NEXT:        [[uintAbsSub2:%\d+]] = OpBitcast %uint [[absSub2]]
+// CHECK-NEXT:              [[diff2:%\d+]] = OpSelect %uint [[isRefByte2Zero]] %uint_0 [[uintAbsSub2]]
+// CHECK-NEXT:  [[accum1PlusDiff012:%\d+]] = OpIAdd %uint [[accum1PlusDiff01]] [[diff2]]
+
+// MSAD 1 Byte 3
+// CHECK-NEXT:          [[src1Byte3:%\d+]] = OpBitFieldUExtract %uint [[bfi0]] %uint_24 %uint_8
+// CHECK-NEXT:       [[intSrc1Byte3:%\d+]] = OpBitcast %int [[src1Byte3]]
+// CHECK-NEXT:               [[sub3:%\d+]] = OpISub %int [[intRefByte3]] [[intSrc1Byte3]]
+// CHECK-NEXT:            [[absSub3:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub3]]
+// CHECK-NEXT:        [[uintAbsSub3:%\d+]] = OpBitcast %uint [[absSub3]]
+// CHECK-NEXT:              [[diff3:%\d+]] = OpSelect %uint [[isRefByte3Zero]] %uint_0 [[uintAbsSub3]]
+// CHECK-NEXT: [[accum1PlusDiff0123:%\d+]] = OpIAdd %uint [[accum1PlusDiff012]] [[diff3]]
+
+
+// MSAD 2 Byte 0
+// CHECK-NEXT:          [[src2Byte0:%\d+]] = OpBitFieldUExtract %uint [[bfi1]] %uint_0 %uint_8
+// CHECK-NEXT:       [[intSrc2Byte0:%\d+]] = OpBitcast %int [[src2Byte0]]
+// CHECK-NEXT:               [[sub0:%\d+]] = OpISub %int [[intRefByte0]] [[intSrc2Byte0]]
+// CHECK-NEXT:            [[absSub0:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub0]]
+// CHECK-NEXT:        [[uintAbsSub0:%\d+]] = OpBitcast %uint [[absSub0]]
+// CHECK-NEXT:              [[diff0:%\d+]] = OpSelect %uint [[isRefByte0Zero]] %uint_0 [[uintAbsSub0]]
+// CHECK-NEXT:    [[accum2PlusDiff0:%\d+]] = OpIAdd %uint [[accum2]] [[diff0]]
+
+// MSAD 2 Byte 1
+// CHECK-NEXT:          [[src2Byte1:%\d+]] = OpBitFieldUExtract %uint [[bfi1]] %uint_8 %uint_8
+// CHECK-NEXT:       [[intSrc2Byte1:%\d+]] = OpBitcast %int [[src2Byte1]]
+// CHECK-NEXT:               [[sub1:%\d+]] = OpISub %int [[intRefByte1]] [[intSrc2Byte1]]
+// CHECK-NEXT:            [[absSub1:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub1]]
+// CHECK-NEXT:        [[uintAbsSub1:%\d+]] = OpBitcast %uint [[absSub1]]
+// CHECK-NEXT:              [[diff1:%\d+]] = OpSelect %uint [[isRefByte1Zero]] %uint_0 [[uintAbsSub1]]
+// CHECK-NEXT:   [[accum2PlusDiff01:%\d+]] = OpIAdd %uint [[accum2PlusDiff0]] [[diff1]]
+
+// MSAD 2 Byte 2
+// CHECK-NEXT:          [[src2Byte2:%\d+]] = OpBitFieldUExtract %uint [[bfi1]] %uint_16 %uint_8
+// CHECK-NEXT:       [[intSrc2Byte2:%\d+]] = OpBitcast %int [[src2Byte2]]
+// CHECK-NEXT:               [[sub2:%\d+]] = OpISub %int [[intRefByte2]] [[intSrc2Byte2]]
+// CHECK-NEXT:            [[absSub2:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub2]]
+// CHECK-NEXT:        [[uintAbsSub2:%\d+]] = OpBitcast %uint [[absSub2]]
+// CHECK-NEXT:              [[diff2:%\d+]] = OpSelect %uint [[isRefByte2Zero]] %uint_0 [[uintAbsSub2]]
+// CHECK-NEXT:  [[accum2PlusDiff012:%\d+]] = OpIAdd %uint [[accum2PlusDiff01]] [[diff2]]
+
+// MSAD 2 Byte 3
+// CHECK-NEXT:          [[src2Byte3:%\d+]] = OpBitFieldUExtract %uint [[bfi1]] %uint_24 %uint_8
+// CHECK-NEXT:       [[intSrc2Byte3:%\d+]] = OpBitcast %int [[src2Byte3]]
+// CHECK-NEXT:               [[sub3:%\d+]] = OpISub %int [[intRefByte3]] [[intSrc2Byte3]]
+// CHECK-NEXT:            [[absSub3:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub3]]
+// CHECK-NEXT:        [[uintAbsSub3:%\d+]] = OpBitcast %uint [[absSub3]]
+// CHECK-NEXT:              [[diff3:%\d+]] = OpSelect %uint [[isRefByte3Zero]] %uint_0 [[uintAbsSub3]]
+// CHECK-NEXT: [[accum2PlusDiff0123:%\d+]] = OpIAdd %uint [[accum2PlusDiff012]] [[diff3]]
+
+
+// MSAD 3 Byte 0
+// CHECK-NEXT:          [[src3Byte0:%\d+]] = OpBitFieldUExtract %uint [[bfi2]] %uint_0 %uint_8
+// CHECK-NEXT:       [[intSrc3Byte0:%\d+]] = OpBitcast %int [[src3Byte0]]
+// CHECK-NEXT:               [[sub0:%\d+]] = OpISub %int [[intRefByte0]] [[intSrc3Byte0]]
+// CHECK-NEXT:            [[absSub0:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub0]]
+// CHECK-NEXT:        [[uintAbsSub0:%\d+]] = OpBitcast %uint [[absSub0]]
+// CHECK-NEXT:              [[diff0:%\d+]] = OpSelect %uint [[isRefByte0Zero]] %uint_0 [[uintAbsSub0]]
+// CHECK-NEXT:    [[accum3PlusDiff0:%\d+]] = OpIAdd %uint [[accum3]] [[diff0]]
+
+// MSAD 3 Byte 1
+// CHECK-NEXT:          [[src3Byte1:%\d+]] = OpBitFieldUExtract %uint [[bfi2]] %uint_8 %uint_8
+// CHECK-NEXT:       [[intSrc3Byte1:%\d+]] = OpBitcast %int [[src3Byte1]]
+// CHECK-NEXT:               [[sub1:%\d+]] = OpISub %int [[intRefByte1]] [[intSrc3Byte1]]
+// CHECK-NEXT:            [[absSub1:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub1]]
+// CHECK-NEXT:        [[uintAbsSub1:%\d+]] = OpBitcast %uint [[absSub1]]
+// CHECK-NEXT:              [[diff1:%\d+]] = OpSelect %uint [[isRefByte1Zero]] %uint_0 [[uintAbsSub1]]
+// CHECK-NEXT:   [[accum3PlusDiff01:%\d+]] = OpIAdd %uint [[accum3PlusDiff0]] [[diff1]]
+
+// MSAD 3 Byte 2
+// CHECK-NEXT:          [[src3Byte2:%\d+]] = OpBitFieldUExtract %uint [[bfi2]] %uint_16 %uint_8
+// CHECK-NEXT:       [[intSrc3Byte2:%\d+]] = OpBitcast %int [[src3Byte2]]
+// CHECK-NEXT:               [[sub2:%\d+]] = OpISub %int [[intRefByte2]] [[intSrc3Byte2]]
+// CHECK-NEXT:            [[absSub2:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub2]]
+// CHECK-NEXT:        [[uintAbsSub2:%\d+]] = OpBitcast %uint [[absSub2]]
+// CHECK-NEXT:              [[diff2:%\d+]] = OpSelect %uint [[isRefByte2Zero]] %uint_0 [[uintAbsSub2]]
+// CHECK-NEXT:  [[accum3PlusDiff012:%\d+]] = OpIAdd %uint [[accum3PlusDiff01]] [[diff2]]
+
+// MSAD 3 Byte 3
+// CHECK-NEXT:          [[src3Byte3:%\d+]] = OpBitFieldUExtract %uint [[bfi2]] %uint_24 %uint_8
+// CHECK-NEXT:       [[intSrc3Byte3:%\d+]] = OpBitcast %int [[src3Byte3]]
+// CHECK-NEXT:               [[sub3:%\d+]] = OpISub %int [[intRefByte3]] [[intSrc3Byte3]]
+// CHECK-NEXT:            [[absSub3:%\d+]] = OpExtInst %int [[glsl]] SAbs [[sub3]]
+// CHECK-NEXT:        [[uintAbsSub3:%\d+]] = OpBitcast %uint [[absSub3]]
+// CHECK-NEXT:              [[diff3:%\d+]] = OpSelect %uint [[isRefByte3Zero]] %uint_0 [[uintAbsSub3]]
+// CHECK-NEXT: [[accum3PlusDiff0123:%\d+]] = OpIAdd %uint [[accum3PlusDiff012]] [[diff3]]
+
+// CHECK-NEXT: {{%\d+}} = OpCompositeConstruct %v4uint [[accum0PlusDiff0123]] [[accum1PlusDiff0123]] [[accum2PlusDiff0123]] [[accum3PlusDiff0123]]
+
+  uint4 result = msad4(reference, source, accum);
+  return result;
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -776,6 +776,7 @@ TEST_F(FileTest, IntrinsicsLit) { runFileTest("intrinsics.lit.hlsl"); }
 TEST_F(FileTest, IntrinsicsModf) { runFileTest("intrinsics.modf.hlsl"); }
 TEST_F(FileTest, IntrinsicsMad) { runFileTest("intrinsics.mad.hlsl"); }
 TEST_F(FileTest, IntrinsicsMax) { runFileTest("intrinsics.max.hlsl"); }
+TEST_F(FileTest, IntrinsicsMsad4) { runFileTest("intrinsics.msad4.hlsl"); }
 TEST_F(FileTest, IntrinsicsNormalize) {
   runFileTest("intrinsics.normalize.hlsl");
 }