소스 검색

[spirv] Translate arithmetic operations for integers and floats (#461)

* Covers the following operations: +, -, *, /, %
Lei Zhang 8 년 전
부모
커밋
5b37ea784f

+ 4 - 0
tools/clang/include/clang/SPIRV/InstBuilder.h

@@ -791,6 +791,10 @@ public:
                                            uint32_t result_id, uint32_t value,
                                            uint32_t index);
 
+  // All-in-one method for creating binary operations.
+  InstBuilder &binaryOp(spv::Op op, uint32_t result_type, uint32_t result_id,
+                        uint32_t lhs, uint32_t rhs);
+
   // Methods for building constants.
   InstBuilder &opConstant(uint32_t result_type, uint32_t result_id,
                           uint32_t value);

+ 5 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -91,6 +91,11 @@ public:
   uint32_t createAccessChain(uint32_t resultType, uint32_t base,
                              llvm::ArrayRef<uint32_t> indexes);
 
+  /// \brief Creates a binary operation with the given SPIR-V opcode. Returns
+  /// the <result-id> for the result.
+  uint32_t createBinaryOp(spv::Op op, uint32_t resultType, uint32_t lhs,
+                          uint32_t rhs);
+
   /// \brief Creates a return instruction.
   void createReturn();
   /// \brief Creates a return value instruction.

+ 90 - 4
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -372,19 +372,41 @@ public:
 
   uint32_t doBinaryOperator(const BinaryOperator *expr) {
     const auto opcode = expr->getOpcode();
-    const uint32_t rhs = doExpr(expr->getRHS());
-    const uint32_t lhs = doExpr(expr->getLHS());
 
-    switch (opcode) {
-    case BO_Assign:
+    // Handle assignment first since we need to evaluate rhs before lhs.
+    // For other binary operations, we need to evaluate lhs before rhs.
+    if (opcode == BO_Assign) {
+      const uint32_t rhs = doExpr(expr->getRHS());
+      const uint32_t lhs = doExpr(expr->getLHS());
+
       theBuilder.createStore(lhs, rhs);
       // Assignment returns a rvalue.
       return rhs;
+    }
+
+    const uint32_t lhs = doExpr(expr->getLHS());
+    const uint32_t rhs = doExpr(expr->getRHS());
+    const uint32_t typeId = typeTranslator.translateType(expr->getType());
+    const QualType elemType = expr->getLHS()->getType();
+
+    switch (opcode) {
+    case BO_Add:
+    case BO_Sub:
+    case BO_Mul:
+    case BO_Div:
+    case BO_Rem: {
+      const spv::Op spvOp = translateOp(opcode, elemType);
+      return theBuilder.createBinaryOp(spvOp, typeId, lhs, rhs);
+    }
+    case BO_Assign: {
+      llvm_unreachable("assignment already handled before");
+    } break;
     default:
       break;
     }
 
     emitError("BinaryOperator '%0' is not supported yet.") << opcode;
+    expr->dump();
     return 0;
   }
 
@@ -422,6 +444,70 @@ public:
     }
   }
 
+  spv::Op translateOp(BinaryOperator::Opcode op, QualType type) {
+    // TODO: the following is not considering vector types yet.
+    const bool isSintType = type->isSignedIntegerType();
+    const bool isUintType = type->isUnsignedIntegerType();
+    const bool isFloatType = type->isFloatingType();
+
+#define BIN_OP_CASE_INT_FLOAT(kind, intBinOp, floatBinOp)                      \
+  \
+case BO_##kind : {                                                             \
+    if (isSintType || isUintType) {                                            \
+      return spv::Op::Op##intBinOp;                                            \
+    }                                                                          \
+    if (isFloatType) {                                                         \
+      return spv::Op::Op##floatBinOp;                                          \
+    }                                                                          \
+  }                                                                            \
+  break
+
+#define BIN_OP_CASE_SINT_UINT_FLOAT(kind, sintBinOp, uintBinOp, floatBinOp)    \
+  \
+case BO_##kind : {                                                             \
+    if (isSintType) {                                                          \
+      return spv::Op::Op##sintBinOp;                                           \
+    }                                                                          \
+    if (isUintType) {                                                          \
+      return spv::Op::Op##uintBinOp;                                           \
+    }                                                                          \
+    if (isFloatType) {                                                         \
+      return spv::Op::Op##floatBinOp;                                          \
+    }                                                                          \
+  }                                                                            \
+  break
+
+    switch (op) {
+      BIN_OP_CASE_INT_FLOAT(Add, IAdd, FAdd);
+      BIN_OP_CASE_INT_FLOAT(Sub, ISub, FSub);
+      BIN_OP_CASE_INT_FLOAT(Mul, IMul, FMul);
+      BIN_OP_CASE_SINT_UINT_FLOAT(Div, SDiv, UDiv, FDiv);
+      // According to HLSL spec, "the modulus operator returns the remainder of
+      // a division." "The % operator is defined only in cases where either both
+      // sides are positive or both sides are negative."
+      //
+      // In SPIR-V, there are two reminder operations: Op*Rem and Op*Mod. With
+      // the former, the sign of a non-0 result comes from Operand 1, while
+      // with the latter, from Operand 2.
+      //
+      // For operands with different signs, technically we can map % to either
+      // Op*Rem or Op*Mod since it's undefined behavior. But it is more
+      // consistent with C (HLSL starts as a C derivative) and Clang frontend
+      // const expression evaluation if we map % to Op*Rem.
+      //
+      // Note there is no OpURem in SPIR-V.
+      BIN_OP_CASE_SINT_UINT_FLOAT(Rem, SRem, UMod, FRem);
+    default:
+      break;
+    }
+
+#undef BIN_OP_CASE_INT_FLOAT
+#undef BIN_OP_CASE_SINT_UINT_FLOAT
+
+    emitError("translating binary operator '%0' unimplemented") << op;
+    return spv::Op::OpNop;
+  }
+
   uint32_t translateAPValue(const APValue &value, const QualType targetType) {
     if (targetType->isBooleanType()) {
       const bool boolValue = value.getInt().getBoolValue();

+ 20 - 0
tools/clang/lib/SPIRV/InstBuilderManual.cpp

@@ -26,6 +26,26 @@ std::vector<uint32_t> InstBuilder::take() {
   return result;
 }
 
+InstBuilder &InstBuilder::binaryOp(spv::Op op, uint32_t result_type,
+                                   uint32_t result_id, uint32_t lhs,
+                                   uint32_t rhs) {
+  if (!TheInst.empty()) {
+    TheStatus = Status::NestedInst;
+    return *this;
+  }
+
+  // TODO: check op range
+
+  TheInst.reserve(5);
+  TheInst.emplace_back(static_cast<uint32_t>(op));
+  TheInst.emplace_back(result_type);
+  TheInst.emplace_back(result_id);
+  TheInst.emplace_back(lhs);
+  TheInst.emplace_back(rhs);
+
+  return *this;
+}
+
 InstBuilder &InstBuilder::opConstant(uint32_t resultType, uint32_t resultId,
                                      uint32_t value) {
   if (!TheInst.empty()) {

+ 9 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -142,6 +142,15 @@ uint32_t ModuleBuilder::createAccessChain(uint32_t resultType, uint32_t base,
   return id;
 }
 
+uint32_t ModuleBuilder::createBinaryOp(spv::Op op, uint32_t resultType,
+                                       uint32_t lhs, uint32_t rhs) {
+  assert(insertPoint && "null insert point");
+  const uint32_t id = theContext.takeNextId();
+  instBuilder.binaryOp(op, resultType, id, lhs, rhs).x();
+  insertPoint->appendInstruction(std::move(constructSite));
+  return id;
+}
+
 void ModuleBuilder::createReturn() {
   assert(insertPoint && "null insert point");
   instBuilder.opReturn().x();

+ 121 - 0
tools/clang/test/CodeGenSPIRV/binary-op.arithmetic.scalar.hlsl

@@ -0,0 +1,121 @@
+// Run: %dxc -T ps_6_0 -E main
+
+void main() {
+    int a, b, c;
+    uint i, j, k;
+    float o, p, q;
+
+// CHECK:      [[a0:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b0:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[c0:%\d+]] = OpIAdd %int [[a0]] [[b0]]
+// CHECK-NEXT: OpStore %c [[c0]]
+    c = a + b;
+// CHECK-NEXT: [[i0:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j0:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[k0:%\d+]] = OpIAdd %uint [[i0]] [[j0]]
+// CHECK-NEXT: OpStore %k [[k0]]
+    k = i + j;
+// CHECK-NEXT: [[o0:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p0:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[q0:%\d+]] = OpFAdd %float [[o0]] [[p0]]
+// CHECK-NEXT: OpStore %q [[q0]]
+    q = o + p;
+
+// CHECK-NEXT: [[a1:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b1:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[c1:%\d+]] = OpISub %int [[a1]] [[b1]]
+// CHECK-NEXT: OpStore %c [[c1]]
+    c = a - b;
+// CHECK-NEXT: [[i1:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j1:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[k1:%\d+]] = OpISub %uint [[i1]] [[j1]]
+// CHECK-NEXT: OpStore %k [[k1]]
+    k = i - j;
+// CHECK-NEXT: [[o1:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p1:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[q1:%\d+]] = OpFSub %float [[o1]] [[p1]]
+// CHECK-NEXT: OpStore %q [[q1]]
+    q = o - p;
+
+// CHECK-NEXT: [[a2:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b2:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[c2:%\d+]] = OpIMul %int [[a2]] [[b2]]
+// CHECK-NEXT: OpStore %c [[c2]]
+    c = a * b;
+// CHECK-NEXT: [[i2:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j2:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[k2:%\d+]] = OpIMul %uint [[i2]] [[j2]]
+// CHECK-NEXT: OpStore %k [[k2]]
+    k = i * j;
+// CHECK-NEXT: [[o2:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p2:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[q2:%\d+]] = OpFMul %float [[o2]] [[p2]]
+// CHECK-NEXT: OpStore %q [[q2]]
+    q = o * p;
+
+// CHECK-NEXT: [[a3:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b3:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[c3:%\d+]] = OpSDiv %int [[a3]] [[b3]]
+// CHECK-NEXT: OpStore %c [[c3]]
+    c = a / b;
+// CHECK-NEXT: [[i3:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j3:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[k3:%\d+]] = OpUDiv %uint [[i3]] [[j3]]
+// CHECK-NEXT: OpStore %k [[k3]]
+    k = i / j;
+// CHECK-NEXT: [[o3:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p3:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[q3:%\d+]] = OpFDiv %float [[o3]] [[p3]]
+// CHECK-NEXT: OpStore %q [[q3]]
+    q = o / p;
+
+// CHECK-NEXT: [[a4:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b4:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[c4:%\d+]] = OpSRem %int [[a4]] [[b4]]
+// CHECK-NEXT: OpStore %c [[c4]]
+    c = a % b;
+// CHECK-NEXT: [[i4:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j4:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[k4:%\d+]] = OpUMod %uint [[i4]] [[j4]]
+// CHECK-NEXT: OpStore %k [[k4]]
+    k = i % j;
+// CHECK-NEXT: [[o4:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p4:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[q4:%\d+]] = OpFRem %float [[o4]] [[p4]]
+// CHECK-NEXT: OpStore %q [[q4]]
+    q = o % p;
+
+// CHECK-NEXT: [[a5:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[b5:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[a6:%\d+]] = OpLoad %int %a
+// CHECK-NEXT: [[in0:%\d+]] = OpIMul %int [[b5]] [[a6]]
+// CHECK-NEXT: [[c5:%\d+]] = OpLoad %int %c
+// CHECK-NEXT: [[in1:%\d+]] = OpSDiv %int [[in0]] [[c5]]
+// CHECK-NEXT: [[b6:%\d+]] = OpLoad %int %b
+// CHECK-NEXT: [[in2:%\d+]] = OpSRem %int [[in1]] [[b6]]
+// CHECK-NEXT: [[in3:%\d+]] = OpIAdd %int [[a5]] [[in2]]
+// CHECK-NEXT: OpStore %c [[in3]]
+    c = a + b * a / c % b;
+// CHECK-NEXT: [[i5:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[j5:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[i6:%\d+]] = OpLoad %uint %i
+// CHECK-NEXT: [[in4:%\d+]] = OpIMul %uint [[j5]] [[i6]]
+// CHECK-NEXT: [[k5:%\d+]] = OpLoad %uint %k
+// CHECK-NEXT: [[in5:%\d+]] = OpUDiv %uint [[in4]] [[k5]]
+// CHECK-NEXT: [[j6:%\d+]] = OpLoad %uint %j
+// CHECK-NEXT: [[in6:%\d+]] = OpUMod %uint [[in5]] [[j6]]
+// CHECK-NEXT: [[in7:%\d+]] = OpIAdd %uint [[i5]] [[in6]]
+// CHECK-NEXT: OpStore %k [[in7]]
+    k = i + j * i / k % j;
+// CHECK-NEXT: [[o5:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[p5:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[o6:%\d+]] = OpLoad %float %o
+// CHECK-NEXT: [[in8:%\d+]] = OpFMul %float [[p5]] [[o6]]
+// CHECK-NEXT: [[q5:%\d+]] = OpLoad %float %q
+// CHECK-NEXT: [[in9:%\d+]] = OpFDiv %float [[in8]] [[q5]]
+// CHECK-NEXT: [[p6:%\d+]] = OpLoad %float %p
+// CHECK-NEXT: [[in10:%\d+]] = OpFRem %float [[in9]] [[p6]]
+// CHECK-NEXT: [[in11:%\d+]] = OpFAdd %float [[o5]] [[in10]]
+// CHECK-NEXT: OpStore %q [[in11]]
+    q = o + p * o / q % p;
+}

+ 4 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -40,4 +40,8 @@ TEST_F(FileTest, ScalarConstants) { runFileTest("constant.scalar.hlsl"); }
 
 TEST_F(FileTest, BinaryOpAssign) { runFileTest("binary-op.assign.hlsl"); }
 
+TEST_F(FileTest, BinaryOpScalarArithmetic) {
+  runFileTest("binary-op.arithmetic.scalar.hlsl");
+}
+
 } // namespace