Browse Source

[spirv] support half zero constant correctly

Jaebaek Seo 6 years ago
parent
commit
ea0d1896ec

+ 3 - 5
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -9097,17 +9097,15 @@ SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
     QualType scalarType = {};
     QualType scalarType = {};
     if (isScalarType(type, &scalarType)) {
     if (isScalarType(type, &scalarType)) {
       if (scalarType->isSignedIntegerType()) {
       if (scalarType->isSignedIntegerType()) {
-        return spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0));
+        return spvBuilder.getConstantInt(scalarType, llvm::APInt(32, 0));
       }
       }
 
 
       if (scalarType->isUnsignedIntegerType()) {
       if (scalarType->isUnsignedIntegerType()) {
-        return spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                         llvm::APInt(32, 0));
+        return spvBuilder.getConstantInt(scalarType, llvm::APInt(32, 0));
       }
       }
 
 
       if (scalarType->isFloatingType()) {
       if (scalarType->isFloatingType()) {
-        return spvBuilder.getConstantFloat(astContext.FloatTy,
-                                           llvm::APFloat(0.0f));
+        return spvBuilder.getConstantFloat(scalarType, llvm::APFloat(0.0f));
       }
       }
     }
     }
   }
   }

+ 49 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.enabled.half.zero.hlsl

@@ -0,0 +1,49 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// CHECK: [[ext:%\d+]] = OpExtInstImport "GLSL.std.450"
+
+void main() {
+// CHECK:      [[a:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT: [[b:%\d+]] = OpSelect %half [[a]] %half_0x1p_0 %half_0x0p_0
+// CHECK-NEXT:              OpStore %b [[b]]
+  bool a;
+  half b = a;
+
+// CHECK:      [[c:%\d+]] = OpLoad %v2bool %c
+// CHECK-NEXT: [[d:%\d+]] = OpSelect %v2half [[c]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT:              OpStore %d [[d]]
+  bool2 c;
+  half2 d = c;
+
+// CHECK:      [[d:%\d+]] = OpLoad %v2half %d
+// CHECK-NEXT: [[e:%\d+]] = OpExtInst %v2half [[ext]] FClamp [[d]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT:              OpStore %e [[e]]
+  half2 e = saturate(d);
+
+// CHECK:      [[b:%\d+]] = OpLoad %half %b
+// CHECK-NEXT: [[f:%\d+]] = OpExtInst %half [[ext]] FClamp [[b]] %half_0x0p_0 %half_0x1p_0
+// CHECK-NEXT:              OpStore %f [[f]]
+  half f = saturate(b);
+
+// CHECK:      [[a:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT: [[x:%\d+]] = OpSelect %float [[a]] %float_1 %float_0
+// CHECK-NEXT: [[y:%\d+]] = OpExtInst %float [[ext]] FClamp [[x]] %float_0 %float_1
+// CHECK-NEXT: [[g:%\d+]] = OpFConvert %half [[y]]
+// CHECK-NEXT:              OpStore %g [[g]]
+  half g = (half)saturate(a);
+
+// CHECK:      [[h:%\d+]] = OpLoad %v2int %h
+// CHECK-NEXT: [[x:%\d+]] = OpConvertSToF %v2float [[h]]
+// CHECK-NEXT: [[y:%\d+]] = OpExtInst %v2float [[ext]] FClamp [[x]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT: [[i:%\d+]] = OpFConvert %v2half [[y]]
+// CHECK-NEXT:              OpStore %i [[i]]
+  int2 h;
+  half2 i = (half2)saturate(h);
+
+// CHECK:      [[j:%\d+]] = OpLoad %v2float %j
+// CHECK-NEXT: [[x:%\d+]] = OpExtInst %v2float [[ext]] FClamp [[j]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT: [[k:%\d+]] = OpFConvert %v2half [[x]]
+// CHECK-NEXT:              OpStore %k [[k]]
+  float2 j;
+  half2 k = (half2)saturate(j);
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -122,6 +122,9 @@ TEST_F(FileTest, 16BitEnabledScalarConstants) {
   // needed extension.
   // needed extension.
   runFileTest("constant.scalar.16bit.enabled.hlsl", Expect::Success, false);
   runFileTest("constant.scalar.16bit.enabled.hlsl", Expect::Success, false);
 }
 }
+TEST_F(FileTest, 16BitEnabledScalarConstantsHalfZero) {
+  runFileTest("constant.scalar.16bit.enabled.half.zero.hlsl");
+}
 TEST_F(FileTest, 64BitScalarConstants) {
 TEST_F(FileTest, 64BitScalarConstants) {
   runFileTest("constant.scalar.64bit.hlsl");
   runFileTest("constant.scalar.64bit.hlsl");
 }
 }