فهرست منبع

[spirv] support half zero constant correctly

Jaebaek Seo 6 سال پیش
والد
کامیت
30acc73562

+ 5 - 9
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -9097,18 +9097,14 @@ SpirvConstant *SpirvEmitter::getValueZero(QualType type) {
   {
     QualType scalarType = {};
     if (isScalarType(type, &scalarType)) {
-      if (scalarType->isSignedIntegerType()) {
-        return spvBuilder.getConstantInt(astContext.IntTy, llvm::APInt(32, 0));
+      if (scalarType->isBooleanType()) {
+        return spvBuilder.getConstantBool(false);
       }
-
-      if (scalarType->isUnsignedIntegerType()) {
-        return spvBuilder.getConstantInt(astContext.UnsignedIntTy,
-                                         llvm::APInt(32, 0));
+      if (scalarType->isIntegerType()) {
+        return spvBuilder.getConstantInt(scalarType, llvm::APInt(32, 0));
       }
-
       if (scalarType->isFloatingType()) {
-        return spvBuilder.getConstantFloat(astContext.FloatTy,
-                                           llvm::APFloat(0.0f));
+        return spvBuilder.getConstantFloat(scalarType, llvm::APFloat(0.0f));
       }
     }
   }

+ 49 - 0
tools/clang/test/CodeGenSPIRV/constant.scalar.16bit.enabled.half.zero.hlsl

@@ -0,0 +1,49 @@
+// Run: %dxc -T ps_6_2 -E main -enable-16bit-types
+
+// CHECK: [[ext:%\d+]] = OpExtInstImport "GLSL.std.450"
+
+void main() {
+// CHECK:      [[a:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT: [[b:%\d+]] = OpSelect %half [[a]] %half_0x1p_0 %half_0x0p_0
+// CHECK-NEXT:              OpStore %b [[b]]
+  bool a;
+  half b = a;
+
+// CHECK:      [[c:%\d+]] = OpLoad %v2bool %c
+// CHECK-NEXT: [[d:%\d+]] = OpSelect %v2half [[c]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT:              OpStore %d [[d]]
+  bool2 c;
+  half2 d = c;
+
+// CHECK:      [[d:%\d+]] = OpLoad %v2half %d
+// CHECK-NEXT: [[e:%\d+]] = OpExtInst %v2half [[ext]] FClamp [[d]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT:              OpStore %e [[e]]
+  half2 e = saturate(d);
+
+// CHECK:      [[b:%\d+]] = OpLoad %half %b
+// CHECK-NEXT: [[f:%\d+]] = OpExtInst %half [[ext]] FClamp [[b]] %half_0x0p_0 %half_0x1p_0
+// CHECK-NEXT:              OpStore %f [[f]]
+  half f = saturate(b);
+
+// CHECK:      [[a:%\d+]] = OpLoad %bool %a
+// CHECK-NEXT: [[x:%\d+]] = OpSelect %float [[a]] %float_1 %float_0
+// CHECK-NEXT: [[y:%\d+]] = OpExtInst %float [[ext]] FClamp [[x]] %float_0 %float_1
+// CHECK-NEXT: [[g:%\d+]] = OpFConvert %half [[y]]
+// CHECK-NEXT:              OpStore %g [[g]]
+  half g = (half)saturate(a);
+
+// CHECK:      [[h:%\d+]] = OpLoad %v2int %h
+// CHECK-NEXT: [[x:%\d+]] = OpConvertSToF %v2float [[h]]
+// CHECK-NEXT: [[y:%\d+]] = OpExtInst %v2float [[ext]] FClamp [[x]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT: [[i:%\d+]] = OpFConvert %v2half [[y]]
+// CHECK-NEXT:              OpStore %i [[i]]
+  int2 h;
+  half2 i = (half2)saturate(h);
+
+// CHECK:      [[j:%\d+]] = OpLoad %v2float %j
+// CHECK-NEXT: [[x:%\d+]] = OpExtInst %v2float [[ext]] FClamp [[j]] {{%\d+}} {{%\d+}}
+// CHECK-NEXT: [[k:%\d+]] = OpFConvert %v2half [[x]]
+// CHECK-NEXT:              OpStore %k [[k]]
+  float2 j;
+  half2 k = (half2)saturate(j);
+}

+ 2 - 2
tools/clang/test/CodeGenSPIRV/ternary-op.cond-op.hlsl

@@ -117,8 +117,8 @@ void main() {
   // CHECK-NEXT:       {{%\d+}} = OpSelect %uint {{%\d+}} %uint_9 [[inner]]
   uint h = cond ? 9 : (cond ? 1 : 2);
 
-  //CHECK:      [[i_int:%\d+]] = OpSelect %int {{%\d+}} %int_1 %int_0
-  //CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[i_int]] %int_0
+  //CHECK:      [[i_int:%\d+]] = OpSelect %uint {{%\d+}} %uint_1 %uint_0
+  //CHECK-NEXT:       {{%\d+}} = OpINotEqual %bool [[i_int]] %uint_0
   bool i = cond ? 1 : 0;
 
   // CHECK:     [[foo:%\d+]] = OpFunctionCall %uint %foo

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -122,6 +122,9 @@ TEST_F(FileTest, 16BitEnabledScalarConstants) {
   // needed extension.
   runFileTest("constant.scalar.16bit.enabled.hlsl", Expect::Success, false);
 }
+TEST_F(FileTest, 16BitEnabledScalarConstantsHalfZero) {
+  runFileTest("constant.scalar.16bit.enabled.half.zero.hlsl");
+}
 TEST_F(FileTest, 64BitScalarConstants) {
   runFileTest("constant.scalar.64bit.hlsl");
 }