Răsfoiți Sursa

[spirv] Support non-fp MAD intrinsic function.

Ehsan 6 ani în urmă
părinte
comite
471aaab03b

+ 3 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -547,6 +547,9 @@ public:
   /// \brief Decorates the given target with patch
   void decoratePatch(SpirvInstruction *target, SourceLocation srcLoc = {});
 
+  /// \brief Decorates the given target with NoContraction
+  void decorateNoContraction(SpirvInstruction *target, SourceLocation loc = {});
+
   /// --- Constants ---
   /// Each of these methods can acquire a unique constant from the SpirvContext,
   /// and add the context to the list of constants in the module.

+ 97 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -7135,6 +7135,103 @@ SpirvInstruction *SPIRVEmitter::processIntrinsicModf(const CallExpr *callExpr) {
   return nullptr;
 }
 
+SpirvInstruction *SPIRVEmitter::processIntrinsicMad(const CallExpr *callExpr) {
+  // Signature is: ret mad(a,b,c)
+  // All of the above must be a scalar, vector, or matrix with the same
+  // component types. Component types can be float or int.
+  // The return value is equal to  "a * b + c"
+
+  // In the case of float arguments, we can use the GLSL extended instruction
+  // set's Fma instruction with NoContraction decoration. In the case of integer
+  // arguments, we'll have to manually perform an OpIMul followed by an OpIAdd
+  // (We should also apply NoContraction decoration to these two instructions to
+  // get precise arithmetic).
+
+  // TODO: We currently don't propagate the NoContraction decoration.
+
+  const Expr *arg0 = callExpr->getArg(0);
+  const Expr *arg1 = callExpr->getArg(1);
+  const Expr *arg2 = callExpr->getArg(2);
+  // All arguments and the return type are the same.
+  const auto argType = arg0->getType();
+  auto *arg0Instr = doExpr(arg0);
+  auto *arg1Instr = doExpr(arg1);
+  auto *arg2Instr = doExpr(arg2);
+
+  // For floating point arguments, we can use the extended instruction set's Fma
+  // instruction. Sadly we can't simply call processIntrinsicUsingGLSLInst
+  // because we need to specifically decorate the Fma instruction with
+  // NoContraction decoration.
+  if (isFloatOrVecMatOfFloatType(argType)) {
+    const auto opcode = GLSLstd450::GLSLstd450Fma;
+    auto *glslInstSet = spvBuilder.getGLSLExtInstSet();
+    // For matrix cases, operate on each row of the matrix.
+    if (isMxNMatrix(arg0->getType())) {
+      const auto actOnEachVec = [this, glslInstSet, arg1Instr,
+                                 arg2Instr](uint32_t index, QualType vecType,
+                                            SpirvInstruction *arg0Row) {
+        auto *arg1Row =
+            spvBuilder.createCompositeExtract(vecType, arg1Instr, {index});
+        auto *arg2Row =
+            spvBuilder.createCompositeExtract(vecType, arg2Instr, {index});
+        auto *fma = spvBuilder.createExtInst(vecType, glslInstSet, opcode,
+                                             {arg0Row, arg1Row, arg2Row});
+        spvBuilder.decorateNoContraction(fma);
+        return fma;
+      };
+      return processEachVectorInMatrix(arg0, arg0Instr, actOnEachVec);
+    }
+    // Non-matrix cases
+    auto *fma = spvBuilder.createExtInst(argType, glslInstSet, opcode,
+                                         {arg0Instr, arg1Instr, arg2Instr});
+    spvBuilder.decorateNoContraction(fma);
+    return fma;
+  }
+
+  // For scalar and vector argument types.
+  {
+    if (isScalarType(argType) || isVectorType(argType)) {
+      auto *mul = spvBuilder.createBinaryOp(spv::Op::OpIMul, argType, arg0Instr,
+                                            arg1Instr);
+      auto *add =
+          spvBuilder.createBinaryOp(spv::Op::OpIAdd, argType, mul, arg2Instr);
+      spvBuilder.decorateNoContraction(mul);
+      spvBuilder.decorateNoContraction(add);
+      return add;
+    }
+  }
+
+  // For matrix argument types.
+  {
+    uint32_t rowCount = 0, colCount = 0;
+    QualType elemType = {};
+    if (isMxNMatrix(argType, &elemType, &rowCount, &colCount)) {
+      const auto colType = astContext.getExtVectorType(elemType, colCount);
+      llvm::SmallVector<SpirvInstruction *, 4> resultRows;
+      for (uint32_t i = 0; i < rowCount; ++i) {
+        auto *rowArg0 =
+            spvBuilder.createCompositeExtract(colType, arg0Instr, {i});
+        auto *rowArg1 =
+            spvBuilder.createCompositeExtract(colType, arg1Instr, {i});
+        auto *rowArg2 =
+            spvBuilder.createCompositeExtract(colType, arg2Instr, {i});
+        auto *mul = spvBuilder.createBinaryOp(spv::Op::OpIMul, colType, rowArg0,
+                                              rowArg1);
+        auto *add =
+            spvBuilder.createBinaryOp(spv::Op::OpIAdd, colType, mul, rowArg2);
+        spvBuilder.decorateNoContraction(mul);
+        spvBuilder.decorateNoContraction(add);
+        resultRows.push_back(add);
+      }
+      return spvBuilder.createCompositeConstruct(argType, resultRows);
+    }
+  }
+
+  emitError("invalid argument type passed to mad intrinsic function",
+            callExpr->getExprLoc());
+  return 0;
+}
+
 SpirvInstruction *SPIRVEmitter::processIntrinsicLit(const CallExpr *callExpr) {
   // Signature is: float4 lit(float n_dot_l, float n_dot_h, float m)
   //

+ 1 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -365,7 +365,7 @@ private:
                                                   bool isAllBarrier);
 
   /// Processes the 'mad' intrinsic function.
-  uint32_t processIntrinsicMad(const CallExpr *);
+  SpirvInstruction *processIntrinsicMad(const CallExpr *);
 
   /// Processes the 'modf' intrinsic function.
   SpirvInstruction *processIntrinsicModf(const CallExpr *);

+ 7 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -1054,6 +1054,13 @@ void SpirvBuilder::decoratePatch(SpirvInstruction *target,
   module->addDecoration(decor);
 }
 
+void SpirvBuilder::decorateNoContraction(SpirvInstruction *target,
+                                         SourceLocation srcLoc) {
+  auto *decor = new (context)
+      SpirvDecoration(srcLoc, target, spv::Decoration::NoContraction);
+  module->addDecoration(decor);
+}
+
 SpirvConstant *SpirvBuilder::getConstantInt(QualType type, llvm::APInt value,
                                             bool specConst) {
   // We do not reuse existing constant integers. Just create a new one.