Browse Source

[spirv] Add support for casting involving vector decomposition (#1677)

Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/1673
Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/1675
Fixes https://github.com/Microsoft/DirectXShaderCompiler/issues/1676
Lei Zhang 6 years ago
parent
commit
740b6701af

+ 15 - 2
tools/clang/lib/SPIRV/InitListHandler.cpp

@@ -26,7 +26,7 @@ InitListHandler::InitListHandler(SPIRVEmitter &emitter)
       typeTranslator(emitter.getTypeTranslator()),
       diags(emitter.getDiagnosticsEngine()) {}
 
-uint32_t InitListHandler::process(const InitListExpr *expr) {
+uint32_t InitListHandler::processInit(const InitListExpr *expr) {
   initializers.clear();
   scalars.clear();
 
@@ -35,7 +35,20 @@ uint32_t InitListHandler::process(const InitListExpr *expr) {
   // tail of the vector. This is more efficient than using a deque.
   std::reverse(std::begin(initializers), std::end(initializers));
 
-  const uint32_t init = createInitForType(expr->getType(), expr->getExprLoc());
+  return doProcess(expr->getType(), expr->getExprLoc());
+}
+
+uint32_t InitListHandler::processCast(QualType toType, const Expr *expr) {
+  initializers.clear();
+  scalars.clear();
+
+  initializers.push_back(expr);
+
+  return doProcess(toType, expr->getExprLoc());
+}
+
+uint32_t InitListHandler::doProcess(QualType type, SourceLocation srcLoc) {
+  const uint32_t init = createInitForType(type, srcLoc);
 
   if (init) {
     // For successful translation, we should have consumed all initializers and

+ 9 - 1
tools/clang/lib/SPIRV/InitListHandler.h

@@ -85,7 +85,11 @@ public:
 
   /// Processes the given InitListExpr and returns the <result-id> for the final
   /// SPIR-V value.
-  uint32_t process(const InitListExpr *expr);
+  uint32_t processInit(const InitListExpr *expr);
+
+  /// Casts the given Expr to the given toType and returns the <result-id> for
+  /// the final SPIR-V value.
+  uint32_t processCast(QualType toType, const Expr *expr);
 
 private:
   /// \brief Wrapper method to create an error message and report it
@@ -97,6 +101,10 @@ private:
     return diags.Report(loc, diagId);
   }
 
+  /// Processes the expressions in initializers and returns the <result-id> for
+  /// the final SPIR-V value of the given type.
+  uint32_t doProcess(QualType type, SourceLocation srcLoc);
+
   /// Flattens the given InitListExpr and puts all non-InitListExpr AST nodes
   /// into initializers.
   void flatten(const InitListExpr *expr);

+ 9 - 1
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2465,6 +2465,14 @@ SpirvEvalInfo SPIRVEmitter::doCastExpr(const CastExpr *expr) {
              typeTranslator.isSameType(expr->getType(), subExprType)) {
       return doExpr(subExpr);
     }
+    // We can have casts changing the shape but without affecting memory order,
+    // e.g., `float4 a[2]; float b[8] = (float[8])a;`. This is also represented
+    // as FlatConversion. For such cases, we can rely on the InitListHandler,
+    // which can decompse vectors/matrices.
+    else if (subExprType->isArrayType()) {
+      auto valId = InitListHandler(*this).processCast(expr->getType(), subExpr);
+      return SpirvEvalInfo(valId).setRValue();
+    }
 
     if (!subExprId)
       subExprId = doExpr(subExpr);
@@ -4652,7 +4660,7 @@ SpirvEvalInfo SPIRVEmitter::doInitListExpr(const InitListExpr *expr) {
   if (const uint32_t id = tryToEvaluateAsConst(expr))
     return SpirvEvalInfo(id).setRValue();
 
-  return SpirvEvalInfo(InitListHandler(*this).process(expr)).setRValue();
+  return SpirvEvalInfo(InitListHandler(*this).processInit(expr)).setRValue();
 }
 
 SpirvEvalInfo SPIRVEmitter::doMemberExpr(const MemberExpr *expr) {

+ 23 - 0
tools/clang/test/CodeGenSPIRV/cast.flat-conversion.vector.hlsl

@@ -0,0 +1,23 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct S {
+    float2 data[2];
+};
+
+StructuredBuffer<S> MySB;
+
+float4 main() : SV_TARGET
+{
+// CHECK:      [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySB %int_0 %uint_0 %int_0 %uint_0
+// CHECK-NEXT: [[vec:%\d+]] = OpLoad %v2float [[ptr]]
+// CHECK-NEXT:  [[v1:%\d+]] = OpCompositeExtract %float [[vec]] 0
+// CHECK-NEXT:  [[v2:%\d+]] = OpCompositeExtract %float [[vec]] 1
+// CHECK-NEXT: [[ptr:%\d+]] = OpAccessChain %_ptr_Uniform_v2float %MySB %int_0 %uint_0 %int_0 %uint_1
+// CHECK-NEXT: [[vec:%\d+]] = OpLoad %v2float [[ptr]]
+// CHECK-NEXT:  [[v3:%\d+]] = OpCompositeExtract %float [[vec]] 0
+// CHECK-NEXT:  [[v4:%\d+]] = OpCompositeExtract %float [[vec]] 1
+// CHECK-NEXT: [[val:%\d+]] = OpCompositeConstruct %_arr_float_uint_4 [[v1]] [[v2]] [[v3]] [[v4]]
+// CHECK-NEXT:                OpStore %data [[val]]
+    float data[4] = (float[4])MySB[0].data;
+    return data[1];
+}

+ 3 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -383,6 +383,9 @@ TEST_F(FileTest, CastFlatConversionNoOp) {
 TEST_F(FileTest, CastFlatConversionLiteralInitializer) {
   runFileTest("cast.flat-conversion.literal-initializer.hlsl");
 }
+TEST_F(FileTest, CastFlatConversionDecomposeVector) {
+  runFileTest("cast.flat-conversion.vector.hlsl");
+}
 TEST_F(FileTest, CastExplicitVecToMat) {
   runFileTest("cast.vec-to-mat.explicit.hlsl");
 }