Parcourir la source

[spirv] Better handling of ConditionalOperator. (#983)

Also improved the literal type hints.
Ehsan il y a 7 ans
Parent
commit
7ad893a79e

+ 85 - 25
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1568,9 +1568,6 @@ void SPIRVEmitter::doSwitchStmt(const SwitchStmt *switchStmt,
 
 SpirvEvalInfo
 SPIRVEmitter::doArraySubscriptExpr(const ArraySubscriptExpr *expr) {
-  // Provide a hint to the TypeTranslator that the integer literal used to
-  // index into the array should be translated as a 32-bit integer.
-  TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
 
   llvm::SmallVector<uint32_t, 4> indices;
   auto info = loadIfAliasVarRef(collectArrayStructIndices(expr, &indices));
@@ -2138,6 +2135,30 @@ SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
 
 SpirvEvalInfo
 SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
+  // Enhancement for special case when the ConditionalOperator return type is a
+  // literal type. For example:
+  //
+  // float a = cond ? 1 : 2;
+  // int   b = cond ? 1.5 : 2.5;
+  //
+  // There will be no indications about whether '1' and '2' should be used as
+  // 32-bit or 64-bit integers. Similarly, there will be no indication about
+  // whether '1.5' and '2.5' should be used as 32-bit or 64-bit floats.
+  //
+  // We want to avoid using 64-bit int and 64-bit float as much as possible.
+  //
+  // Note that if the literal is in fact large enough that it can't be
+  // represented in 32 bits (e.g. integer larger than 3e+9), we should *not*
+  // provide a hint.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator);
+  if (canBeRepresentedIn32Bits(expr->getTrueExpr()) &&
+      canBeRepresentedIn32Bits(expr->getFalseExpr())) {
+    if (expr->getType()->isSpecificBuiltinType(BuiltinType::LitInt))
+      hint.setHint(astContext.IntTy);
+    else if (expr->getType()->isSpecificBuiltinType(BuiltinType::LitFloat))
+      hint.setHint(astContext.FloatTy);
+  }
+
   // According to HLSL doc, all sides of the ?: expression are always
   // evaluated.
   const uint32_t type = typeTranslator.translateType(expr->getType());
@@ -4703,6 +4724,10 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
     return base;
   }
 
+  // Provide a hint to the TypeTranslator that the integer literal used to
+  // index into the following cases should be translated as a 32-bit integer.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
+
   if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
     // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
     // cast. We need to ingore it to avoid creating OpLoad.
@@ -6846,6 +6871,34 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
   return 0;
 }
 
+bool SPIRVEmitter::canBeRepresentedIn32Bits(const Expr *expr) {
+  if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
+    const bool isSigned = expr->getType()->isSignedIntegerType();
+    const llvm::APInt &value = intLiteral->getValue();
+    return (isSigned && value.isSignedIntN(32)) ||
+           (!isSigned && value.isIntN(32));
+  }
+
+  if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
+    llvm::APFloat value = floatLiteral->getValue();
+    const auto &semantics = value.getSemantics();
+    // regular 'half' and 'float' can be represented in 32 bits.
+    if (&semantics == &llvm::APFloat::IEEEsingle ||
+        &semantics == &llvm::APFloat::IEEEhalf)
+      return true;
+
+    // See if 'double' value can be represented in 32 bits without losing info.
+    bool losesInfo = false;
+    const auto convertStatus =
+        value.convert(llvm::APFloat::IEEEsingle,
+                      llvm::APFloat::rmNearestTiesToEven, &losesInfo);
+    if (convertStatus == llvm::APFloat::opOK && !losesInfo)
+      return true;
+  }
+
+  return false;
+}
+
 uint32_t SPIRVEmitter::tryToEvaluateAsInt32(const llvm::APInt &intValue,
                                             bool isSigned) {
   if (isSigned && intValue.isSignedIntN(32)) {
@@ -6885,34 +6938,41 @@ uint32_t SPIRVEmitter::tryToEvaluateAsFloat32(const llvm::APFloat &floatValue) {
 
 uint32_t SPIRVEmitter::translateAPFloat(const llvm::APFloat &floatValue,
                                         QualType targetType) {
+  // The float value may have to go through conversion, so work on a local copy.
+  llvm::APFloat value = floatValue;
+  const auto valueBitwidth = llvm::APFloat::getSizeInBits(value.getSemantics());
+
+  // Find out the target bitwidth.
   targetType = typeTranslator.getIntendedLiteralType(targetType);
-  const auto &semantics = astContext.getFloatTypeSemantics(targetType);
-  const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
-  switch (bitwidth) {
-  case 16: {
-    if (spirvOptions.enable16BitTypes) {
-      return theBuilder.getConstantFloat16(
-          static_cast<uint16_t>(floatValue.bitcastToAPInt().getZExtValue()));
-    } else {
-      // If 16-bit types are not enabled, treat as 32-bit float.
-      llvm::APFloat f32 = floatValue;
-      bool losesInfo = false;
-      f32.convert(llvm::APFloat::IEEEsingle,
-                  llvm::APFloat::roundingMode::rmTowardZero, &losesInfo);
-      // Conversion from 16-bit float value to 32-bit float value should be
-      // loss-less.
-      assert(!losesInfo);
-      return theBuilder.getConstantFloat32(f32.convertToFloat());
-    }
-  }
+  auto targetBitwidth = llvm::APFloat::getSizeInBits(
+      astContext.getFloatTypeSemantics(targetType));
+  // If 16-bit types are not enabled, treat them as 32-bit float.
+  if (targetBitwidth == 16 && !spirvOptions.enable16BitTypes)
+    targetBitwidth = 32;
+
+  if (targetBitwidth != valueBitwidth) {
+    bool losesInfo = false;
+    const llvm::fltSemantics &targetSemantics =
+        targetBitwidth == 16 ? llvm::APFloat::IEEEhalf
+                             : targetBitwidth == 32 ? llvm::APFloat::IEEEsingle
+                                                    : llvm::APFloat::IEEEdouble;
+    value.convert(targetSemantics, llvm::APFloat::roundingMode::rmTowardZero,
+                  &losesInfo);
+  }
+
+  switch (targetBitwidth) {
+  case 16:
+    return theBuilder.getConstantFloat16(
+        static_cast<uint16_t>(value.bitcastToAPInt().getZExtValue()));
   case 32:
-    return theBuilder.getConstantFloat32(floatValue.convertToFloat());
+    return theBuilder.getConstantFloat32(value.convertToFloat());
   case 64:
-    return theBuilder.getConstantFloat64(floatValue.convertToDouble());
+    return theBuilder.getConstantFloat64(value.convertToDouble());
   default:
     break;
   }
-  emitError("APFloat for target bitwidth %0 unimplemented", {}) << bitwidth;
+  emitError("APFloat for target bitwidth %0 unimplemented", {})
+      << targetBitwidth;
   return 0;
 }
 

+ 6 - 0
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -448,6 +448,12 @@ private:
   /// constant for that value.
   uint32_t tryToEvaluateAsInt32(const llvm::APInt &, bool isSigned);
 
+  /// Returns true iff the given expression is a literal integer that can be
+  /// represented in a 32-bit integer type or a literal float that can be
+  /// represented in a 32-bit float type without losing info. Returns false
+  /// otherwise.
+  bool canBeRepresentedIn32Bits(const Expr* expr);
+
 private:
   /// Translates the given HLSL loop attribute into SPIR-V loop control mask.
   /// Emits an error if the given attribute is not a loop attribute.

+ 40 - 3
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -118,13 +118,24 @@ bool TypeTranslator::isOpaqueStructType(QualType type) {
   return false;
 }
 
+void TypeTranslator::LiteralTypeHint::setHint(QualType ty) {
+  // You can set hint only once for each object.
+  assert(type == QualType());
+  type = ty;
+  translator.pushIntendedLiteralType(type);
+}
+
+TypeTranslator::LiteralTypeHint::LiteralTypeHint(TypeTranslator &t)
+    : translator(t), type({}) {}
+
 TypeTranslator::LiteralTypeHint::LiteralTypeHint(TypeTranslator &t, QualType ty)
     : translator(t), type(ty) {
   if (!isLiteralType(type))
     translator.pushIntendedLiteralType(type);
 }
+
 TypeTranslator::LiteralTypeHint::~LiteralTypeHint() {
-  if (!isLiteralType(type))
+  if (type != QualType() && !isLiteralType(type))
     translator.popIntendedLiteralType();
 }
 
@@ -154,8 +165,34 @@ void TypeTranslator::pushIntendedLiteralType(QualType type) {
 }
 
 QualType TypeTranslator::getIntendedLiteralType(QualType type) {
-  if (!intendedLiteralTypes.empty())
-    return intendedLiteralTypes.top();
+  if (!intendedLiteralTypes.empty()) {
+    // If the stack is not empty, there is potentially a useful hint about how a
+    // given literal should be translated.
+    //
+    // However, a hint should not be returned blindly. It is possible that casts
+    // are occuring. For Example:
+    //
+    //   TU
+    //    |_ n1: <IntegralToFloating> float
+    //       |_ n2: ConditionalOperator 'literal int'
+    //          |_ n3: cond, bool
+    //          |_ n4: 'literal int' 2
+    //          |_ n5: 'literal int' 3
+    //
+    // When evaluating the return type of ConditionalOperator, we shouldn't
+    // provide 'float' as hint. The cast AST node should take care of that.
+    // In the above example, we have no hints about how '2' or '3' should be
+    // used.
+    QualType potentialHint = intendedLiteralTypes.top();
+    const bool isDifferentBasicType =
+        (type->isSpecificBuiltinType(BuiltinType::LitInt) &&
+         !potentialHint->isIntegerType()) ||
+        (type->isSpecificBuiltinType(BuiltinType::LitFloat) &&
+         !potentialHint->isFloatingType());
+
+    if (!isDifferentBasicType)
+      return intendedLiteralTypes.top();
+  }
 
   // We don't have any useful hints, return the given type itself.
   return type;

+ 2 - 0
tools/clang/lib/SPIRV/TypeTranslator.h

@@ -261,6 +261,8 @@ public:
   class LiteralTypeHint {
   public:
     LiteralTypeHint(TypeTranslator &t, QualType ty);
+    LiteralTypeHint(TypeTranslator &t);
+    void setHint(QualType ty);
     ~LiteralTypeHint();
 
   private:

+ 1 - 1
tools/clang/test/CodeGenSPIRV/method.consume-structured-buffer.consume.hlsl

@@ -49,7 +49,7 @@ float4 main() : A {
 // CHECK-NEXT: [[prev:%\d+]] = OpAtomicISub %int [[counter]] %uint_1 %uint_0 %int_1
 // CHECK-NEXT: [[index:%\d+]] = OpISub %int [[prev]] %int_1
 // CHECK-NEXT: [[buffer3:%\d+]] = OpAccessChain %_ptr_Uniform_T %buffer3 %uint_0 [[index]]
-// CHECK-NEXT: [[ac:%\d+]] = OpAccessChain %_ptr_Uniform_v3float [[buffer3]] %int_0 %uint_3 %int_1
+// CHECK-NEXT: [[ac:%\d+]] = OpAccessChain %_ptr_Uniform_v3float [[buffer3]] %int_0 %int_3 %int_1
 // CHECK-NEXT: [[val:%\d+]] = OpLoad %v3float [[ac]]
 // CHECK-NEXT: OpStore %val [[val]]
     float3 val = buffer3.Consume().s[3].b;

+ 1 - 1
tools/clang/test/CodeGenSPIRV/method.structured-buffer.load.hlsl

@@ -31,7 +31,7 @@ float4 main(int index: A) : SV_Target {
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[x]]
 
 // CHECK:      [[index:%\d+]] = OpLoad %int %index
-// CHECK-NEXT: [[f012:%\d+]] = OpAccessChain %_ptr_Uniform_float %mySBuffer2 %int_0 [[index]] %int_1 %uint_0 %uint_1 %uint_2
+// CHECK-NEXT: [[f012:%\d+]] = OpAccessChain %_ptr_Uniform_float %mySBuffer2 %int_0 [[index]] %int_1 %int_0 %uint_1 %uint_2
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[f012]]
     return mySBuffer1.Load(5).f1.x + mySBuffer2.Load(index).f2[0][1][2];
 }

+ 3 - 2
tools/clang/test/CodeGenSPIRV/op.array.access.hlsl

@@ -35,7 +35,7 @@ float main(float val: A, uint index: B) : C {
 // CHECK-NEXT:  [[res:%\d+]] = OpVectorShuffle %v4float [[vec4]] [[vec2]] 0 1 5 4
 // CHECK-NEXT:                 OpStore [[ptr0]] [[res]]
     vecvar[3].ab = val;
-// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Function_float %vecvar %uint_2 %uint_1
+// CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Function_float %vecvar %int_2 %uint_1
 // CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr2]]
 // CHECK-NEXT:                 OpStore %r [[load]]
     r = vecvar[2][1];
@@ -50,10 +50,11 @@ float main(float val: A, uint index: B) : C {
 // CHECK-NEXT: [[ptr2:%\d+]] = OpAccessChain %_ptr_Function_float [[ptr0]] %int_1 %int_2
 // CHECK-NEXT:                 OpStore [[ptr2]] [[val1]]
     matvar[2]._12_23 = val;
-// CHECK-NEXT: [[ptr4:%\d+]] = OpAccessChain %_ptr_Function_float %matvar %uint_0 %uint_1 %uint_2
+// CHECK-NEXT: [[ptr4:%\d+]] = OpAccessChain %_ptr_Function_float %matvar %int_0 %uint_1 %uint_2
 // CHECK-NEXT: [[load:%\d+]] = OpLoad %float [[ptr4]]
 // CHECK-NEXT:                 OpStore %r [[load]]
     r = matvar[0][1][2];
 
     return r;
 }
+// CHECK-WHOLE-SPIR-V:

+ 4 - 4
tools/clang/test/CodeGenSPIRV/op.rw-structured-buffer.access.hlsl

@@ -16,10 +16,10 @@ struct T {
 RWStructuredBuffer<T> MySbuffer;
 
 void main(uint index: A) {
-// CHECK:      [[c12:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %uint_2 %uint_1 %uint_2
+// CHECK:      [[c12:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %int_2 %uint_1 %uint_2
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[c12]]
 
-// CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %uint_0 %int_0
+// CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %int_0 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
     float val = MySbuffer[2].c[2][1][2] + MySbuffer[3].s[0].f;
 
@@ -29,10 +29,10 @@ void main(uint index: A) {
 // CHECK-NEXT:  [[t3:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 [[index]] %int_4 %int_3
 // CHECK-NEXT:  OpStore [[t3]] [[val]]
 
-// CHECK:       [[f:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %uint_0 %int_0
+// CHECK:       [[f:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %int_0 %int_0
 // CHECK-NEXT:  OpStore [[f]] [[val]]
 
-// CHECK-NEXT:  [[c212:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %uint_2 %uint_1 %uint_2
+// CHECK-NEXT:  [[c212:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %int_2 %uint_1 %uint_2
 // CHECK-NEXT:  OpStore [[c212]] [[val]]
 
 // CHECK-NEXT:  [[b1:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySbuffer %int_0 %uint_1 %int_1 %int_1

+ 2 - 2
tools/clang/test/CodeGenSPIRV/op.structured-buffer.access.hlsl

@@ -23,10 +23,10 @@ float4 main(uint index: A) : SV_Target {
 // CHECK-NEXT: [[x:%\d+]] = OpAccessChain %_ptr_Uniform_float [[b1]] %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[x]]
 
-// CHECK:      [[c12:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %uint_2 %uint_1 %uint_2
+// CHECK:      [[c12:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_2 %int_2 %int_2 %uint_1 %uint_2
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[c12]]
 
-// CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %uint_0 %int_0
+// CHECK:      [[s:%\d+]] = OpAccessChain %_ptr_Uniform_float %MySbuffer %int_0 %uint_3 %int_3 %int_0 %int_0
 // CHECK-NEXT: {{%\d+}} = OpLoad %float [[s]]
 
 // CHECK:      [[index:%\d+]] = OpLoad %uint %index

+ 16 - 0
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -69,4 +69,20 @@ void main() {
 // CHECK-NEXT:         [[v:%\d+]] = OpLoad %v3int %v
 // CHECK-NEXT:           {{%\d+}} = OpSelect %v3int [[bool3Cond]] [[u]] [[v]]
     w = int3Cond ? u : v;
+
+// Make sure literal types are handled correctly in ternary ops
+
+// CHECK: [[b_float:%\d+]] = OpSelect %float {{%\d+}} %float_1_5 %float_2_5
+// CHECK-NEXT:    {{%\d+}} = OpConvertFToS %int [[b_float]]
+    int   b = cond ? 1.5 : 2.5;
+
+// CHECK:      [[a_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+// CHECK-NEXT:       {{%\d+}} = OpConvertSToF %float [[a_int]]
+    float a = cond ? 1 : 0;
+
+
+// CHECK:      [[c_long:%\d+]] = OpSelect %long {{%\d+}} %long_3000000000 %long_4000000000
+// CHECK-NEXT:        {{%\d+}} = OpConvertSToF %float [[c_long]]
+    // TODO: Use OpSConvert to first convert long to int. Then use OpConvertSToF.
+    float c = cond ? 3000000000 : 4000000000;
 }