Browse Source

[spirv] Fix storage class for ptr-to-ptr in Vulkan 1.2. (#3344)

* [spirv] Fix storage class for ptr-to-ptr in Vulkan 1.2.

* [spirv] Add one more test case.

* Address code review comments.
Ehsan 4 years ago
parent
commit
f11daf7d42

+ 9 - 9
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -45,9 +45,9 @@ public:
   void setResultId(uint32_t id) { functionId = id; }
   void setResultId(uint32_t id) { functionId = id; }
 
 
   // Sets the lowered (SPIR-V) return type.
   // Sets the lowered (SPIR-V) return type.
-  void setReturnType(SpirvType *type) { returnType = type; }
+  void setReturnType(const SpirvType *type) { returnType = type; }
   // Returns the lowered (SPIR-V) return type.
   // Returns the lowered (SPIR-V) return type.
-  SpirvType *getReturnType() const { return returnType; }
+  const SpirvType *getReturnType() const { return returnType; }
 
 
   // Sets the function AST return type
   // Sets the function AST return type
   void setAstReturnType(QualType type) { astReturnType = type; }
   void setAstReturnType(QualType type) { astReturnType = type; }
@@ -116,13 +116,13 @@ public:
   }
   }
 
 
 private:
 private:
-  uint32_t functionId;    ///< This function's <result-id>
-  QualType astReturnType; ///< The return type
-  SpirvType *returnType;  ///< The lowered return type
-  SpirvType *fnType;      ///< The SPIR-V function type
-  bool relaxedPrecision;  ///< Whether the return type is at relaxed precision
-  bool precise;           ///< Whether the return value is 'precise'
-  bool noInline;          ///< The function is marked as no inline
+  uint32_t functionId;         ///< This function's <result-id>
+  QualType astReturnType;      ///< The return type
+  const SpirvType *returnType; ///< The lowered return type
+  SpirvType *fnType;           ///< The SPIR-V function type
+  bool relaxedPrecision; ///< Whether the return type is at relaxed precision
+  bool precise;          ///< Whether the return value is 'precise'
+  bool noInline;         ///< The function is marked as no inline
 
 
   /// Legalization-specific code
   /// Legalization-specific code
   ///
   ///

+ 1 - 1
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -42,7 +42,7 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
         lowerType(fn->getAstReturnType(), SpirvLayoutRule::Void,
         lowerType(fn->getAstReturnType(), SpirvLayoutRule::Void,
                   /*isRowMajor*/ llvm::None,
                   /*isRowMajor*/ llvm::None,
                   /*SourceLocation*/ {});
                   /*SourceLocation*/ {});
-    fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
+    fn->setReturnType(spirvReturnType);
 
 
     // Lower the function parameter types.
     // Lower the function parameter types.
     auto params = fn->getParameters();
     auto params = fn->getParameters();

+ 80 - 6
tools/clang/lib/SPIRV/RemoveBufferBlockVisitor.cpp

@@ -9,6 +9,7 @@
 
 
 #include "RemoveBufferBlockVisitor.h"
 #include "RemoveBufferBlockVisitor.h"
 #include "clang/SPIRV/SpirvContext.h"
 #include "clang/SPIRV/SpirvContext.h"
+#include "clang/SPIRV/SpirvFunction.h"
 
 
 namespace clang {
 namespace clang {
 namespace spirv {
 namespace spirv {
@@ -71,15 +72,88 @@ bool RemoveBufferBlockVisitor::visitInstruction(SpirvInstruction *inst) {
 
 
   // For all instructions, if the result type is a pointer pointing to a struct
   // For all instructions, if the result type is a pointer pointing to a struct
   // with StorageBuffer interface, the storage class must be updated.
   // with StorageBuffer interface, the storage class must be updated.
-  if (auto *ptrResultType = dyn_cast<SpirvPointerType>(inst->getResultType())) {
-    if (hasStorageBufferInterfaceType(ptrResultType->getPointeeType()) &&
-        ptrResultType->getStorageClass() != spv::StorageClass::StorageBuffer) {
-      inst->setStorageClass(spv::StorageClass::StorageBuffer);
-      inst->setResultType(context.getPointerType(
-          ptrResultType->getPointeeType(), spv::StorageClass::StorageBuffer));
+  const auto *instType = inst->getResultType();
+  const auto *newInstType = instType;
+  spv::StorageClass newInstStorageClass = spv::StorageClass::Max;
+  if (updateStorageClass(instType, &newInstType, &newInstStorageClass)) {
+    inst->setResultType(newInstType);
+    inst->setStorageClass(newInstStorageClass);
+  }
+
+  return true;
+}
+
+bool RemoveBufferBlockVisitor::updateStorageClass(
+    const SpirvType *type, const SpirvType **newType,
+    spv::StorageClass *newStorageClass) {
+  auto *ptrType = dyn_cast<SpirvPointerType>(type);
+  if (ptrType == nullptr)
+    return false;
+
+  const auto *innerType = ptrType->getPointeeType();
+
+  // For usual cases such as _ptr_Uniform_StructuredBuffer_float.
+  if (hasStorageBufferInterfaceType(innerType) &&
+      ptrType->getStorageClass() != spv::StorageClass::StorageBuffer) {
+    *newType =
+        context.getPointerType(innerType, spv::StorageClass::StorageBuffer);
+    *newStorageClass = spv::StorageClass::StorageBuffer;
+    return true;
+  }
+
+  // For pointer-to-pointer cases (which need legalization), we could have a
+  // type like: _ptr_Function__ptr_Uniform_type_StructuredBuffer_float.
+  // In such cases, we need to update the storage class for the inner pointer.
+  if (const auto *innerPtrType = dyn_cast<SpirvPointerType>(innerType)) {
+    if (hasStorageBufferInterfaceType(innerPtrType->getPointeeType()) &&
+        innerPtrType->getStorageClass() != spv::StorageClass::StorageBuffer) {
+      auto *newInnerType = context.getPointerType(
+          innerPtrType->getPointeeType(), spv::StorageClass::StorageBuffer);
+      *newType =
+          context.getPointerType(newInnerType, ptrType->getStorageClass());
+      *newStorageClass = ptrType->getStorageClass();
+      return true;
     }
     }
   }
   }
 
 
+  return false;
+}
+
+bool RemoveBufferBlockVisitor::visit(SpirvFunction *fn, Phase phase) {
+  if (phase == Visitor::Phase::Init) {
+    llvm::SmallVector<const SpirvType *, 4> paramTypes;
+    bool updatedParamTypes = false;
+    for (auto *param : fn->getParameters()) {
+      const auto *paramType = param->getResultType();
+      // This pass is run after all types are lowered.
+      assert(paramType != nullptr);
+
+      // Update the parameter type if needed (update storage class of pointers).
+      const auto *newParamType = paramType;
+      spv::StorageClass newParamSC = spv::StorageClass::Max;
+      if (updateStorageClass(paramType, &newParamType, &newParamSC)) {
+        param->setStorageClass(newParamSC);
+        param->setResultType(newParamType);
+        updatedParamTypes = true;
+      }
+      paramTypes.push_back(newParamType);
+    }
+
+    // Update the return type if needed (update storage class of pointers).
+    const auto *returnType = fn->getReturnType();
+    const auto *newReturnType = returnType;
+    spv::StorageClass newReturnSC = spv::StorageClass::Max;
+    bool updatedReturnType =
+        updateStorageClass(returnType, &newReturnType, &newReturnSC);
+    if (updatedReturnType) {
+      fn->setReturnType(newReturnType);
+    }
+
+    if (updatedParamTypes || updatedReturnType) {
+      fn->setFunctionType(context.getFunctionType(newReturnType, paramTypes));
+    }
+    return true;
+  }
   return true;
   return true;
 }
 }
 
 

+ 28 - 2
tools/clang/lib/SPIRV/RemoveBufferBlockVisitor.h

@@ -26,6 +26,7 @@ public:
       : Visitor(opts, spvCtx), featureManager(astCtx.getDiagnostics(), opts) {}
       : Visitor(opts, spvCtx), featureManager(astCtx.getDiagnostics(), opts) {}
 
 
   bool visit(SpirvModule *, Phase) override;
   bool visit(SpirvModule *, Phase) override;
+  bool visit(SpirvFunction *, Phase) override;
 
 
   using Visitor::visit;
   using Visitor::visit;
 
 
@@ -41,10 +42,35 @@ private:
   /// StorageBuffer.
   /// StorageBuffer.
   bool hasStorageBufferInterfaceType(const SpirvType *type);
   bool hasStorageBufferInterfaceType(const SpirvType *type);
 
 
-  ///  Returns true if the BufferBlock decoration is deprecated (Vulkan 1.2 or
-  ///  above).
+  /// Returns true if the BufferBlock decoration is deprecated (Vulkan 1.2 or
+  /// above).
   bool isBufferBlockDecorationDeprecated();
   bool isBufferBlockDecorationDeprecated();
 
 
+  /// Transforms the given |type| if it is one of the following cases:
+  ///
+  /// 1- a pointer to a structure with StorageBuffer interface
+  /// 2- a pointer to a pointer to a structure with StorageBuffer interface
+  ///
+  /// by updating the storage class of the pointer whose pointee is the struct.
+  ///
+  /// Example of case (1):
+  /// type:              _ptr_Uniform_SturcturedBuffer_float
+  /// new type:          _ptr_StorageBuffer_SturcturedBuffer_float
+  /// new storage class: StorageBuffer
+  ///
+  /// Example of case (2):
+  /// type:              _ptr_Function__ptr_Uniform_SturcturedBuffer_float
+  /// new type:          _ptr_Function__ptr_StorageBuffer_SturcturedBuffer_float
+  /// new storage class: Function
+  ///
+  /// If |type| is transformed, the |newType| and |newStorageClass| are
+  /// returned by reference and the function returns true.
+  ///
+  /// If |type| is not transformed, |newType| and |newStorageClass| are
+  /// untouched, and the function returns false.
+  bool updateStorageClass(const SpirvType *type, const SpirvType **newType,
+                          spv::StorageClass *newStorageClass);
+
   FeatureManager featureManager;
   FeatureManager featureManager;
 };
 };
 
 

+ 12 - 0
tools/clang/test/CodeGenSPIRV/vk.1p2.remove.bufferblock.ptr-to-ptr.example2.hlsl

@@ -0,0 +1,12 @@
+// Run: %dxc -T cs_6_4 -E main -fspv-target-env=vulkan1.2
+
+
+// CHECK: OpDecorate %type_ByteAddressBuffer Block
+
+ByteAddressBuffer g_byteAddressBuffer[] : register(t0, space3);
+[numthreads(1,1,1)]
+void main() {
+// CHECK: %flat_bucket_indices = OpVariable %_ptr_Function__ptr_StorageBuffer_type_ByteAddressBuffer Function
+// CHECK:             {{%\d+}} = OpAccessChain %_ptr_StorageBuffer_type_ByteAddressBuffer %g_byteAddressBuffer %int_0
+  ByteAddressBuffer flat_bucket_indices = g_byteAddressBuffer[0];
+}

+ 46 - 0
tools/clang/test/CodeGenSPIRV/vk.1p2.remove.bufferblock.ptr-to-ptr.hlsl

@@ -0,0 +1,46 @@
+// Run: %dxc -T lib_6_4 -fspv-target-env=vulkan1.2
+
+// We cannot use BufferBlock decoration for SPIR-V 1.4 or above.
+// Instead, we must use Block decorated StorageBuffer Storage Class.
+//
+// As a result, we transform the storage class of types where needed.
+//
+// If a resource is used as a function parameter or function return value,
+// it represents a case that requires legalization. In such cases, a
+// pointer-to-pointer type may be present. We must make sure that in such cases
+// the inner pointer's storage class is also updated if needed.
+
+// CHECK: ; Version: 1.5
+
+// CHECK: OpDecorate %type_StructuredBuffer_float Block
+// CHECK: OpDecorate %type_RWStructuredBuffer_float Block
+// CHECK: OpDecorate %type_ACSBuffer_counter Block
+
+// CHECK: [[fn-type:%\d+]] = OpTypeFunction %_ptr_StorageBuffer_type_StructuredBuffer_float %_ptr_Function__ptr_StorageBuffer_type_StructuredBuffer_float
+
+// CHECK:   %gSBuffer = OpVariable %_ptr_StorageBuffer_type_StructuredBuffer_float StorageBuffer
+StructuredBuffer<float> gSBuffer;
+// CHECK: %gRWSBuffer = OpVariable %_ptr_StorageBuffer_type_RWStructuredBuffer_float StorageBuffer
+RWStructuredBuffer<float> gRWSBuffer;
+
+StructuredBuffer<float> foo(StructuredBuffer<float> pSBuffer);
+
+[shader("raygeneration")]
+void main() {
+// CHECK: %param_var_pSBuffer = OpVariable %_ptr_Function__ptr_StorageBuffer_type_StructuredBuffer_float Function
+// CHECK:                       OpStore %param_var_pSBuffer %gSBuffer
+// CHECK:                       OpFunctionCall %_ptr_StorageBuffer_type_StructuredBuffer_float %foo %param_var_pSBuffer
+  float a = foo(gSBuffer)[0];
+}
+
+// CHECK:      %foo = OpFunction %_ptr_StorageBuffer_type_StructuredBuffer_float None [[fn-type]]
+// CHECK: %pSBuffer = OpFunctionParameter %_ptr_Function__ptr_StorageBuffer_type_StructuredBuffer_float
+StructuredBuffer<float> foo(StructuredBuffer<float> pSBuffer) {
+// CHECK: OpLoad %_ptr_StorageBuffer_type_StructuredBuffer_float %pSBuffer
+// CHECK: OpAccessChain %_ptr_StorageBuffer_float
+  float x = pSBuffer[0];
+// CHECK: [[buf:%\d+]] = OpLoad %_ptr_StorageBuffer_type_StructuredBuffer_float %pSBuffer
+// CHECK: OpReturnValue [[buf]]
+  return pSBuffer;
+}
+

+ 8 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -2431,6 +2431,14 @@ TEST_F(FileTest, Vk1p2BlockDecoration) {
 TEST_F(FileTest, Vk1p2RemoveBufferBlockRuntimeArray) {
 TEST_F(FileTest, Vk1p2RemoveBufferBlockRuntimeArray) {
   runFileTest("vk.1p2.remove.bufferblock.runtimearray.hlsl");
   runFileTest("vk.1p2.remove.bufferblock.runtimearray.hlsl");
 }
 }
+TEST_F(FileTest, Vk1p2RemoveBufferBlockPtrToPtr) {
+  setBeforeHLSLLegalization();
+  runFileTest("vk.1p2.remove.bufferblock.ptr-to-ptr.hlsl");
+}
+TEST_F(FileTest, Vk1p2RemoveBufferBlockPtrToPtr2) {
+  setBeforeHLSLLegalization();
+  runFileTest("vk.1p2.remove.bufferblock.ptr-to-ptr.example2.hlsl");
+}
 
 
 // Test shaders that require Vulkan1.1 support with
 // Test shaders that require Vulkan1.1 support with
 // -fspv-target-env=vulkan1.2 option to make sure that enabling
 // -fspv-target-env=vulkan1.2 option to make sure that enabling