ソースを参照

[spirv] Convert floats of different bit widths (#844)

Also support 'literal float' to 'float' conversion.
Ehsan 7 年 前
コミット
31855e75a8

+ 6 - 3
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -4139,6 +4139,11 @@ uint32_t SPIRVEmitter::castToFloat(const uint32_t fromVal, QualType fromType,
 
   const uint32_t floatType = typeTranslator.translateType(toFloatType);
 
+  // AST may include a 'literal float' to 'float' conversion. No-op.
+  if (fromType->isLiteralType(astContext) && fromType->isFloatingType() &&
+      typeTranslator.translateType(fromType) == floatType)
+    return fromVal;
+
   if (isBoolOrVecOfBoolType(fromType)) {
     const uint32_t one = getValueOne(toFloatType);
     const uint32_t zero = getValueZero(toFloatType);
@@ -4154,9 +4159,7 @@ uint32_t SPIRVEmitter::castToFloat(const uint32_t fromVal, QualType fromType,
   }
 
   if (isFloatOrVecOfFloatType(fromType)) {
-    emitError("casting between different floating point bitwidth unimplemented",
-              srcLoc);
-    return 0;
+    return theBuilder.createUnaryOp(spv::Op::OpFConvert, floatType, fromVal);
   }
 
   emitError("casting to floating point unimplemented", srcLoc);

+ 8 - 0
tools/clang/lib/SPIRV/TypeTranslator.cpp

@@ -148,6 +148,14 @@ uint32_t TypeTranslator::translateType(QualType type, LayoutRule rule,
           return theBuilder.getFloat32Type();
         case BuiltinType::Double:
           return theBuilder.getFloat64Type();
+        case BuiltinType::LitFloat: {
+          const auto &semantics = astContext.getFloatTypeSemantics(type);
+          const auto bitwidth = llvm::APFloat::getSizeInBits(semantics);
+          if (bitwidth <= 32)
+            return theBuilder.getFloat32Type();
+          else
+            return theBuilder.getFloat64Type();
+        }
         default:
           emitError("primitive type %0 unimplemented")
               << builtinType->getTypeClassName();

+ 17 - 0
tools/clang/test/CodeGenSPIRV/op.vector.swizzle.const-scalar.hlsl

@@ -0,0 +1,17 @@
+// Run: %dxc -T ps_6_0 -E main
+
+// CHECK:  [[v4f1:%\d+]] = OpConstantComposite %v4float %float_1 %float_1 %float_1 %float_1
+// CHECK: [[v4f25:%\d+]] = OpConstantComposite %v4float %float_2_5 %float_2_5 %float_2_5 %float_2_5
+// CHECK:  [[v4f0:%\d+]] = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
+
+void main() {
+
+// CHECK: %a = OpVariable %_ptr_Function_v4float Function [[v4f1]]
+  float4 a = (1).xxxx;
+
+// CHECK: %b = OpVariable %_ptr_Function_v4float Function [[v4f25]]
+  float4 b = (2.5).xxxx;
+
+// CHECK: %c = OpVariable %_ptr_Function_v4float Function [[v4f0]]
+  float4 c = (false).xxxx;
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -223,6 +223,9 @@ TEST_F(FileTest, OpVectorSwizzleAfterBufferAccess) {
 TEST_F(FileTest, OpVectorSwizzleAfterTextureAccess) {
   runFileTest("op.vector.swizzle.texture-access.hlsl");
 }
+TEST_F(FileTest, OpVectorSwizzleConstScalar) {
+  runFileTest("op.vector.swizzle.const-scalar.hlsl");
+}
 TEST_F(FileTest, OpVectorAccess) { runFileTest("op.vector.access.hlsl"); }
 
 // For matrix accessing/swizzling operators