소스 검색

[spirv] Translate intrinsic dot product. (#496)

The translation supports dot product of vectors of floats using
SPIR-V's OpDot.

The translation also supports dot product of vectros of integers
using multiplication and addition.
Ehsan 8 년 전
부모
커밋
e6685a310d

+ 4 - 2
docs/SPIR-V.rst

@@ -370,10 +370,12 @@ For a function ``f`` which has a parameter of type ``T``, the generated SPIR-V s
 
 This approach gives us unified handling of function parameters and local variables: both of them are accessed via load/store instructions.
 
-Builtin functions
+Intrinsic functions
 -----------------
 
-[TODO]
+The following intrinsic HLSL functions are currently supported:
+
+- `dot` : performs dot product of two vectors, each containing floats or integers. If the two parameters are vectors of floats, we use SPIR-V's OpDot instruction to perform the translation. If the two parameters are vectors of integers, we multiply corresponding vector elementes using OpIMul and accumulate the results using OpIAdd to compute the dot product.
 
 Logistics
 =========

+ 6 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -99,6 +99,12 @@ public:
   uint32_t createCompositeConstruct(uint32_t resultType,
                                     llvm::ArrayRef<uint32_t> constituents);
 
+  /// \brief Creates a composite extract instruction. The given composite is
+  /// indexed using the given literal indexes to obtain the resulting element.
+  /// Returns the <result-id> for the extracted element.
+  uint32_t createCompositeExtract(uint32_t resultType, uint32_t composite,
+                                  llvm::ArrayRef<uint32_t> indexes);
+
   /// \brief Creates a load instruction loading the value of the given
   /// <result-type> from the given pointer. Returns the <result-id> for the
   /// loaded value.

+ 103 - 0
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -9,6 +9,7 @@
 
 #include "clang/SPIRV/EmitSPIRVAction.h"
 
+#include "dxc/HlslIntrinsicOp.h"
 #include "clang/AST/AST.h"
 #include "clang/AST/ASTConsumer.h"
 #include "clang/AST/ASTContext.h"
@@ -847,9 +848,111 @@ public:
     }
   }
 
+  uint32_t processIntrinsicDot(const CallExpr *callExpr) {
+    const uint32_t returnType =
+        typeTranslator.translateType(callExpr->getType());
+
+    // Get the function parameters. Expect 2 vectors as parameters.
+    assert(callExpr->getNumArgs() == 2u);
+    const Expr *arg0 = callExpr->getArg(0);
+    const Expr *arg1 = callExpr->getArg(1);
+    const uint32_t arg0Id = doExpr(arg0);
+    const uint32_t arg1Id = doExpr(arg1);
+    QualType arg0Type = arg0->getType();
+    QualType arg1Type = arg1->getType();
+    const size_t vec0Size = hlsl::GetHLSLVecSize(arg0Type);
+    const size_t vec1Size = hlsl::GetHLSLVecSize(arg1Type);
+    const QualType vec0ComponentType = hlsl::GetHLSLVecElementType(arg0Type);
+    const QualType vec1ComponentType = hlsl::GetHLSLVecElementType(arg1Type);
+    assert(callExpr->getType() == vec1ComponentType);
+    assert(vec0ComponentType == vec1ComponentType);
+    assert(vec0Size == vec1Size);
+    assert(vec0Size >= 1 && vec0Size <= 4);
+
+    // According to HLSL reference, the dot function only works on integers
+    // and floats.
+    const auto returnTypeBuiltinKind =
+        cast<BuiltinType>(callExpr->getType().getTypePtr())->getKind();
+    assert(returnTypeBuiltinKind == BuiltinType::Float ||
+           returnTypeBuiltinKind == BuiltinType::Int ||
+           returnTypeBuiltinKind == BuiltinType::UInt);
+
+    // Special case: dot product of two vectors, each of size 1. That is
+    // basically the same as regular multiplication of 2 scalars.
+    if (vec0Size == 1) {
+      const spv::Op spvOp = translateOp(BO_Mul, arg0Type);
+      return theBuilder.createBinaryOp(spvOp, returnType, arg0Id, arg1Id);
+    }
+
+    // If the vectors are of type Float, we can use OpDot.
+    if (returnTypeBuiltinKind == BuiltinType::Float) {
+      return theBuilder.createBinaryOp(spv::Op::OpDot, returnType, arg0Id,
+                                       arg1Id);
+    }
+    // Vector component type is Integer (signed or unsigned).
+    // Create all instructions necessary to perform a dot product on
+    // two integer vectors. SPIR-V OpDot does not support integer vectors.
+    // Therefore, we use other SPIR-V instructions (addition and
+    // multiplication).
+    else {
+      uint32_t result = 0;
+      llvm::SmallVector<uint32_t, 4> multIds;
+      const spv::Op multSpvOp = translateOp(BO_Mul, arg0Type);
+      const spv::Op addSpvOp = translateOp(BO_Add, arg0Type);
+
+      // Extract members from the two vectors and multiply them.
+      for (unsigned int i = 0; i < vec0Size; ++i) {
+        const uint32_t vec0member =
+            theBuilder.createCompositeExtract(returnType, arg0Id, {i});
+        const uint32_t vec1member =
+            theBuilder.createCompositeExtract(returnType, arg1Id, {i});
+        const uint32_t multId = theBuilder.createBinaryOp(
+            multSpvOp, returnType, vec0member, vec1member);
+        multIds.push_back(multId);
+      }
+      // Add all the multiplications.
+      result = multIds[0];
+      for (unsigned int i = 1; i < vec0Size; ++i) {
+        const uint32_t additionId =
+            theBuilder.createBinaryOp(addSpvOp, returnType, result, multIds[i]);
+        result = additionId;
+      }
+      return result;
+    }
+  }
+
+  uint32_t processIntrinsicCallExpr(const CallExpr *callExpr) {
+    const FunctionDecl *callee = callExpr->getDirectCallee();
+    assert(hlsl::IsIntrinsicOp(callee) &&
+           "doIntrinsicCallExpr was called for a non-intrinsic function.");
+
+    // Figure out which intrinsic function to translate.
+    llvm::StringRef group;
+    uint32_t opcode = static_cast<uint32_t>(hlsl::IntrinsicOp::Num_Intrinsics);
+    hlsl::GetIntrinsicOp(callee, opcode, group);
+
+    switch (static_cast<hlsl::IntrinsicOp>(opcode)) {
+    case hlsl::IntrinsicOp::IOP_dot: {
+      return processIntrinsicDot(callExpr);
+      break;
+    }
+    default:
+      break;
+    }
+
+    emitError("Intrinsic function '%0' not yet implemented.")
+        << callee->getName();
+    return 0;
+  }
+
   uint32_t doCallExpr(const CallExpr *callExpr) {
     const FunctionDecl *callee = callExpr->getDirectCallee();
 
+    // Intrinsic functions such as 'dot' or 'mul'
+    if (hlsl::IsIntrinsicOp(callee)) {
+      return processIntrinsicCallExpr(callExpr);
+    }
+
     if (callee) {
       const uint32_t returnType =
           typeTranslator.translateType(callExpr->getType());

+ 10 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -138,6 +138,16 @@ ModuleBuilder::createCompositeConstruct(uint32_t resultType,
   return resultId;
 }
 
+uint32_t
+ModuleBuilder::createCompositeExtract(uint32_t resultType, uint32_t composite,
+                                      llvm::ArrayRef<uint32_t> indexes) {
+  assert(insertPoint && "null insert point");
+  const uint32_t resultId = theContext.takeNextId();
+  instBuilder.opCompositeExtract(resultType, resultId, composite, indexes).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return resultId;
+}
+
 uint32_t ModuleBuilder::createLoad(uint32_t resultType, uint32_t pointer) {
   assert(insertPoint && "null insert point");
   const uint32_t resultId = theContext.takeNextId();

+ 162 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.dot.hlsl

@@ -0,0 +1,162 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// According to HLSL reference:
+// The components of the vectors may be either float or int.
+
+void main() {
+    // CHECK:      [[a:%\d+]] = OpLoad %int %a
+    // CHECK-NEXT: [[b:%\d+]] = OpLoad %int %b
+    // CHECK-NEXT: [[intdot1:%\d+]] = OpIMul %int [[a]] [[b]]
+    // CHECK-NEXT: OpStore %c [[intdot1]]
+    int1 a, b;
+    int c;
+    c = dot(a, b);
+
+    // CHECK:      [[d:%\d+]] = OpLoad %v2int %d
+    // CHECK-NEXT: [[e:%\d+]] = OpLoad %v2int %e
+    // CHECK-NEXT: [[d0:%\d+]] = OpCompositeExtract %int [[d]] 0
+    // CHECK-NEXT: [[e0:%\d+]] = OpCompositeExtract %int [[e]] 0
+    // CHECK-NEXT: [[mul_de0:%\d+]] = OpIMul %int [[d0]] [[e0]]
+    // CHECK-NEXT: [[d1:%\d+]] = OpCompositeExtract %int [[d]] 1
+    // CHECK-NEXT: [[e1:%\d+]] = OpCompositeExtract %int [[e]] 1
+    // CHECK-NEXT: [[mul_de1:%\d+]] = OpIMul %int [[d1]] [[e1]]
+    // CHECK-NEXT: [[intdot2:%\d+]] = OpIAdd %int [[mul_de0]] [[mul_de1]]
+    // CHECK-NEXT: OpStore %f [[intdot2]]
+    int2 d, e;
+    int f;
+    f = dot(d, e);
+
+    // CHECK:      [[g:%\d+]] = OpLoad %v3int %g
+    // CHECK-NEXT: [[h:%\d+]] = OpLoad %v3int %h
+    // CHECK-NEXT: [[g0:%\d+]] = OpCompositeExtract %int [[g]] 0
+    // CHECK-NEXT: [[h0:%\d+]] = OpCompositeExtract %int [[h]] 0
+    // CHECK-NEXT: [[mul_gh0:%\d+]] = OpIMul %int [[g0]] [[h0]]
+    // CHECK-NEXT: [[g1:%\d+]] = OpCompositeExtract %int [[g]] 1
+    // CHECK-NEXT: [[h1:%\d+]] = OpCompositeExtract %int [[h]] 1
+    // CHECK-NEXT: [[mul_gh1:%\d+]] = OpIMul %int [[g1]] [[h1]]
+    // CHECK-NEXT: [[g2:%\d+]] = OpCompositeExtract %int [[g]] 2
+    // CHECK-NEXT: [[h2:%\d+]] = OpCompositeExtract %int [[h]] 2
+    // CHECK-NEXT: [[mul_gh2:%\d+]] = OpIMul %int [[g2]] [[h2]]
+    // CHECK-NEXT: [[intdot3_add0:%\d+]] = OpIAdd %int [[mul_gh0]] [[mul_gh1]]
+    // CHECK-NEXT: [[intdot3:%\d+]] = OpIAdd %int [[intdot3_add0]] [[mul_gh2]]
+    // CHECK-NEXT: OpStore %i [[intdot3]]
+    int3 g, h;
+    int i;
+    i = dot(g, h);
+
+    // CHECK:      [[j:%\d+]] = OpLoad %v4int %j
+    // CHECK-NEXT: [[k:%\d+]] = OpLoad %v4int %k
+    // CHECK-NEXT: [[j0:%\d+]] = OpCompositeExtract %int [[j]] 0
+    // CHECK-NEXT: [[k0:%\d+]] = OpCompositeExtract %int [[k]] 0
+    // CHECK-NEXT: [[mul_jk0:%\d+]] = OpIMul %int [[j0]] [[k0]]
+    // CHECK-NEXT: [[j1:%\d+]] = OpCompositeExtract %int [[j]] 1
+    // CHECK-NEXT: [[k1:%\d+]] = OpCompositeExtract %int [[k]] 1
+    // CHECK-NEXT: [[mul_jk1:%\d+]] = OpIMul %int [[j1]] [[k1]]
+    // CHECK-NEXT: [[j2:%\d+]] = OpCompositeExtract %int [[j]] 2
+    // CHECK-NEXT: [[k2:%\d+]] = OpCompositeExtract %int [[k]] 2
+    // CHECK-NEXT: [[mul_jk2:%\d+]] = OpIMul %int [[j2]] [[k2]]
+    // CHECK-NEXT: [[j3:%\d+]] = OpCompositeExtract %int [[j]] 3
+    // CHECK-NEXT: [[k3:%\d+]] = OpCompositeExtract %int [[k]] 3
+    // CHECK-NEXT: [[mul_jk3:%\d+]] = OpIMul %int [[j3]] [[k3]]
+    // CHECK-NEXT: [[intdot4_add0:%\d+]] = OpIAdd %int [[mul_jk0]] [[mul_jk1]]
+    // CHECK-NEXT: [[intdot4_add1:%\d+]] = OpIAdd %int [[intdot4_add0]] [[mul_jk2]]
+    // CHECK-NEXT: [[intdot4:%\d+]] = OpIAdd %int [[intdot4_add1]] [[mul_jk3]]
+    // CHECK-NEXT: OpStore %l [[intdot4]]
+    int4 j, k;
+    int l;
+    l = dot(j, k);
+
+    // CHECK:      [[m:%\d+]] = OpLoad %float %m
+    // CHECK-NEXT: [[n:%\d+]] = OpLoad %float %n
+    // CHECK-NEXT: [[floatdot1:%\d+]] = OpFMul %float [[m]] [[n]]
+    // CHECK-NEXT: OpStore %o [[floatdot1]]
+    float1 m, n;
+    float o;
+    o = dot(m, n);
+
+    // CHECK:      [[p:%\d+]] = OpLoad %v2float %p
+    // CHECK-NEXT: [[q:%\d+]] = OpLoad %v2float %q
+    // CHECK-NEXT: [[floatdot2:%\d+]] = OpDot %float [[p]] [[q]]
+    // CHECK-NEXT: OpStore %r [[floatdot2]]
+    float2 p, q;
+    float r;
+    r = dot(p, q);
+
+    // CHECK:      [[s:%\d+]] = OpLoad %v3float %s
+    // CHECK-NEXT: [[t:%\d+]] = OpLoad %v3float %t
+    // CHECK-NEXT: [[floatdot3:%\d+]] = OpDot %float [[s]] [[t]]
+    // CHECK-NEXT: OpStore %u [[floatdot3]]
+    float3 s, t;
+    float u;
+    u = dot(s, t);
+
+    // CHECK:      [[v:%\d+]] = OpLoad %v4float %v
+    // CHECK-NEXT: [[w:%\d+]] = OpLoad %v4float %w
+    // CHECK-NEXT: [[floatdot4:%\d+]] = OpDot %float [[v]] [[w]]
+    // CHECK-NEXT: OpStore %x [[floatdot4]]
+    float4 v, w;
+    float x;
+    x = dot(v, w);
+
+    // CHECK:      [[ua:%\d+]] = OpLoad %uint %ua
+    // CHECK-NEXT: [[ub:%\d+]] = OpLoad %uint %ub
+    // CHECK-NEXT: [[uintdot1:%\d+]] = OpIMul %uint [[ua]] [[ub]]
+    // CHECK-NEXT: OpStore %uc [[uintdot1]]
+    uint1 ua, ub;
+    uint uc;
+    uc = dot(ua, ub);
+
+    // CHECK:      [[ud:%\d+]] = OpLoad %v2uint %ud
+    // CHECK-NEXT: [[ue:%\d+]] = OpLoad %v2uint %ue
+    // CHECK-NEXT: [[ud0:%\d+]] = OpCompositeExtract %uint [[ud]] 0
+    // CHECK-NEXT: [[ue0:%\d+]] = OpCompositeExtract %uint [[ue]] 0
+    // CHECK-NEXT: [[mul_ude0:%\d+]] = OpIMul %uint [[ud0]] [[ue0]]
+    // CHECK-NEXT: [[ud1:%\d+]] = OpCompositeExtract %uint [[ud]] 1
+    // CHECK-NEXT: [[ue1:%\d+]] = OpCompositeExtract %uint [[ue]] 1
+    // CHECK-NEXT: [[mul_ude1:%\d+]] = OpIMul %uint [[ud1]] [[ue1]]
+    // CHECK-NEXT: [[uintdot2:%\d+]] = OpIAdd %uint [[mul_ude0]] [[mul_ude1]]
+    // CHECK-NEXT: OpStore %uf [[uintdot2]]
+    uint2 ud, ue;
+    uint uf;
+    uf = dot(ud, ue);
+
+    // CHECK:      [[ug:%\d+]] = OpLoad %v3uint %ug
+    // CHECK-NEXT: [[uh:%\d+]] = OpLoad %v3uint %uh
+    // CHECK-NEXT: [[ug0:%\d+]] = OpCompositeExtract %uint [[ug]] 0
+    // CHECK-NEXT: [[uh0:%\d+]] = OpCompositeExtract %uint [[uh]] 0
+    // CHECK-NEXT: [[mul_ugh0:%\d+]] = OpIMul %uint [[ug0]] [[uh0]]
+    // CHECK-NEXT: [[ug1:%\d+]] = OpCompositeExtract %uint [[ug]] 1
+    // CHECK-NEXT: [[uh1:%\d+]] = OpCompositeExtract %uint [[uh]] 1
+    // CHECK-NEXT: [[mul_ugh1:%\d+]] = OpIMul %uint [[ug1]] [[uh1]]
+    // CHECK-NEXT: [[ug2:%\d+]] = OpCompositeExtract %uint [[ug]] 2
+    // CHECK-NEXT: [[uh2:%\d+]] = OpCompositeExtract %uint [[uh]] 2
+    // CHECK-NEXT: [[mul_ugh2:%\d+]] = OpIMul %uint [[ug2]] [[uh2]]
+    // CHECK-NEXT: [[uintdot3_add0:%\d+]] = OpIAdd %uint [[mul_ugh0]] [[mul_ugh1]]
+    // CHECK-NEXT: [[uintdot3:%\d+]] = OpIAdd %uint [[uintdot3_add0]] [[mul_ugh2]]
+    // CHECK-NEXT: OpStore %ui [[uintdot3]]
+    uint3 ug, uh;
+    uint ui;
+    ui = dot(ug, uh);
+
+    // CHECK:      [[uj:%\d+]] = OpLoad %v4uint %uj
+    // CHECK-NEXT: [[uk:%\d+]] = OpLoad %v4uint %uk
+    // CHECK-NEXT: [[uj0:%\d+]] = OpCompositeExtract %uint [[uj]] 0
+    // CHECK-NEXT: [[uk0:%\d+]] = OpCompositeExtract %uint [[uk]] 0
+    // CHECK-NEXT: [[mul_ujk0:%\d+]] = OpIMul %uint [[uj0]] [[uk0]]
+    // CHECK-NEXT: [[uj1:%\d+]] = OpCompositeExtract %uint [[uj]] 1
+    // CHECK-NEXT: [[uk1:%\d+]] = OpCompositeExtract %uint [[uk]] 1
+    // CHECK-NEXT: [[mul_ujk1:%\d+]] = OpIMul %uint [[uj1]] [[uk1]]
+    // CHECK-NEXT: [[uj2:%\d+]] = OpCompositeExtract %uint [[uj]] 2
+    // CHECK-NEXT: [[uk2:%\d+]] = OpCompositeExtract %uint [[uk]] 2
+    // CHECK-NEXT: [[mul_ujk2:%\d+]] = OpIMul %uint [[uj2]] [[uk2]]
+    // CHECK-NEXT: [[uj3:%\d+]] = OpCompositeExtract %uint [[uj]] 3
+    // CHECK-NEXT: [[uk3:%\d+]] = OpCompositeExtract %uint [[uk]] 3
+    // CHECK-NEXT: [[mul_ujk3:%\d+]] = OpIMul %uint [[uj3]] [[uk3]]
+    // CHECK-NEXT: [[uintdot4_add0:%\d+]] = OpIAdd %uint [[mul_ujk0]] [[mul_ujk1]]
+    // CHECK-NEXT: [[uintdot4_add1:%\d+]] = OpIAdd %uint [[uintdot4_add0]] [[mul_ujk2]]
+    // CHECK-NEXT: [[uintdot4:%\d+]] = OpIAdd %uint [[uintdot4_add1]] [[mul_ujk3]]
+    // CHECK-NEXT: OpStore %ul [[uintdot4]]
+    uint4 uj, uk;
+    uint ul;
+    ul = dot(uj, uk);
+}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -110,4 +110,6 @@ TEST_F(FileTest, ControlFlowNestedIfForStmt) { runFileTest("cf.if.for.hlsl"); }
 
 TEST_F(FileTest, FunctionCall) { runFileTest("fn.call.hlsl"); }
 
+TEST_F(FileTest, IntrinsicsDot) { runFileTest("intrinsics.dot.hlsl"); }
+
 } // namespace