Browse Source

[spirv] Add translation for intrinsic rcp. (#713)

Also includes creating 64-bit float constants that are needed for
reciprocating a 64-bit float.
Ehsan 7 years ago
parent
commit
2fa7a957a0

+ 1 - 0
docs/SPIR-V.rst

@@ -1101,6 +1101,7 @@ The following intrinsic HLSL functions are currently supported:
 - ``ddx_fine``: High precision partial derivative with respect to the screen-space x-coordinate. Uses SIR-V ``OpDPdxFine``.
 - ``ddy_fine``: High precision partial derivative with respect to the screen-space y-coordinate. Uses SIR-V ``OpDPdyFine``.
 - ``fwidth``: Returns the absolute value of the partial derivatives of the specified value. Uses SIR-V ``OpFwidth``.
+- ``rcp``: Calculates a fast, approximate, per-component reciprocal. Uses SIR-V ``OpFDiv``.
 
 
 Using GLSL extended instructions

+ 2 - 0
tools/clang/include/clang/SPIRV/Constant.h

@@ -67,6 +67,8 @@ public:
                                    uint32_t value, DecorationSet dec = {});
   static const Constant *getFloat32(SPIRVContext &ctx, uint32_t type_id,
                                     float value, DecorationSet dec = {});
+  static const Constant *getFloat64(SPIRVContext &ctx, uint32_t type_id,
+                                    double value, DecorationSet dec = {});
 
   // TODO: 64-bit float and integer constant implementation
 

+ 1 - 0
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -339,6 +339,7 @@ public:
   uint32_t getConstantInt32(int32_t value);
   uint32_t getConstantUint32(uint32_t value);
   uint32_t getConstantFloat32(float value);
+  uint32_t getConstantFloat64(double value);
   uint32_t getConstantComposite(uint32_t typeId,
                                 llvm::ArrayRef<uint32_t> constituents);
   uint32_t getConstantNull(uint32_t type);

+ 14 - 0
tools/clang/lib/SPIRV/Constant.cpp

@@ -44,6 +44,20 @@ const Constant *Constant::getFloat32(SPIRVContext &ctx, uint32_t type_id,
   return getUniqueConstant(ctx, c);
 }
 
+const Constant *Constant::getFloat64(SPIRVContext &ctx, uint32_t type_id,
+                                     double value, DecorationSet dec) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine.
+  struct wideFloat {
+    uint32_t word0;
+    uint32_t word1;
+  };
+  wideFloat words = cast::BitwiseCast<wideFloat, double>(value);
+  Constant c =
+      Constant(spv::Op::OpConstant, type_id, {words.word0, words.word1}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
 const Constant *Constant::getUint32(SPIRVContext &ctx, uint32_t type_id,
                                     uint32_t value, DecorationSet dec) {
   Constant c = Constant(spv::Op::OpConstant, type_id, {value}, dec);

+ 1 - 0
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -927,6 +927,7 @@ uint32_t ModuleBuilder::getConstant##builderTy(cppTy value) {                  \
 IMPL_GET_PRIMITIVE_CONST(Int32, int32_t)
 IMPL_GET_PRIMITIVE_CONST(Uint32, uint32_t)
 IMPL_GET_PRIMITIVE_CONST(Float32, float)
+IMPL_GET_PRIMITIVE_CONST(Float64, double)
 
 #undef IMPL_GET_PRIMITIVE_VALUE
 

+ 44 - 2
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -3771,6 +3771,9 @@ uint32_t SPIRVEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
   case hlsl::IntrinsicOp::IOP_sincos: {
     return processIntrinsicSinCos(callExpr);
   }
+  case hlsl::IntrinsicOp::IOP_rcp: {
+    return processIntrinsicRcp(callExpr);
+  }
   case hlsl::IntrinsicOp::IOP_saturate: {
     return processIntrinsicSaturate(callExpr);
   }
@@ -4349,6 +4352,35 @@ uint32_t SPIRVEmitter::processIntrinsicDot(const CallExpr *callExpr) {
   }
 }
 
+uint32_t SPIRVEmitter::processIntrinsicRcp(const CallExpr *callExpr) {
+  // 'rcp' takes only 1 argument that is a scalar, vector, or matrix of type
+  // float or double.
+  assert(callExpr->getNumArgs() == 1u);
+  const QualType returnType = callExpr->getType();
+  const uint32_t returnTypeId = typeTranslator.translateType(returnType);
+  const Expr *arg = callExpr->getArg(0);
+  const uint32_t argId = doExpr(arg);
+  const QualType argType = arg->getType();
+
+  // For cases with matrix argument.
+  QualType elemType = {};
+  uint32_t numRows = 0, numCols = 0;
+  if (TypeTranslator::isMxNMatrix(argType, &elemType, &numRows, &numCols)) {
+    const uint32_t vecOne = getVecValueOne(elemType, numCols);
+    const auto actOnEachVec = [this, vecOne](uint32_t /*index*/,
+                                             uint32_t vecType,
+                                             uint32_t curRowId) {
+      return theBuilder.createBinaryOp(spv::Op::OpFDiv, vecType, vecOne,
+                                       curRowId);
+    };
+    return processEachVectorInMatrix(arg, argId, actOnEachVec);
+  }
+
+  // For cases with scalar or vector arguments.
+  return theBuilder.createBinaryOp(spv::Op::OpFDiv, returnTypeId,
+                                   getValueOne(argType), argId);
+}
+
 uint32_t SPIRVEmitter::processIntrinsicAllOrAny(const CallExpr *callExpr,
                                                 spv::Op spvOp) {
   // 'all' and 'any' take only 1 parameter.
@@ -4739,6 +4771,8 @@ uint32_t SPIRVEmitter::getValueOne(QualType type) {
   {
     QualType scalarType = {};
     if (TypeTranslator::isScalarType(type, &scalarType)) {
+      // TODO: Support other types such as short, half, etc.
+
       if (scalarType->isSignedIntegerType()) {
         return theBuilder.getConstantInt32(1);
       }
@@ -4747,8 +4781,14 @@ uint32_t SPIRVEmitter::getValueOne(QualType type) {
         return theBuilder.getConstantUint32(1);
       }
 
-      if (scalarType->isFloatingType()) {
-        return theBuilder.getConstantFloat32(1.0);
+      if (const auto *builtinType = scalarType->getAs<BuiltinType>()) {
+        // TODO: Add support for other types that are not covered yet.
+        switch (builtinType->getKind()) {
+        case BuiltinType::Double:
+          return theBuilder.getConstantFloat64(1.0);
+        case BuiltinType::Float:
+          return theBuilder.getConstantFloat32(1.0);
+        }
       }
     }
   }
@@ -4861,6 +4901,8 @@ uint32_t SPIRVEmitter::translateAPFloat(const llvm::APFloat &floatValue,
   switch (bitwidth) {
   case 32:
     return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+  case 64:
+    return theBuilder.getConstantFloat64(floatValue.convertToDouble());
   default:
     break;
   }

+ 3 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -290,6 +290,9 @@ private:
   /// Processes the 'isFinite' intrinsic function.
   uint32_t processIntrinsicIsFinite(const CallExpr *);
 
+  /// Processes the 'rcp' intrinsic function.
+  uint32_t processIntrinsicRcp(const CallExpr *);
+
   /// Processes the 'sign' intrinsic function for float types.
   /// The FSign instruction in the GLSL instruction set returns a floating point
   /// result. The HLSL sign function, however, returns an integer. An extra

+ 16 - 1
tools/clang/test/CodeGenSPIRV/constant.scalar.hlsl

@@ -1,7 +1,8 @@
 // Run: %dxc -T ps_6_0 -E main
 
 // TODO
-// 16bit & 64bit integer & floats (require additional capability)
+// 16bit & 64bit integer (require additional capability)
+// 16bit floats (require additional capability)
 // float: denormalized numbers, Inf, NaN
 
 void main() {
@@ -46,4 +47,18 @@ void main() {
   float c_float_4_2 = 4.2;
 // CHECK-DAG: %float_n4_2 = OpConstant %float -4.2
   float c_float_n4_2 = -4.2;
+  
+  // double constants
+// CHECK-DAG: %double_0 = OpConstant %double 0
+  double c_double_0 = 0.;
+// CHECK-DAG: %double_n0 = OpConstant %double -0
+  double c_double_n0 = -0.;
+// CHECK-DAG: %double_4_5 = OpConstant %double 4.5
+  double c_double_4_5 = 4.5;
+// CHECK-DAG: %double_n8_2 = OpConstant %double -8.2
+  double c_double_n8_2 = -8.2;
+// CHECK-DAG: %double_1234567898765_32 = OpConstant %double 1234567898765.32
+  double c_large  =  1234567898765.32;
+// CHECK-DAG: %double_n1234567898765_32 = OpConstant %double -1234567898765.32
+  double c_nlarge = -1234567898765.32;
 }

+ 48 - 0
tools/clang/test/CodeGenSPIRV/intrinsics.rcp.hlsl

@@ -0,0 +1,48 @@
+// Run: %dxc -T vs_6_0 -E main
+
+// CHECK: [[v4f1:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+// CHECK: [[v3f1:%\d+]] = OpConstantComposite %v3float %float_1 %float_1 %float_1
+// CHECK: [[v4d1:%\d+]] = OpConstantComposite %v4double %double_1 %double_1 %double_1 %double_1
+// CHECK: [[v3d1:%\d+]] = OpConstantComposite %v3double %double_1 %double_1 %double_1
+
+void main() {
+  float    a, rcpa;
+  float4   b, rcpb;
+  float2x3 c, rcpc;
+  
+  double    d, rcpd;
+  double4   e, rcpe;
+  double2x3 f, rcpf;
+
+// CHECK:      [[a:%\d+]] = OpLoad %float %a
+// CHECK-NEXT:   {{%\d+}} = OpFDiv %float %float_1 [[a]]
+  rcpa = rcp(a);
+
+// CHECK:      [[b:%\d+]] = OpLoad %v4float %b
+// CHECK-NEXT:   {{%\d+}} = OpFDiv %v4float [[v4f1]] [[b]]
+  rcpb = rcp(b);
+
+// CHECK:          [[c:%\d+]] = OpLoad %mat2v3float %c
+// CHECK-NEXT:    [[c0:%\d+]] = OpCompositeExtract %v3float [[c]] 0
+// CHECK-NEXT: [[rcpc0:%\d+]] = OpFDiv %v3float [[v3f1]] [[c0]]
+// CHECK-NEXT:    [[c1:%\d+]] = OpCompositeExtract %v3float [[c]] 1
+// CHECK-NEXT: [[rcpc1:%\d+]] = OpFDiv %v3float [[v3f1]] [[c1]]
+// CHECK-NEXT:       {{%\d+}} = OpCompositeConstruct %mat2v3float [[rcpc0]] [[rcpc1]]
+  rcpc = rcp(c);
+
+// CHECK:      [[d:%\d+]] = OpLoad %double %d
+// CHECK-NEXT:   {{%\d+}} = OpFDiv %double %double_1 [[d]]
+  rcpd = rcp(d);  
+
+// CHECK:    [[e:%\d+]] = OpLoad %v4double %e
+// CHECK-NEXT: {{%\d+}} = OpFDiv %v4double [[v4d1]] [[e]]
+  rcpe = rcp(e);
+
+// CHECK:          [[f:%\d+]] = OpLoad %mat2v3double %f
+// CHECK-NEXT:    [[f0:%\d+]] = OpCompositeExtract %v3double [[f]] 0
+// CHECK-NEXT: [[rcpf0:%\d+]] = OpFDiv %v3double [[v3d1]] [[f0]]
+// CHECK-NEXT:    [[f1:%\d+]] = OpCompositeExtract %v3double [[f]] 1
+// CHECK-NEXT: [[rcpf1:%\d+]] = OpFDiv %v3double [[v3d1]] [[f1]]
+// CHECK-NEXT:       {{%\d+}} = OpCompositeConstruct %mat2v3double [[rcpf0]] [[rcpf1]]
+  rcpf = rcp(f);
+}

+ 1 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -576,6 +576,7 @@ TEST_F(FileTest, IntrinsicsFloatSign) {
   runFileTest("intrinsics.floatsign.hlsl");
 }
 TEST_F(FileTest, IntrinsicsIntSign) { runFileTest("intrinsics.intsign.hlsl"); }
+TEST_F(FileTest, IntrinsicsRcp) { runFileTest("intrinsics.rcp.hlsl"); }
 TEST_F(FileTest, IntrinsicsReflect) { runFileTest("intrinsics.reflect.hlsl"); }
 TEST_F(FileTest, IntrinsicsRefract) { runFileTest("intrinsics.refract.hlsl"); }
 TEST_F(FileTest, IntrinsicsReverseBits) {