瀏覽代碼

[spirv] Better handling of ternary and binary ops. (#995)

* [spirv] Better handling of ternary and binary ops.

This change does type handling more properly and removes a hack
regarding binary ops. Also fixes the way ternary ops provide hints about
usage of literal types.
Ehsan 7 年之前
父節點
當前提交
fb8eac6c83

+ 90 - 45
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -1611,9 +1611,8 @@ SpirvEvalInfo SPIRVEmitter::doBinaryOperator(const BinaryOperator *expr) {
       return result;
   }
 
-  const uint32_t resultType = typeTranslator.translateType(expr->getType());
-  return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode, resultType,
-                         expr->getSourceRange());
+  return processBinaryOp(expr->getLHS(), expr->getRHS(), opcode,
+                         expr->getType(), expr->getSourceRange());
 }
 
 SpirvEvalInfo SPIRVEmitter::doCallExpr(const CallExpr *callExpr) {
@@ -2136,8 +2135,7 @@ SPIRVEmitter::doCompoundAssignOperator(const CompoundAssignOperator *expr) {
   const auto *lhs = expr->getLHS();
 
   SpirvEvalInfo lhsPtr = 0;
-  const uint32_t resultType = typeTranslator.translateType(expr->getType());
-  const auto result = processBinaryOp(lhs, rhs, opcode, resultType,
+  const auto result = processBinaryOp(lhs, rhs, opcode, expr->getType(),
                                       expr->getSourceRange(), &lhsPtr);
   return processAssignment(lhs, result, true, lhsPtr);
 }
@@ -2161,13 +2159,24 @@ SPIRVEmitter::doConditionalOperator(const ConditionalOperator *expr) {
   // Note that if the literal is in fact large enough that it can't be
   // represented in 32 bits (e.g. integer larger than 3e+9), we should *not*
   // provide a hint.
+
   TypeTranslator::LiteralTypeHint hint(typeTranslator);
-  if (canBeRepresentedIn32Bits(expr->getTrueExpr()) &&
-      canBeRepresentedIn32Bits(expr->getFalseExpr())) {
-    if (type->isSpecificBuiltinType(BuiltinType::LitInt))
-      hint.setHint(astContext.IntTy);
-    else if (type->isSpecificBuiltinType(BuiltinType::LitFloat))
-      hint.setHint(astContext.FloatTy);
+  const bool isLitInt = type->isSpecificBuiltinType(BuiltinType::LitInt);
+  const bool isLitFloat = type->isSpecificBuiltinType(BuiltinType::LitFloat);
+  // Return type of ConditionalOperator is a 'literal int' or 'literal float'
+  if (isLitInt || isLitFloat) {
+    // There is no hint about the intended usage of the literal type.
+    if (typeTranslator.getIntendedLiteralType(type) == type) {
+      // If either branch is a literal that is larger than 32-bits, do not
+      // provide a hint.
+      if (!isLiteralLargerThan32Bits(expr->getTrueExpr()) &&
+          !isLiteralLargerThan32Bits(expr->getFalseExpr())) {
+        if (isLitInt)
+          hint.setHint(astContext.IntTy);
+        else if (isLitFloat)
+          hint.setHint(astContext.FloatTy);
+      }
+    }
   }
 
   // According to HLSL doc, all sides of the ?: expression are always
@@ -4079,10 +4088,25 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
 
 SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
                                             const BinaryOperatorKind opcode,
-                                            const uint32_t resultType,
+                                            const QualType resultType,
                                             SourceRange sourceRange,
                                             SpirvEvalInfo *lhsInfo,
                                             const spv::Op mandateGenOpcode) {
+  const uint32_t resultTypeId = typeTranslator.translateType(resultType);
+
+  // Binary logical operations (such as ==, !=, etc) that return a boolean type
+  // may get a literal (e.g. 0, 1, etc.) as lhs or rhs args. Since only
+  // non-zero-ness of these literals matter, they can be translated as 32-bits.
+  TypeTranslator::LiteralTypeHint hint(typeTranslator);
+  if (resultType->isBooleanType()) {
+    if (lhs->getType()->isSpecificBuiltinType(BuiltinType::LitInt) ||
+        rhs->getType()->isSpecificBuiltinType(BuiltinType::LitInt))
+      hint.setHint(astContext.IntTy);
+    if (lhs->getType()->isSpecificBuiltinType(BuiltinType::LitFloat) ||
+        rhs->getType()->isSpecificBuiltinType(BuiltinType::LitFloat))
+      hint.setHint(astContext.FloatTy);
+  }
+
   // If the operands are of matrix type, we need to dispatch the operation
   // onto each element vector iff the operands are not degenerated matrices
   // and we don't have a matrix specific SPIR-V instruction for the operation.
@@ -4153,7 +4177,7 @@ SpirvEvalInfo SPIRVEmitter::processBinaryOp(const Expr *lhs, const Expr *rhs,
   case BO_ShlAssign:
   case BO_ShrAssign: {
     const auto valId =
-        theBuilder.createBinaryOp(spvOp, resultType, lhsVal, rhsVal);
+        theBuilder.createBinaryOp(spvOp, resultTypeId, lhsVal, rhsVal);
     auto result = SpirvEvalInfo(valId).setRValue();
     return lhsVal.isRelaxedPrecision() || rhsVal.isRelaxedPrecision()
                ? result.setRelaxedPrecision()
@@ -4432,7 +4456,7 @@ SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
   if (hlsl::IsHLSLVecType(lhs->getType())) {
     if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
       if (cast->getCastKind() == CK_HLSLVectorSplat) {
-        const uint32_t vecType = typeTranslator.translateType(expr->getType());
+        const QualType vecType = expr->getType();
         if (isa<CompoundAssignOperator>(expr)) {
           SpirvEvalInfo lhsPtr = 0;
           const auto result = processBinaryOp(
@@ -4452,7 +4476,7 @@ SPIRVEmitter::tryToGenFloatVectorScale(const BinaryOperator *expr) {
   if (hlsl::IsHLSLVecType(rhs->getType())) {
     if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
       if (cast->getCastKind() == CK_HLSLVectorSplat) {
-        const uint32_t vecType = typeTranslator.translateType(expr->getType());
+        const QualType vecType = expr->getType();
         // We need to switch the positions of lhs and rhs here because
         // OpVectorTimesScalar requires the first operand to be a vector and
         // the second to be a scalar.
@@ -4496,7 +4520,7 @@ SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
   if (hlsl::IsHLSLMatType(lhsType)) {
     if (const auto *cast = dyn_cast<ImplicitCastExpr>(rhs)) {
       if (cast->getCastKind() == CK_HLSLMatrixSplat) {
-        const uint32_t matType = typeTranslator.translateType(expr->getType());
+        const QualType matType = expr->getType();
         const spv::Op opcode = selectOpcode(lhsType);
         if (isa<CompoundAssignOperator>(expr)) {
           SpirvEvalInfo lhsPtr = 0;
@@ -4516,7 +4540,7 @@ SPIRVEmitter::tryToGenFloatMatrixScale(const BinaryOperator *expr) {
   if (hlsl::IsHLSLMatType(rhsType)) {
     if (const auto *cast = dyn_cast<ImplicitCastExpr>(lhs)) {
       if (cast->getCastKind() == CK_HLSLMatrixSplat) {
-        const uint32_t matType = typeTranslator.translateType(expr->getType());
+        const QualType matType = expr->getType();
         const spv::Op opcode = selectOpcode(rhsType);
         // We need to switch the positions of lhs and rhs here because
         // OpMatrixTimesScalar requires the first operand to be a matrix and
@@ -4934,23 +4958,13 @@ uint32_t SPIRVEmitter::castToInt(const uint32_t fromVal, QualType fromType,
   uint32_t intType = typeTranslator.translateType(toIntType);
 
   // AST may include a 'literal int' to 'int' conversion. No-op.
-  // 'literal int' to 'uint' must still go through conversion.
   if (fromType->isSpecificBuiltinType(BuiltinType::LitInt) &&
-      toIntType->isSignedIntegerType())
+      toIntType->isIntegerType())
     return fromVal;
 
   if (isBoolOrVecOfBoolType(fromType)) {
     const uint32_t one = getValueOne(toIntType);
     const uint32_t zero = getValueZero(toIntType);
-    if (toIntType->isScalarType() && toIntType->isLiteralType(astContext)) {
-      // Special case for handling casting from boolean values to literal ints.
-      // For source code like (a == b) != 5, an IntegralCast will be inserted
-      // for (a == b), whose return type will be 64-bit integer if following the
-      // normal path.
-      // TODO: This is not beautiful. But other ways are even worse.
-      intType = toIntType->isSignedIntegerType() ? theBuilder.getInt32Type()
-                                                 : theBuilder.getUint32Type();
-    }
     return theBuilder.createSelect(intType, fromVal, one, zero);
   }
 
@@ -4983,8 +4997,8 @@ uint32_t SPIRVEmitter::castToFloat(const uint32_t fromVal, QualType fromType,
   const uint32_t floatType = typeTranslator.translateType(toFloatType);
 
   // AST may include a 'literal float' to 'float' conversion. No-op.
-  if (fromType->isLiteralType(astContext) && fromType->isFloatingType() &&
-      typeTranslator.translateType(fromType) == floatType)
+  if (fromType->isSpecificBuiltinType(BuiltinType::LitFloat) &&
+      toFloatType->isFloatingType())
     return fromVal;
 
   if (isBoolOrVecOfBoolType(fromType)) {
@@ -6965,12 +6979,28 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
     }
   }
   case 32: {
-    if (isSigned)
+    if (isSigned) {
+      if (!intValue.isSignedIntN(32)) {
+        emitError("evaluating integer literal %0 as a 32-bit integer loses "
+                  "inforamtion",
+                  {})
+            << std::to_string(intValue.getSExtValue());
+        return 0;
+      }
       return theBuilder.getConstantInt32(
           static_cast<int32_t>(intValue.getSExtValue()));
-    else
+    }
+    else {
+      if (!intValue.isIntN(32)) {
+        emitError("evaluating integer literal %0 as a 32-bit integer loses "
+                  "inforamtion",
+                  {})
+            << std::to_string(intValue.getZExtValue());
+        return 0;
+      }
       return theBuilder.getConstantUint32(
           static_cast<uint32_t>(intValue.getZExtValue()));
+    }
   }
   case 64: {
     if (isSigned)
@@ -6986,12 +7016,12 @@ uint32_t SPIRVEmitter::translateAPInt(const llvm::APInt &intValue,
   return 0;
 }
 
-bool SPIRVEmitter::canBeRepresentedIn32Bits(const Expr *expr) {
+bool SPIRVEmitter::isLiteralLargerThan32Bits(const Expr *expr) {
   if (const auto *intLiteral = dyn_cast<IntegerLiteral>(expr)) {
     const bool isSigned = expr->getType()->isSignedIntegerType();
     const llvm::APInt &value = intLiteral->getValue();
-    return (isSigned && value.isSignedIntN(32)) ||
-           (!isSigned && value.isIntN(32));
+    return (isSigned && !value.isSignedIntN(32)) ||
+           (!isSigned && !value.isIntN(32));
   }
 
   if (const auto *floatLiteral = dyn_cast<FloatingLiteral>(expr)) {
@@ -7007,7 +7037,8 @@ bool SPIRVEmitter::canBeRepresentedIn32Bits(const Expr *expr) {
     const auto convertStatus =
         value.convert(llvm::APFloat::IEEEsingle,
                       llvm::APFloat::rmNearestTiesToEven, &losesInfo);
-    if (convertStatus == llvm::APFloat::opOK && !losesInfo)
+    if (convertStatus != llvm::APFloat::opOK &&
+        convertStatus != llvm::APFloat::opInexact)
       return true;
   }
 
@@ -7053,13 +7084,14 @@ uint32_t SPIRVEmitter::tryToEvaluateAsFloat32(const llvm::APFloat &floatValue) {
 
 uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
                                         QualType targetType) {
-  const auto valueBitwidth =
-      llvm::APFloat::getSizeInBits(floatValue.getSemantics());
+  using llvm::APFloat;
+  const auto originalValue = floatValue;
+  const auto valueBitwidth = APFloat::getSizeInBits(floatValue.getSemantics());
 
   // Find out the target bitwidth.
   targetType = typeTranslator.getIntendedLiteralType(targetType);
-  auto targetBitwidth = llvm::APFloat::getSizeInBits(
-      astContext.getFloatTypeSemantics(targetType));
+  auto targetBitwidth =
+      APFloat::getSizeInBits(astContext.getFloatTypeSemantics(targetType));
   // If 16-bit types are not enabled, treat them as 32-bit float.
   if (targetBitwidth == 16 && !spirvOptions.enable16BitTypes)
     targetBitwidth = 32;
@@ -7067,11 +7099,24 @@ uint32_t SPIRVEmitter::translateAPFloat(llvm::APFloat floatValue,
   if (targetBitwidth != valueBitwidth) {
     bool losesInfo = false;
     const llvm::fltSemantics &targetSemantics =
-        targetBitwidth == 16 ? llvm::APFloat::IEEEhalf
-                             : targetBitwidth == 32 ? llvm::APFloat::IEEEsingle
-                                                    : llvm::APFloat::IEEEdouble;
-    floatValue.convert(targetSemantics,
-                       llvm::APFloat::roundingMode::rmTowardZero, &losesInfo);
+        targetBitwidth == 16
+            ? APFloat::IEEEhalf
+            : targetBitwidth == 32 ? APFloat::IEEEsingle : APFloat::IEEEdouble;
+    const auto status = floatValue.convert(
+        targetSemantics, APFloat::roundingMode::rmTowardZero, &losesInfo);
+    if (status != APFloat::opStatus::opOK &&
+        status != APFloat::opStatus::opInexact) {
+      emitError(
+          "evaluating float literal %0 at a lower bitwidth loses information",
+          {})
+          << std::to_string(
+                 valueBitwidth == 16
+                     ? static_cast<float>(
+                           originalValue.bitcastToAPInt().getZExtValue())
+                     : valueBitwidth == 32 ? originalValue.convertToFloat()
+                                           : originalValue.convertToDouble());
+      return 0;
+    }
   }
 
   switch (targetBitwidth) {

+ 4 - 4
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -153,7 +153,7 @@ private:
   /// mandateGenOpcode is not spv::Op::Max, it will used as the SPIR-V opcode
   /// instead of deducing from Clang frontend opcode.
   SpirvEvalInfo processBinaryOp(const Expr *lhs, const Expr *rhs,
-                                BinaryOperatorKind opcode, uint32_t resultType,
+                                BinaryOperatorKind opcode, QualType resultType,
                                 SourceRange, SpirvEvalInfo *lhsInfo = nullptr,
                                 spv::Op mandateGenOpcode = spv::Op::Max);
 
@@ -450,11 +450,11 @@ private:
   /// constant for that value.
   uint32_t tryToEvaluateAsInt32(const llvm::APInt &, bool isSigned);
 
-  /// Returns true iff the given expression is a literal integer that can be
-  /// represented in a 32-bit integer type or a literal float that can be
+  /// Returns true iff the given expression is a literal integer that cannot be
+  /// represented in a 32-bit integer type or a literal float that cannot be
   /// represented in a 32-bit float type without losing info. Returns false
   /// otherwise.
-  bool canBeRepresentedIn32Bits(const Expr *expr);
+  bool isLiteralLargerThan32Bits(const Expr *expr);
 
 private:
   /// Translates the given HLSL loop attribute into SPIR-V loop control mask.

+ 5 - 4
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -184,11 +184,12 @@ QualType TypeTranslator::getIntendedLiteralType(QualType type) {
     // In the above example, we have no hints about how '2' or '3' should be
     // used.
     QualType potentialHint = intendedLiteralTypes.top();
+    const bool hintIsInteger =
+        potentialHint->isIntegerType() && !potentialHint->isBooleanType();
+    const bool hintIsFloating = potentialHint->isFloatingType();
     const bool isDifferentBasicType =
-        (type->isSpecificBuiltinType(BuiltinType::LitInt) &&
-         !potentialHint->isIntegerType()) ||
-        (type->isSpecificBuiltinType(BuiltinType::LitFloat) &&
-         !potentialHint->isFloatingType());
+        (type->isSpecificBuiltinType(BuiltinType::LitInt) && !hintIsInteger) ||
+        (type->isSpecificBuiltinType(BuiltinType::LitFloat) && !hintIsFloating);
 
     if (!isDifferentBasicType)
       return intendedLiteralTypes.top();

+ 17 - 7
tools/clang/test/CodeGenSPIRV/cast.2literal-int.implicit.hlsl

@@ -1,9 +1,19 @@
-// Run: %dxc -T vs_6_0 -E main
+// Run: %dxc -T ps_6_0 -E main
 
-bool main(int a : A, int b : B) : C {
-// CHECK:       [[a:%\d+]] = OpLoad %int %a
-// CHECK-NEXT:  [[b:%\d+]] = OpLoad %int %b
-// CHECK-NEXT: [[eq:%\d+]] = OpIEqual %bool [[a]] [[b]]
-// CHECK-NEXT:    {{%\d+}} = OpSelect %int [[eq]] %int_1 %int_0
-    return (a == b) != 0;
+void main() {
+  int a, b;
+
+// CHECK:          [[a:%\d+]] = OpLoad %int %a
+// CHECK-NEXT:     [[b:%\d+]] = OpLoad %int %b
+// CHECK-NEXT:    [[eq:%\d+]] = OpIEqual %bool [[a]] [[b]]
+// CHECK-NEXT: [[c_int:%\d+]] = OpSelect %int [[eq]] %int_1 %int_0
+// CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[c_int]] %int_1
+  bool c = (a == b) != 1;
+
+// CHECK:            [[a:%\d+]] = OpLoad %int %a
+// CHECK-NEXT:       [[b:%\d+]] = OpLoad %int %b
+// CHECK-NEXT:      [[eq:%\d+]] = OpIEqual %bool [[a]] [[b]]
+// CHECK-NEXT: [[d_float:%\d+]] = OpSelect %float [[eq]] %float_1 %float_0
+// CHECK-NEXT:         {{%\d+}} = OpFOrdNotEqual %bool [[d_float]] %float_1
+  bool d = (a == b) != 1.0;
 }

+ 21 - 2
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -2,6 +2,9 @@
 
 // CHECK: [[v3i0:%\d+]] = OpConstantComposite %v3int %int_0 %int_0 %int_0
 
+uint foo() { return 1; }
+float bar() { return 3.0; }
+
 void main() {
 // CHECK-LABEL: %bb_entry = OpLabel
 
@@ -87,8 +90,7 @@ void main() {
     // TODO: Use OpSConvert to first convert long to int. Then use OpConvertSToF.
     float c = cond ? 3000000000 : 4000000000;
 
-// CHECK:      [[d_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
-// CHECK-NEXT:       {{%\d+}} = OpBitcast %uint [[d_int]]
+// CHECK:      [[d_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
     uint d = cond ? 1 : 0;
 
     float2x3 e;
@@ -108,4 +110,21 @@ void main() {
 // CHECK-NEXT:[[temp:%\d+]] = OpLoad %mat2v3float %temp_var_ternary
 // CHECK-NEXT:                OpStore %g [[temp]]
     float2x3 g = cond ? e : f;
+
+// CHECK:      [[inner:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_2
+// CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
+    uint h = cond ? 9 : (cond ? 1 : 2);
+
+//CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
+//CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[i_int]] %int_0
+    bool i = cond ? 1 : 0;
+
+// CHECK:     [[foo:%\d+]] = OpFunctionCall %uint %foo
+// CHECKNEXT:     {{%\d+}} = OpSelect %uint {{%\d+}} %uint_3 [[foo]]
+    uint j = cond ? 3 : foo();
+
+// CHECK:          [[bar:%\d+]] = OpFunctionCall %float %bar
+// CHECK-NEXT: [[k_float:%\d+]] = OpSelect %float {{%\d+}} %float_4 [[bar]]
+// CHECK-NEXT:         {{%\d+}} = OpConvertFToU %uint [[k_float]]
+    uint k = cond ? 4 : bar();
 }