Просмотр исходного кода

[spirv] Support non-fp MAD intrinsic function. (#1730)

Ehsan 6 лет назад
Родитель
Сommit
883ce4a32c

+ 3 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -431,6 +431,9 @@ public:
   /// \brief Decorates the given target <result-id> with nonuniformEXT
   void decorateNonUniformEXT(uint32_t targetId);
 
+  /// \brief Decorates the given target <result-id> with NoContraction
+  void decorateNoContraction(uint32_t targetId);
+
   // === Type ===
 
   uint32_t getVoidType();

+ 5 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -893,6 +893,11 @@ void ModuleBuilder::decorateNonUniformEXT(uint32_t targetId) {
   theModule.addDecoration(d, targetId);
 }
 
+void ModuleBuilder::decorateNoContraction(uint32_t targetId) {
+  const Decoration *d = Decoration::getNoContraction(theContext);
+  theModule.addDecoration(d, targetId);
+}
+
 #define IMPL_GET_PRIMITIVE_TYPE(ty)                                            \
                                                                                \
   uint32_t ModuleBuilder::get##ty##Type() {                                    \

+ 102 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -6574,6 +6574,9 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_lit:
     retVal = processIntrinsicLit(callExpr);
     break;
+  case hlsl::IntrinsicOp::IOP_mad:
+    retVal = processIntrinsicMad(callExpr);
+    break;
   case hlsl::IntrinsicOp::IOP_modf:
     retVal = processIntrinsicModf(callExpr);
     break;
@@ -6747,7 +6750,6 @@ SpirvEvalInfo SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
     INTRINSIC_OP_CASE(lerp, FMix, true);
     INTRINSIC_OP_CASE(log, Log, true);
     INTRINSIC_OP_CASE(log2, Log2, true);
-    INTRINSIC_OP_CASE(mad, Fma, true);
     INTRINSIC_OP_CASE_SINT_UINT_FLOAT(max, SMax, UMax, FMax, true);
     INTRINSIC_OP_CASE(umax, UMax, true);
     INTRINSIC_OP_CASE_SINT_UINT_FLOAT(min, SMin, UMin, FMin, true);
@@ -7425,6 +7427,105 @@ uint32_t SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
   return 0;
 }
 
+uint32_t SPIRVEmitter::processIntrinsicMad(const CallExpr *callExpr) {
+  // Signature is: ret mad(a,b,c)
+  // All of the above must be a scalar, vector, or matrix with the same
+  // component types. Component types can be float or int.
+  // The return value is equal to  "a * b + c"
+
+  // In the case of float arguments, we can use the GLSL extended instruction
+  // set's Fma instruction with NoContraction decoration. In the case of integer
+  // arguments, we'll have to manually perform an OpIMul followed by an OpIAdd
+  // (We should also apply NoContraction decoration to these two instructions to
+  // get precise arithmetic).
+
+  // TODO: We currently don't propagate the NoContraction decoration.
+
+  const Expr *arg0 = callExpr->getArg(0);
+  const Expr *arg1 = callExpr->getArg(1);
+  const Expr *arg2 = callExpr->getArg(2);
+  // All arguments and the return type are the same.
+  const auto argType = arg0->getType();
+  const auto argTypeId = typeTranslator.translateType(argType);
+  const uint32_t arg0Id = doExpr(arg0);
+  const uint32_t arg1Id = doExpr(arg1);
+  const uint32_t arg2Id = doExpr(arg2);
+
+  // For floating point arguments, we can use the extended instruction set's Fma
+  // instruction. Sadly we can't simply call processIntrinsicUsingGLSLInst
+  // because we need to specifically decorate the Fma instruction with
+  // NoContraction decoration.
+  if (isFloatOrVecMatOfFloatType(argType)) {
+    const auto opcode = GLSLstd450::GLSLstd450Fma;
+    const uint32_t glslInstSetId = theBuilder.getGLSLExtInstSet();
+    // For matrix cases, operate on each row of the matrix.
+    if (isMxNMatrix(arg0->getType())) {
+      const auto actOnEachVec = [this, glslInstSetId, opcode, arg1Id,
+                                 arg2Id](uint32_t index, uint32_t vecType,
+                                         uint32_t arg0RowId) {
+        const uint32_t arg1RowId =
+            theBuilder.createCompositeExtract(vecType, arg1Id, {index});
+        const uint32_t arg2RowId =
+            theBuilder.createCompositeExtract(vecType, arg2Id, {index});
+        const uint32_t fma = theBuilder.createExtInst(
+            vecType, glslInstSetId, opcode, {arg0RowId, arg1RowId, arg2RowId});
+        theBuilder.decorateNoContraction(fma);
+        return fma;
+      };
+      return processEachVectorInMatrix(arg0, arg0Id, actOnEachVec);
+    }
+    // Non-matrix cases
+    const uint32_t fma = theBuilder.createExtInst(
+        argTypeId, glslInstSetId, opcode, {arg0Id, arg1Id, arg2Id});
+    theBuilder.decorateNoContraction(fma);
+    return fma;
+  }
+
+  // For scalar and vector argument types.
+  {
+    if (isScalarType(argType) || isVectorType(argType)) {
+      const auto mul =
+          theBuilder.createBinaryOp(spv::Op::OpIMul, argTypeId, arg0Id, arg1Id);
+      const auto add =
+          theBuilder.createBinaryOp(spv::Op::OpIAdd, argTypeId, mul, arg2Id);
+      theBuilder.decorateNoContraction(mul);
+      theBuilder.decorateNoContraction(add);
+      return add;
+    }
+  }
+
+  // For matrix argument types.
+  {
+    uint32_t rowCount = 0, colCount = 0;
+    QualType elemType = {};
+    if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
+      const auto elemTypeId = typeTranslator.translateType(elemType);
+      const auto colTypeId = theBuilder.getVecType(elemTypeId, colCount);
+      llvm::SmallVector<uint32_t, 4> resultRows;
+      for (uint32_t i = 0; i < rowCount; ++i) {
+        const auto rowArg0 =
+            theBuilder.createCompositeExtract(colTypeId, arg0Id, {i});
+        const auto rowArg1 =
+            theBuilder.createCompositeExtract(colTypeId, arg1Id, {i});
+        const auto rowArg2 =
+            theBuilder.createCompositeExtract(colTypeId, arg2Id, {i});
+        const auto mul = theBuilder.createBinaryOp(spv::Op::OpIMul, colTypeId,
+                                                   rowArg0, rowArg1);
+        const auto add =
+            theBuilder.createBinaryOp(spv::Op::OpIAdd, colTypeId, mul, rowArg2);
+        theBuilder.decorateNoContraction(mul);
+        theBuilder.decorateNoContraction(add);
+        resultRows.push_back(add);
+      }
+      return theBuilder.createCompositeConstruct(argTypeId, resultRows);
+    }
+  }
+
+  emitError("invalid argument type passed to mad intrinsic function",
+            callExpr->getExprLoc());
+  return 0;
+}
+
 uint32_t SPIRVEmitter::processIntrinsicLit(const CallExpr *callExpr) {
   // Signature is: float4 lit(float n_dot_l, float n_dot_h, float m)
   //

+ 3 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -359,6 +359,9 @@ private:
   uint32_t processIntrinsicMemoryBarrier(const CallExpr *, bool isDevice,
                                          bool groupSync, bool isAllBarrier);
 
+  /// Processes the 'mad' intrinsic function.
+  uint32_t processIntrinsicMad(const CallExpr *);
+
   /// Processes the 'modf' intrinsic function.
   uint32_t processIntrinsicModf(const CallExpr *);
 

+ 53 - 5
tools/clang/test/CodeGenSPIRV/intrinsics.mad.hlsl

@@ -2,21 +2,38 @@
 
 // CHECK: [[glsl:%\d+]] = OpExtInstImport "GLSL.std.450"
 
+// CHECK: OpDecorate [[fma1:%\d+]] NoContraction
+// CHECK: OpDecorate [[fma2:%\d+]] NoContraction
+// CHECK: OpDecorate [[fma3:%\d+]] NoContraction
+// CHECK: OpDecorate [[fma4:%\d+]] NoContraction
+// CHECK: OpDecorate [[mul1:%\d+]] NoContraction
+// CHECK: OpDecorate [[add1:%\d+]] NoContraction
+// CHECK: OpDecorate [[mul2:%\d+]] NoContraction
+// CHECK: OpDecorate [[add2:%\d+]] NoContraction
+// CHECK: OpDecorate [[mul3:%\d+]] NoContraction
+// CHECK: OpDecorate [[add3:%\d+]] NoContraction
+// CHECK: OpDecorate [[mul4:%\d+]] NoContraction
+// CHECK: OpDecorate [[add4:%\d+]] NoContraction
+
 void main() {
   float    a1, a2, a3, fma_a;
   float4   b1, b2, b3, fma_b;
   float2x3 c1, c2, c3, fma_c;
 
+  int    d1, d2, d3, fma_d;
+  int4   e1, e2, e3, fma_e;
+  int2x3 f1, f2, f3, fma_f;
+
 // CHECK:      [[a1:%\d+]] = OpLoad %float %a1
 // CHECK-NEXT: [[a2:%\d+]] = OpLoad %float %a2
 // CHECK-NEXT: [[a3:%\d+]] = OpLoad %float %a3
-// CHECK-NEXT:    {{%\d+}} = OpExtInst %float [[glsl]] Fma [[a1]] [[a2]] [[a3]]
+// CHECK-NEXT:    [[fma1]] = OpExtInst %float [[glsl]] Fma [[a1]] [[a2]] [[a3]]
   fma_a = mad(a1, a2, a3);
 
 // CHECK:      [[b1:%\d+]] = OpLoad %v4float %b1
 // CHECK-NEXT: [[b2:%\d+]] = OpLoad %v4float %b2
 // CHECK-NEXT: [[b3:%\d+]] = OpLoad %v4float %b3
-// CHECK-NEXT:    {{%\d+}} = OpExtInst %v4float [[glsl]] Fma [[b1]] [[b2]] [[b3]]
+// CHECK-NEXT:    [[fma2]] = OpExtInst %v4float [[glsl]] Fma [[b1]] [[b2]] [[b3]]
   fma_b = mad(b1, b2, b3);
 
 // CHECK:            [[c1:%\d+]] = OpLoad %mat2v3float %c1
@@ -25,11 +42,42 @@ void main() {
 // CHECK-NEXT:  [[c1_row0:%\d+]] = OpCompositeExtract %v3float [[c1]] 0
 // CHECK-NEXT:  [[c2_row0:%\d+]] = OpCompositeExtract %v3float [[c2]] 0
 // CHECK-NEXT:  [[c3_row0:%\d+]] = OpCompositeExtract %v3float [[c3]] 0
-// CHECK-NEXT: [[fma_row0:%\d+]] = OpExtInst %v3float [[glsl]] Fma [[c1_row0]] [[c2_row0]] [[c3_row0]]
+// CHECK-NEXT:          [[fma3]] = OpExtInst %v3float [[glsl]] Fma [[c1_row0]] [[c2_row0]] [[c3_row0]]
 // CHECK-NEXT:  [[c1_row1:%\d+]] = OpCompositeExtract %v3float [[c1]] 1
 // CHECK-NEXT:  [[c2_row1:%\d+]] = OpCompositeExtract %v3float [[c2]] 1
 // CHECK-NEXT:  [[c3_row1:%\d+]] = OpCompositeExtract %v3float [[c3]] 1
-// CHECK-NEXT: [[fma_row1:%\d+]] = OpExtInst %v3float [[glsl]] Fma [[c1_row1]] [[c2_row1]] [[c3_row1]]
-// CHECK-NEXT:          {{%\d+}} = OpCompositeConstruct %mat2v3float [[fma_row0]] [[fma_row1]]
+// CHECK-NEXT:          [[fma4]] = OpExtInst %v3float [[glsl]] Fma [[c1_row1]] [[c2_row1]] [[c3_row1]]
+// CHECK-NEXT:          {{%\d+}} = OpCompositeConstruct %mat2v3float [[fma3]] [[fma4]]
   fma_c = mad(c1, c2, c3);
+
+// CHECK:       [[d1:%\d+]] = OpLoad %int %d1
+// CHECK-NEXT:  [[d2:%\d+]] = OpLoad %int %d2
+// CHECK-NEXT:  [[d3:%\d+]] = OpLoad %int %d3
+// CHECK-NEXT:     [[mul1]] = OpIMul %int [[d1]] [[d2]]
+// CHECK-NEXT:     [[add1]] = OpIAdd %int [[mul1]] [[d3]]
+  fma_d = mad(d1, d2, d3);
+
+// CHECK:       [[e1:%\d+]] = OpLoad %v4int %e1
+// CHECK-NEXT:  [[e2:%\d+]] = OpLoad %v4int %e2
+// CHECK-NEXT:  [[e3:%\d+]] = OpLoad %v4int %e3
+// CHECK-NEXT:     [[mul2]] = OpIMul %v4int [[e1]] [[e2]]
+// CHECK-NEXT:     [[add2]] = OpIAdd %v4int [[mul2]] [[e3]]
+  fma_e = mad(e1, e2, e3);
+
+// CHECK:           [[f1:%\d+]] = OpLoad %_arr_v3int_uint_2 %f1
+// CHECK-NEXT:      [[f2:%\d+]] = OpLoad %_arr_v3int_uint_2 %f2
+// CHECK-NEXT:      [[f3:%\d+]] = OpLoad %_arr_v3int_uint_2 %f3
+// CHECK-NEXT:  [[f1row0:%\d+]] = OpCompositeExtract %v3int [[f1]] 0
+// CHECK-NEXT:  [[f2row0:%\d+]] = OpCompositeExtract %v3int [[f2]] 0
+// CHECK-NEXT:  [[f3row0:%\d+]] = OpCompositeExtract %v3int [[f3]] 0
+// CHECK-NEXT:         [[mul3]] = OpIMul %v3int [[f1row0]] [[f2row0]]
+// CHECK-NEXT:         [[add3]] = OpIAdd %v3int [[mul3]] [[f3row0]]
+// CHECK-NEXT:  [[f1row1:%\d+]] = OpCompositeExtract %v3int [[f1]] 1
+// CHECK-NEXT:  [[f2row1:%\d+]] = OpCompositeExtract %v3int [[f2]] 1
+// CHECK-NEXT:  [[f3row1:%\d+]] = OpCompositeExtract %v3int [[f3]] 1
+// CHECK-NEXT:         [[mul4]] = OpIMul %v3int [[f1row1]] [[f2row1]]
+// CHECK-NEXT:         [[add4]] = OpIAdd %v3int [[mul4]] [[f3row1]]
+// CHECK-NEXT:         {{%\d+}} = OpCompositeConstruct %_arr_v3int_uint_2 [[add3]] [[add4]]
+  fma_f = mad(f1, f2, f3);
 }
+