2
0
Эх сурвалжийг харах

[spirv] Use isResourceType to cover SubpassInput. (#3253)

Fixes #3169.

hlsl::IsResourceType does not cover SubpassInput (which is
vulkan-specific).
Ehsan 4 жил өмнө
parent
commit
c294ebf2c1

+ 2 - 2
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -98,14 +98,14 @@ bool isTextureBuffer(QualType);
 /// or an array of ConstantBuffers/TextureBuffers.
 /// or an array of ConstantBuffers/TextureBuffers.
 bool isConstantTextureBuffer(QualType);
 bool isConstantTextureBuffer(QualType);
 
 
-/// \brief Returns true if the decl will have a SPIR-V resource type.
+/// \brief Returns true if the given type will have a SPIR-V resource type.
 ///
 ///
 /// Note that this function covers the following HLSL types:
 /// Note that this function covers the following HLSL types:
 /// * ConstantBuffer/TextureBuffer
 /// * ConstantBuffer/TextureBuffer
 /// * Various structured buffers
 /// * Various structured buffers
 /// * (RW)ByteAddressBuffer
 /// * (RW)ByteAddressBuffer
 /// * SubpassInput(MS)
 /// * SubpassInput(MS)
-bool isResourceType(const ValueDecl *decl);
+bool isResourceType(QualType);
 
 
 /// Returns true if the given type is or contains a 16-bit type.
 /// Returns true if the given type is or contains a 16-bit type.
 /// The caller must also specify whether 16-bit types have been enabled via
 /// The caller must also specify whether 16-bit types have been enabled via

+ 13 - 13
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -291,18 +291,16 @@ bool isConstantTextureBuffer(QualType type) {
   return isConstantBuffer(type) || isTextureBuffer(type);
   return isConstantBuffer(type) || isTextureBuffer(type);
 }
 }
 
 
-bool isResourceType(const ValueDecl *decl) {
-  QualType declType = decl->getType();
-
+bool isResourceType(QualType type) {
   // Deprive the arrayness to see the element type
   // Deprive the arrayness to see the element type
-  while (declType->isArrayType()) {
-    declType = declType->getAsArrayTypeUnsafe()->getElementType();
+  while (type->isArrayType()) {
+    type = type->getAsArrayTypeUnsafe()->getElementType();
   }
   }
 
 
-  if (isSubpassInput(declType) || isSubpassInputMS(declType))
+  if (isSubpassInput(type) || isSubpassInputMS(type))
     return true;
     return true;
 
 
-  return hlsl::IsHLSLResourceType(declType);
+  return hlsl::IsHLSLResourceType(type);
 }
 }
 
 
 bool isOrContains16BitType(QualType type, bool enable16BitTypesOption) {
 bool isOrContains16BitType(QualType type, bool enable16BitTypesOption) {
@@ -1256,9 +1254,9 @@ bool isResourceOnlyStructure(QualType type) {
 
 
   if (const auto *structType = type->getAs<RecordType>()) {
   if (const auto *structType = type->getAs<RecordType>()) {
     for (const auto *field : structType->getDecl()->fields()) {
     for (const auto *field : structType->getDecl()->fields()) {
+      const auto fieldType = field->getType();
       // isResourceType does remove arrayness for the field if needed.
       // isResourceType does remove arrayness for the field if needed.
-      if (!isResourceType(field) &&
-          !isResourceOnlyStructure(field->getType())) {
+      if (!isResourceType(fieldType) && !isResourceOnlyStructure(fieldType)) {
         return false;
         return false;
       }
       }
     }
     }
@@ -1275,10 +1273,11 @@ bool isStructureContainingResources(QualType type) {
 
 
   if (const auto *structType = type->getAs<RecordType>()) {
   if (const auto *structType = type->getAs<RecordType>()) {
     for (const auto *field : structType->getDecl()->fields()) {
     for (const auto *field : structType->getDecl()->fields()) {
+      const auto fieldType = field->getType();
       // isStructureContainingResources and isResourceType functions both remove
       // isStructureContainingResources and isResourceType functions both remove
       // arrayness for the field if needed.
       // arrayness for the field if needed.
-      if (isStructureContainingResources(field->getType()) ||
-          isResourceType(field)) {
+      if (isStructureContainingResources(fieldType) ||
+          isResourceType(fieldType)) {
         return true;
         return true;
       }
       }
     }
     }
@@ -1293,10 +1292,11 @@ bool isStructureContainingNonResources(QualType type) {
 
 
   if (const auto *structType = type->getAs<RecordType>()) {
   if (const auto *structType = type->getAs<RecordType>()) {
     for (const auto *field : structType->getDecl()->fields()) {
     for (const auto *field : structType->getDecl()->fields()) {
+      const auto fieldType = field->getType();
       // isStructureContainingNonResources and isResourceType functions both
       // isStructureContainingNonResources and isResourceType functions both
       // remove arrayness for the field if needed.
       // remove arrayness for the field if needed.
-      if (isStructureContainingNonResources(field->getType()) ||
-          !isResourceType(field)) {
+      if (isStructureContainingNonResources(fieldType) ||
+          !isResourceType(fieldType)) {
         return true;
         return true;
       }
       }
     }
     }

+ 7 - 6
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -69,11 +69,11 @@ uint32_t getNumBindingsUsedByResourceType(QualType type) {
 
 
   // Once we remove the arrayness, we expect the given type to be either a
   // Once we remove the arrayness, we expect the given type to be either a
   // resource OR a structure that only contains resources.
   // resource OR a structure that only contains resources.
-  assert(hlsl::IsHLSLResourceType(type) || isResourceOnlyStructure(type));
+  assert(isResourceType(type) || isResourceOnlyStructure(type));
 
 
   // In the case of a resource, each resource takes 1 binding slot, so in total
   // In the case of a resource, each resource takes 1 binding slot, so in total
   // it consumes: 1 * arrayFactor.
   // it consumes: 1 * arrayFactor.
-  if (hlsl::IsHLSLResourceType(type))
+  if (isResourceType(type))
     return arrayFactor;
     return arrayFactor;
 
 
   // In the case of a struct of resources, we need to sum up the number of
   // In the case of a struct of resources, we need to sum up the number of
@@ -229,10 +229,11 @@ bool shouldSkipInStructLayout(const Decl *decl) {
       return true;
       return true;
 
 
     // Other resource types
     // Other resource types
-    if (const auto *valueDecl = dyn_cast<ValueDecl>(decl))
-      if (isResourceType(valueDecl) ||
-          isResourceOnlyStructure((valueDecl->getType())))
+    if (const auto *valueDecl = dyn_cast<ValueDecl>(decl)) {
+      const auto declType = valueDecl->getType();
+      if (isResourceType(declType) || isResourceOnlyStructure(declType))
         return true;
         return true;
+    }
   }
   }
 
 
   return false;
   return false;
@@ -791,7 +792,7 @@ SpirvVariable *DeclResultIdMapper::createExternVar(const VarDecl *var) {
   const auto rule = getLayoutRuleForExternVar(type, spirvOptions);
   const auto rule = getLayoutRuleForExternVar(type, spirvOptions);
   const auto loc = var->getLocation();
   const auto loc = var->getLocation();
 
 
-  if (!isGroupShared && !isResourceType(var) &&
+  if (!isGroupShared && !isResourceType(type) &&
       !isResourceOnlyStructure(type)) {
       !isResourceOnlyStructure(type)) {
 
 
     // We currently cannot support global structures that contain both resources
     // We currently cannot support global structures that contain both resources

+ 14 - 16
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -2178,6 +2178,7 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     // the LValueToRValue implicit cast here.
     // the LValueToRValue implicit cast here.
     auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts();
     auto *arg = callExpr->getArg(i)->IgnoreParenLValueCasts();
     const auto *param = callee->getParamDecl(i);
     const auto *param = callee->getParamDecl(i);
+    const auto paramType = param->getType();
 
 
     // Get the evaluation info if this argument is referencing some variable
     // Get the evaluation info if this argument is referencing some variable
     // *as a whole*, in which case we can avoid creating the temporary variable
     // *as a whole*, in which case we can avoid creating the temporary variable
@@ -2199,8 +2200,8 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     // expects are point-to-pointer argument for resources, which will be
     // expects are point-to-pointer argument for resources, which will be
     // resolved by legalization.
     // resolved by legalization.
     if ((argInfo || (argInst && !argInst->isRValue())) &&
     if ((argInfo || (argInst && !argInst->isRValue())) &&
-        canActAsOutParmVar(param) && !isResourceType(param) &&
-        paramTypeMatchesArgType(param->getType(), arg->getType())) {
+        canActAsOutParmVar(param) && !isResourceType(paramType) &&
+        paramTypeMatchesArgType(paramType, arg->getType())) {
       // Based on SPIR-V spec, function parameter must be always Function
       // Based on SPIR-V spec, function parameter must be always Function
       // scope. In addition, we must pass memory object declaration argument
       // scope. In addition, we must pass memory object declaration argument
       // to function. If we pass an argument that is not function scope
       // to function. If we pass an argument that is not function scope
@@ -2246,16 +2247,14 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
       // the function. And we will cast back the results once the function call
       // the function. And we will cast back the results once the function call
       // has returned.
       // has returned.
       if (canActAsOutParmVar(param) &&
       if (canActAsOutParmVar(param) &&
-          !paramTypeMatchesArgType(param->getType(), arg->getType())) {
-        auto paramType = param->getType();
+          !paramTypeMatchesArgType(paramType, arg->getType())) {
         if (const auto *refType = paramType->getAs<ReferenceType>())
         if (const auto *refType = paramType->getAs<ReferenceType>())
-          paramType = refType->getPointeeType();
-        rhsVal =
-            castToType(rhsVal, arg->getType(), paramType, arg->getLocStart());
+          rhsVal = castToType(rhsVal, arg->getType(), refType->getPointeeType(),
+                              arg->getLocStart());
       }
       }
 
 
       // Initialize the temporary variables using the contents of the arguments
       // Initialize the temporary variables using the contents of the arguments
-      storeValue(tempVar, rhsVal, param->getType(), arg->getLocStart());
+      storeValue(tempVar, rhsVal, paramType, arg->getLocStart());
     }
     }
   }
   }
 
 
@@ -2280,6 +2279,7 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
   // Go through all parameters and write those marked as out/inout
   // Go through all parameters and write those marked as out/inout
   for (uint32_t i = 0; i < numParams; ++i) {
   for (uint32_t i = 0; i < numParams; ++i) {
     const auto *param = callee->getParamDecl(i);
     const auto *param = callee->getParamDecl(i);
+    const auto paramType = param->getType();
     // If it calls a non-static member function, the object itself is argument
     // If it calls a non-static member function, the object itself is argument
     // 0, and therefore all other argument positions are shifted by 1.
     // 0, and therefore all other argument positions are shifted by 1.
     const uint32_t index = i + isNonStaticMemberCall;
     const uint32_t index = i + isNonStaticMemberCall;
@@ -2288,21 +2288,19 @@ SpirvInstruction *SpirvEmitter::processCall(const CallExpr *callExpr) {
     // there is no reason to copy back the results after the function call into
     // there is no reason to copy back the results after the function call into
     // the resource.
     // the resource.
     if (isTempVar[index] && canActAsOutParmVar(param) &&
     if (isTempVar[index] && canActAsOutParmVar(param) &&
-        !isResourceType(param)) {
+        !isResourceType(paramType)) {
       const auto *arg = callExpr->getArg(i);
       const auto *arg = callExpr->getArg(i);
-      SpirvInstruction *value = spvBuilder.createLoad(
-          param->getType(), vars[index], arg->getLocStart());
+      SpirvInstruction *value =
+          spvBuilder.createLoad(paramType, vars[index], arg->getLocStart());
 
 
       // Now we want to assign 'value' to arg. But first, in rare cases when
       // Now we want to assign 'value' to arg. But first, in rare cases when
       // using 'out' or 'inout' where the parameter and argument have a type
       // using 'out' or 'inout' where the parameter and argument have a type
       // mismatch, we need to first cast 'value' to the type of 'arg' because
       // mismatch, we need to first cast 'value' to the type of 'arg' because
       // the AST will not include a cast node.
       // the AST will not include a cast node.
-      if (!paramTypeMatchesArgType(param->getType(), arg->getType())) {
-        auto paramType = param->getType();
+      if (!paramTypeMatchesArgType(paramType, arg->getType())) {
         if (const auto *refType = paramType->getAs<ReferenceType>())
         if (const auto *refType = paramType->getAs<ReferenceType>())
-          paramType = refType->getPointeeType();
-        value =
-            castToType(value, paramType, arg->getType(), arg->getLocStart());
+          value = castToType(value, refType->getPointeeType(), arg->getType(),
+                             arg->getLocStart());
       }
       }
 
 
       processAssignment(arg, value, false, args[index]);
       processAssignment(arg, value, false, args[index]);

+ 4 - 1
tools/clang/test/CodeGenSPIRV/vk.binding.global-struct-of-resources.1.hlsl

@@ -11,7 +11,9 @@
 //
 //
 // CHECK: OpDecorate %globalSamplerState DescriptorSet 0
 // CHECK: OpDecorate %globalSamplerState DescriptorSet 0
 // CHECK: OpDecorate %globalSamplerState Binding 3
 // CHECK: OpDecorate %globalSamplerState Binding 3
-
+//
+// CHECK: OpDecorate %MySubpassInput DescriptorSet 0
+// CHECK: OpDecorate %MySubpassInput Binding 4
 
 
 // CHECK:                          %S = OpTypeStruct %type_2d_image %type_sampler
 // CHECK:                          %S = OpTypeStruct %type_2d_image %type_sampler
 // CHECK:     %_ptr_UniformConstant_S = OpTypePointer UniformConstant %S
 // CHECK:     %_ptr_UniformConstant_S = OpTypePointer UniformConstant %S
@@ -31,6 +33,7 @@ S globalS;
 
 
 Texture2D globalTexture;
 Texture2D globalTexture;
 SamplerState globalSamplerState;
 SamplerState globalSamplerState;
+[[vk::input_attachment_index (0)]] SubpassInput<float4> MySubpassInput;
 
 
 float4 main() : SV_Target {
 float4 main() : SV_Target {
 // CHECK: [[globalS:%\d+]] = OpLoad %S %globalS
 // CHECK: [[globalS:%\d+]] = OpLoad %S %globalS