Ver código fonte

[spirv] Fix bug: Must perform Bitcast rather than type cast.

Ehsan Nasiri 6 anos atrás
pai
commit
a03a69546a

+ 26 - 13
tools/clang/lib/SPIRV/RawBufferMethods.cpp

@@ -17,6 +17,20 @@
 namespace clang {
 namespace spirv {
 
+SpirvInstruction *
+RawBufferHandler::bitCastToNumericalOrBool(SpirvInstruction *instr,
+                                           QualType fromType, QualType toType,
+                                           SourceLocation loc) {
+  if (isSameType(astContext, fromType, toType))
+    return instr;
+
+  if (toType->isBooleanType())
+    return theEmitter.castToType(instr, fromType, toType, loc);
+
+  // Perform a bitcast
+  return spvBuilder.createUnaryOp(spv::Op::OpBitcast, toType, instr, loc);
+}
+
 SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset0(
     SpirvInstruction *buffer, SpirvInstruction *&index,
     QualType target16BitType, uint32_t &bitOffset) {
@@ -34,8 +48,8 @@ SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset0(
   // OpUConvert can perform truncation in this case.
   result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
                                     astContext.UnsignedShortTy, result, loc);
-  result = theEmitter.castToType(result, astContext.UnsignedShortTy,
-                                 target16BitType, loc);
+  result = bitCastToNumericalOrBool(result, astContext.UnsignedShortTy,
+                                    target16BitType, loc);
   result->setRValue();
 
   // Now that a 16-bit load at bit-offset 0 has been performed, the next load
@@ -58,8 +72,8 @@ SpirvInstruction *RawBufferHandler::load32BitsAtBitOffset0(
   auto *loadPtr = spvBuilder.createAccessChain(astContext.UnsignedIntTy, buffer,
                                                {constUint0, index}, loc);
   result = spvBuilder.createLoad(astContext.UnsignedIntTy, loadPtr, loc);
-  result = theEmitter.castToType(result, astContext.UnsignedIntTy,
-                                 target32BitType, loc);
+  result = bitCastToNumericalOrBool(result, astContext.UnsignedIntTy,
+                                    target32BitType, loc);
   result->setRValue();
   // Now that a 32-bit load at bit-offset 0 has been performed, the next load
   // should be done at *the next base index* at bit-offset 0.
@@ -114,9 +128,8 @@ SpirvInstruction *RawBufferHandler::load64BitsAtBitOffset0(
   // BitwiseOr word0 and word1.
   result = spvBuilder.createBinaryOp(
       spv::Op::OpBitwiseOr, astContext.UnsignedLongLongTy, word0, word1, loc);
-
-  result = theEmitter.castToType(result, astContext.UnsignedLongLongTy,
-                                 target64BitType, loc);
+  result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
+                                    target64BitType, loc);
   result->setRValue();
   // Now that a 64-bit load at bit-offset 0 has been performed, the next load
   // should be done at *the base index + 2* at bit-offset 0. The index has
@@ -151,8 +164,8 @@ SpirvInstruction *RawBufferHandler::load16BitsAtBitOffset16(
                                      constUint16, loc);
   result = spvBuilder.createUnaryOp(spv::Op::OpUConvert,
                                     astContext.UnsignedShortTy, result, loc);
-  result = theEmitter.castToType(result, astContext.UnsignedShortTy,
-                                 target16BitType, loc);
+  result = bitCastToNumericalOrBool(result, astContext.UnsignedShortTy,
+                                    target16BitType, loc);
   result->setRValue();
 
   // Now that a 16-bit load at bit-offset 16 has been performed, the next load
@@ -212,8 +225,8 @@ SpirvInstruction *RawBufferHandler::load32BitsAtBitOffset16(
   result = spvBuilder.createBinaryOp(spv::Op::OpBitwiseOr,
                                      astContext.UnsignedIntTy, lsb, msb, loc);
 
-  result = theEmitter.castToType(result, astContext.UnsignedIntTy,
-                                 target32BitType, loc);
+  result = bitCastToNumericalOrBool(result, astContext.UnsignedIntTy,
+                                    target32BitType, loc);
   result->setRValue();
 
   // Now that a 32-bit load at bit-offset 16 has been performed, the next load
@@ -293,8 +306,8 @@ SpirvInstruction *RawBufferHandler::load64BitsAtBitOffset16(
   result = spvBuilder.createBinaryOp(
       spv::Op::OpBitwiseOr, astContext.UnsignedLongLongTy, result, last16, loc);
 
-  result = theEmitter.castToType(result, astContext.UnsignedLongLongTy,
-                                 target64BitType, loc);
+  result = bitCastToNumericalOrBool(result, astContext.UnsignedLongLongTy,
+                                    target64BitType, loc);
   result->setRValue();
 
   // Now that a 64-bit load at bit-offset 16 has been performed, the next load

+ 12 - 0
tools/clang/lib/SPIRV/RawBufferMethods.h

@@ -69,6 +69,18 @@ private:
                                             QualType target64BitType,
                                             uint32_t &bitOffset);
 
+private:
+  /// \brief Performs an OpBitCast from |fromType| to |toType| on the given
+  /// instruction.
+  ///
+  /// If the |toType| is a boolean type, it performs a regular type cast.
+  ///
+  /// If the |fromType| and |toType| are the same, does not thing and returns
+  /// the given instruction
+  SpirvInstruction *bitCastToNumericalOrBool(SpirvInstruction *instr,
+                                             QualType fromType, QualType toType,
+                                             SourceLocation loc);
+
 private:
   SpirvEmitter &theEmitter;
   const ASTContext &astContext;

+ 4 - 4
tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.templated-load.matrix.hlsl

@@ -71,7 +71,7 @@ void main(uint3 tid : SV_DispatchThreadId)
 // CHECK:         [[word1_ulong:%\d+]] = OpUConvert %ulong [[word1]]
 // CHECK: [[word1_ulong_shifted:%\d+]] = OpShiftLeftLogical %ulong [[word1_ulong]] %uint_32
 // CHECK:          [[val0_ulong:%\d+]] = OpBitwiseOr %ulong [[word0_ulong]] [[word1_ulong_shifted]]
-// CHECK:                [[val0:%\d+]] = OpConvertUToF %double [[val0_ulong]]
+// CHECK:                [[val0:%\d+]] = OpBitcast %double [[val0_ulong]]
 // CHECK:             [[index_2:%\d+]] = OpIAdd %uint [[index_1]] %uint_1
 // CHECK:                 [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2]]
 // CHECK:               [[word2:%\d+]] = OpLoad %uint [[ptr]]
@@ -82,7 +82,7 @@ void main(uint3 tid : SV_DispatchThreadId)
 // CHECK:         [[word3_ulong:%\d+]] = OpUConvert %ulong [[word3]]
 // CHECK: [[word3_ulong_shifted:%\d+]] = OpShiftLeftLogical %ulong [[word3_ulong]] %uint_32
 // CHECK:          [[val1_ulong:%\d+]] = OpBitwiseOr %ulong [[word2_ulong]] [[word3_ulong_shifted]]
-// CHECK:                [[val1:%\d+]] = OpConvertUToF %double [[val1_ulong]]
+// CHECK:                [[val1:%\d+]] = OpBitcast %double [[val1_ulong]]
 // CHECK:             [[index_4:%\d+]] = OpIAdd %uint [[index_3]] %uint_1
 // CHECK:                [[row0:%\d+]] = OpCompositeConstruct %v2double [[val0]] [[val1]]
 // CHECK:                 [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_4]]
@@ -94,7 +94,7 @@ void main(uint3 tid : SV_DispatchThreadId)
 // CHECK:         [[word5_ulong:%\d+]] = OpUConvert %ulong [[word5]]
 // CHECK: [[word5_ulong_shifted:%\d+]] = OpShiftLeftLogical %ulong [[word5_ulong]] %uint_32
 // CHECK:          [[val2_ulong:%\d+]] = OpBitwiseOr %ulong [[word4_ulong]] [[word5_ulong_shifted]]
-// CHECK:                [[val2:%\d+]] = OpConvertUToF %double [[val2_ulong]]
+// CHECK:                [[val2:%\d+]] = OpBitcast %double [[val2_ulong]]
 // CHECK:             [[index_6:%\d+]] = OpIAdd %uint [[index_5]] %uint_1
 // CHECK:                 [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_6]]
 // CHECK:               [[word6:%\d+]] = OpLoad %uint [[ptr]]
@@ -105,7 +105,7 @@ void main(uint3 tid : SV_DispatchThreadId)
 // CHECK:         [[word7_ulong:%\d+]] = OpUConvert %ulong [[word7]]
 // CHECK: [[word7_ulong_shifted:%\d+]] = OpShiftLeftLogical %ulong [[word7_ulong]] %uint_32
 // CHECK:          [[val3_ulong:%\d+]] = OpBitwiseOr %ulong [[word6_ulong]] [[word7_ulong_shifted]]
-// CHECK:                [[val3:%\d+]] = OpConvertUToF %double [[val3_ulong]]
+// CHECK:                [[val3:%\d+]] = OpBitcast %double [[val3_ulong]]
 // CHECK:                [[row1:%\d+]] = OpCompositeConstruct %v2double [[val2]] [[val3]]
 // CHECK:              [[matrix:%\d+]] = OpCompositeConstruct %mat2v2double [[row0]] [[row1]]
 // CHECK:                                OpStore %f64 [[matrix]]

+ 5 - 5
tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.templated-load.scalar.hlsl

@@ -22,7 +22,7 @@ ByteAddressBuffer buf;
   // CHECK:    [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 {{%\d+}}
   // CHECK:   [[uint:%\d+]] = OpLoad %uint [[ptr]]
   // CHECK: [[ushort:%\d+]] = OpUConvert %ushort [[uint]]
-  // CHECK:   [[half:%\d+]] = OpConvertUToF %half [[ushort]]
+  // CHECK:   [[half:%\d+]] = OpBitcast %half [[ushort]]
   // CHECK:                   OpStore %f16 [[half]]
   float16_t f16 = buf.Load<float16_t>(tid.x);
 
@@ -41,7 +41,7 @@ ByteAddressBuffer buf;
 
   // CHECK:   [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 {{%\d+}}
   // CHECK:  [[uint:%\d+]] = OpLoad %uint [[ptr]]
-  // CHECK: [[float:%\d+]] = OpConvertUToF %float [[uint]]
+  // CHECK: [[float:%\d+]] = OpBitcast %float [[uint]]
   // CHECK:                  OpStore %f [[float]]
   float f = buf.Load<float>(tid.x);
 
@@ -87,7 +87,7 @@ ByteAddressBuffer buf;
 // CHECK:        [[word1Long:%\d+]] = OpUConvert %ulong [[word1]]
 // CHECK: [[shiftedWord1Long:%\d+]] = OpShiftLeftLogical %ulong [[word1Long]] %uint_32
 // CHECK:        [[val_ulong:%\d+]] = OpBitwiseOr %ulong [[word0Long]] [[shiftedWord1Long]]
-// CHECK:       [[val_double:%\d+]] = OpConvertUToF %double [[val_ulong]]
+// CHECK:       [[val_double:%\d+]] = OpBitcast %double [[val_ulong]]
 // CHECK:                             OpStore %f64 [[val_double]]
   double f64 = buf.Load<double>(tid.x);
 
@@ -128,7 +128,7 @@ ByteAddressBuffer buf;
 // CHECK:         [[val0_word1_ulong:%\d+]] = OpUConvert %ulong [[val0_word1_uint]]
 // CHECK: [[shifted_val0_word1_ulong:%\d+]] = OpShiftLeftLogical %ulong [[val0_word1_ulong]] %uint_32
 // CHECK:               [[val0_ulong:%\d+]] = OpBitwiseOr %ulong [[val0_word0_ulong]] [[shifted_val0_word1_ulong]]
-// CHECK:              [[val0_double:%\d+]] = OpConvertUToF %double [[val0_ulong]]
+// CHECK:              [[val0_double:%\d+]] = OpBitcast %double [[val0_ulong]]
 //
 // CHECK:                   [[addr_2:%\d+]] = OpIAdd %uint [[addr_1]] %uint_1
 // CHECK:                      [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[addr_2]]
@@ -140,7 +140,7 @@ ByteAddressBuffer buf;
 // CHECK:         [[val1_word1_ulong:%\d+]] = OpUConvert %ulong [[val1_word1_uint]]
 // CHECK: [[shifted_val1_word1_ulong:%\d+]] = OpShiftLeftLogical %ulong [[val1_word1_ulong]] %uint_32
 // CHECK:               [[val1_ulong:%\d+]] = OpBitwiseOr %ulong [[val1_word0_ulong]] [[shifted_val1_word1_ulong]]
-// CHECK:              [[val1_double:%\d+]] = OpConvertUToF %double [[val1_ulong]]
+// CHECK:              [[val1_double:%\d+]] = OpBitcast %double [[val1_ulong]]
 //
 // CHECK:                     [[fArr:%\d+]] = OpCompositeConstruct %_arr_double_uint_2 [[val0_double]] [[val1_double]]
 // CHECK:                                     OpStore %fArr [[fArr]]

+ 2 - 2
tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.templated-load.vector.hlsl

@@ -57,7 +57,7 @@ void main(uint3 tid : SV_DispatchThreadId)
 // CHECK:          [[word1_ulong:%\d+]] = OpUConvert %ulong [[word1]]
 // CHECK:  [[shifted_word1_ulong:%\d+]] = OpShiftLeftLogical %ulong [[word1_ulong]] %uint_32
 // CHECK:           [[val0_ulong:%\d+]] = OpBitwiseOr %ulong [[word0_ulong]] [[shifted_word1_ulong]]
-// CHECK:                 [[val0:%\d+]] = OpConvertUToF %double [[val0_ulong]]
+// CHECK:                 [[val0:%\d+]] = OpBitcast %double [[val0_ulong]]
 // CHECK:              [[index_2:%\d+]] = OpIAdd %uint [[index_1]] %uint_1
 // CHECK:                  [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_uint %buf %uint_0 [[index_2]]
 // CHECK:                [[word0:%\d+]] = OpLoad %uint [[ptr]]
@@ -68,7 +68,7 @@ void main(uint3 tid : SV_DispatchThreadId)
 // CHECK:          [[word1_ulong:%\d+]] = OpUConvert %ulong [[word1]]
 // CHECK:  [[shifted_word1_ulong:%\d+]] = OpShiftLeftLogical %ulong [[word1_ulong]] %uint_32
 // CHECK:           [[val1_ulong:%\d+]] = OpBitwiseOr %ulong [[word0_ulong]] [[shifted_word1_ulong]]
-// CHECK:                 [[val1:%\d+]] = OpConvertUToF %double [[val1_ulong]]
+// CHECK:                 [[val1:%\d+]] = OpBitcast %double [[val1_ulong]]
 // CHECK:                 [[fVec:%\d+]] = OpCompositeConstruct %v2double [[val0]] [[val1]]
 // CHECK:                                 OpStore %f64 [[fVec]]
   float64_t2 f64 = buf.Load<float64_t2>(tid.x);