Browse Source

[spirv] handle bindless opaque array type argument passing (#2928)

* Handle bindless opaque type array argument passing

* Remove paramTypes from SpirvFunction - we can use SpirvType info of parameters instead of it

* clang-format

* Correct access to pointer to opaque array type local variable

* Add unit tests

* Update based on code review

* Fix build failure
Jaebaek Seo 5 years ago
parent
commit
d4a248078c

+ 3 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -276,6 +276,9 @@ bool isOrContainsNonFpColMajorMatrix(const ASTContext &,
 /// \bried Returns true if the given type is a String or StringLiteral type.
 bool isStringType(QualType);
 
+/// \bried Returns true if the given type is a bindless array of an opaque type.
+bool isBindlessOpaqueArray(QualType type);
+
 /// \brief Generates the corresponding SPIR-V vector type for the given Clang
 /// frontend matrix type's vector component and returns the <result-id>.
 ///

+ 0 - 1
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -71,7 +71,6 @@ public:
   ///
   /// At any time, there can only exist at most one function under building.
   SpirvFunction *beginFunction(QualType returnType,
-                               llvm::ArrayRef<QualType> paramTypes,
                                SourceLocation, llvm::StringRef name = "",
                                bool isPrecise = false,
                                SpirvFunction *func = nullptr);

+ 5 - 9
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -24,9 +24,8 @@ class SpirvVisitor;
 /// The class representing a SPIR-V function in memory.
 class SpirvFunction {
 public:
-  SpirvFunction(QualType astReturnType, llvm::ArrayRef<QualType> astParamTypes,
-                SourceLocation, llvm::StringRef name = "",
-                bool precise = false);
+  SpirvFunction(QualType astReturnType, SourceLocation,
+                llvm::StringRef name = "", bool precise = false);
   ~SpirvFunction() = default;
 
   // Forbid copy construction and assignment
@@ -53,12 +52,10 @@ public:
   // Gets the function AST return type
   QualType getAstReturnType() const { return astReturnType; }
 
-  // Sets the vector of parameter QualTypes.
-  void setAstParamTypes(llvm::ArrayRef<QualType> paramTypes) {
-    astParamTypes.append(paramTypes.begin(), paramTypes.end());
+  // Gets the vector of parameters.
+  llvm::SmallVector<SpirvFunctionParameter *, 8> getParameters() const {
+    return parameters;
   }
-  // Gets the vector of parameter QualTypes.
-  llvm::ArrayRef<QualType> getAstParamTypes() const { return astParamTypes; }
 
   // Sets the SPIR-V type of the function
   void setFunctionType(SpirvType *type) { fnType = type; }
@@ -99,7 +96,6 @@ private:
   uint32_t functionId; ///< This function's <result-id>
 
   QualType astReturnType;                       ///< The return type
-  llvm::SmallVector<QualType, 4> astParamTypes; ///< The paratemer types in AST
   SpirvType *returnType;                        ///< The lowered return type
   SpirvType *fnType;                            ///< The SPIR-V function type
 

+ 7 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -454,6 +454,10 @@ public:
   SpirvVariable(QualType resultType, SourceLocation loc, spv::StorageClass sc,
                 bool isPrecise, SpirvInstruction *initializerId = 0);
 
+  SpirvVariable(const SpirvType *spvType, SourceLocation loc,
+                spv::StorageClass sc, bool isPrecise,
+                SpirvInstruction *initializerId = 0);
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Variable;
@@ -482,6 +486,9 @@ public:
   SpirvFunctionParameter(QualType resultType, bool isPrecise,
                          SourceLocation loc);
 
+  SpirvFunctionParameter(const SpirvType *spvType, bool isPrecise,
+                         SourceLocation loc);
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_FunctionParameter;

+ 5 - 0
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -1124,6 +1124,11 @@ bool isStringType(QualType type) {
   return hlsl::IsStringType(type) || hlsl::IsStringLiteralType(type);
 }
 
+bool isBindlessOpaqueArray(QualType type) {
+  return !type.isNull() && isOpaqueArrayType(type) &&
+         !type->isConstantArrayType();
+}
+
 QualType getComponentVectorType(const ASTContext &astContext,
                                 QualType matrixType) {
   assert(isMxNMatrix(matrixType));

+ 15 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -574,6 +574,19 @@ SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
           {spvBuilder.getConstantInt(
               astContext.IntTy, llvm::APInt(32, info->indexInCTBuffer, true))},
           loc);
+    } else if (auto *type = info->instr->getResultType()) {
+      const auto *ptrTy = dyn_cast<HybridPointerType>(type);
+
+      // If it is a local variable or function parameter with a bindless
+      // array of an opaque type, we have to load it because we pass a
+      // pointer of a global variable that has the bindless opaque array.
+      if (ptrTy != nullptr && isBindlessOpaqueArray(decl->getType())) {
+        auto *load = spvBuilder.createLoad(ptrTy, info->instr, loc);
+        load->setRValue(false);
+        return load;
+      } else {
+        return *info;
+      }
     } else {
       return *info;
     }
@@ -594,7 +607,6 @@ DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
   const auto loc = param->getLocation();
   SpirvFunctionParameter *fnParamInstr = spvBuilder.addFnParam(
       type, param->hasAttr<HLSLPreciseAttr>(), loc, param->getName());
-
   bool isAlias = false;
   (void)getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias);
   fnParamInstr->setContainsAliasComponent(isAlias);
@@ -1040,9 +1052,8 @@ SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
   // definition is seen, the parameter types will be set properly and take into
   // account whether the function is a member function of a class/struct (in
   // which case a 'this' parameter is added at the beginnig).
-  SpirvFunction *spirvFunction = new (spvContext)
-      SpirvFunction(fn->getReturnType(), /* param QualTypes */ {},
-                    fn->getLocation(), fn->getName(), isPrecise);
+  SpirvFunction *spirvFunction = new (spvContext) SpirvFunction(
+      fn->getReturnType(), fn->getLocation(), fn->getName(), isPrecise);
 
   // No need to dereference to get the pointer. Function returns that are
   // stand-alone aliases are already pointers to values. All other cases should

+ 4 - 8
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -36,7 +36,7 @@ namespace clang {
 namespace spirv {
 
 bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
-  if (phase == Visitor::Phase::Init) {
+  if (phase == Visitor::Phase::Done) {
     // Lower the function return type.
     const SpirvType *spirvReturnType =
         lowerType(fn->getAstReturnType(), SpirvLayoutRule::Void,
@@ -45,14 +45,10 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
     fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
 
     // Lower the function parameter types.
-    auto paramQualTypes = fn->getAstParamTypes();
+    auto params = fn->getParameters();
     llvm::SmallVector<const SpirvType *, 4> spirvParamTypes;
-    for (auto qualtype : paramQualTypes) {
-      const auto *spirvParamType =
-          lowerType(qualtype, SpirvLayoutRule::Void,
-                    /*isRowMajor*/ llvm::None, fn->getSourceLocation());
-      spirvParamTypes.push_back(spvContext.getPointerType(
-          spirvParamType, spv::StorageClass::Function));
+    for (auto *param : params) {
+      spirvParamTypes.push_back(param->getResultType());
     }
     fn->setFunctionType(
         spvContext.getFunctionType(spirvReturnType, spirvParamTypes));

+ 24 - 7
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -29,7 +29,6 @@ SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
 }
 
 SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
-                                           llvm::ArrayRef<QualType> paramTypes,
                                            SourceLocation loc,
                                            llvm::StringRef funcName,
                                            bool isPrecise,
@@ -38,13 +37,12 @@ SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
   if (func) {
     function = func;
     function->setAstReturnType(returnType);
-    function->setAstParamTypes(paramTypes);
     function->setSourceLocation(loc);
     function->setFunctionName(funcName);
     function->setPrecise(isPrecise);
   } else {
-    function = new (context)
-        SpirvFunction(returnType, paramTypes, loc, funcName, isPrecise);
+    function =
+        new (context) SpirvFunction(returnType, loc, funcName, isPrecise);
   }
 
   return function;
@@ -55,7 +53,16 @@ SpirvFunctionParameter *SpirvBuilder::addFnParam(QualType ptrType,
                                                  SourceLocation loc,
                                                  llvm::StringRef name) {
   assert(function && "found detached parameter");
-  auto *param = new (context) SpirvFunctionParameter(ptrType, isPrecise, loc);
+  SpirvFunctionParameter *param = nullptr;
+  if (isBindlessOpaqueArray(ptrType)) {
+    // If it is a bindless array of an opaque type, we have to use
+    // a pointer to a pointer of the runtime array.
+    param = new (context) SpirvFunctionParameter(
+        context.getPointerType(ptrType, spv::StorageClass::UniformConstant),
+        isPrecise, loc);
+  } else {
+    param = new (context) SpirvFunctionParameter(ptrType, isPrecise, loc);
+  }
   param->setStorageClass(spv::StorageClass::Function);
   param->setDebugName(name);
   function->addParameter(param);
@@ -66,8 +73,18 @@ SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
                                       llvm::StringRef name, bool isPrecise,
                                       SpirvInstruction *init) {
   assert(function && "found detached local variable");
-  auto *var = new (context) SpirvVariable(
-      valueType, loc, spv::StorageClass::Function, isPrecise, init);
+  SpirvVariable *var = nullptr;
+  if (isBindlessOpaqueArray(valueType)) {
+    // If it is a bindless array of an opaque type, we have to use
+    // a pointer to a pointer of the runtime array.
+    var = new (context)
+        SpirvVariable(context.getPointerType(
+                          valueType, spv::StorageClass::UniformConstant),
+                      loc, spv::StorageClass::Function, isPrecise, init);
+  } else {
+    var = new (context) SpirvVariable(
+        valueType, loc, spv::StorageClass::Function, isPrecise, init);
+  }
   var->setDebugName(name);
   function->addVariable(var);
   return var;

+ 21 - 33
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -1032,39 +1032,23 @@ void SpirvEmitter::doFunctionDecl(const FunctionDecl *decl) {
   const QualType retType =
       declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
 
-  // Construct the function signature.
-  llvm::SmallVector<QualType, 4> paramTypes;
+  spvBuilder.beginFunction(retType, decl->getLocStart(), funcName,
+                           decl->hasAttr<HLSLPreciseAttr>(), func);
 
-  bool isNonStaticMemberFn = false;
   if (const auto *memberFn = dyn_cast<CXXMethodDecl>(decl)) {
-    isNonStaticMemberFn = !memberFn->isStatic();
-
-    if (isNonStaticMemberFn) {
+    if (!memberFn->isStatic()) {
       // For non-static member function, the first parameter should be the
       // object on which we are invoking this method.
-      const QualType valueType =
-          memberFn->getThisType(astContext)->getPointeeType();
-      paramTypes.push_back(valueType);
-    }
-  }
-
-  for (const auto *param : decl->params()) {
-    const QualType valueType =
-        declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
-    paramTypes.push_back(valueType);
-  }
-
-  spvBuilder.beginFunction(retType, paramTypes, decl->getLocStart(), funcName,
-                           decl->hasAttr<HLSLPreciseAttr>(), func);
-
-  if (isNonStaticMemberFn) {
-    // Remember the parameter for the 'this' object so later we can handle
-    // CXXThisExpr correctly.
-    curThis = spvBuilder.addFnParam(paramTypes[0], /*isPrecise*/ false,
-                                    decl->getLocStart(), "param.this");
-    if (isOrContainsAKindOfStructuredOrByteBuffer(paramTypes[0])) {
-      curThis->setContainsAliasComponent(true);
-      needsLegalization = true;
+      QualType valueType = memberFn->getThisType(astContext)->getPointeeType();
+
+      // Remember the parameter for the 'this' object so later we can handle
+      // CXXThisExpr correctly.
+      curThis = spvBuilder.addFnParam(valueType, /*isPrecise*/ false,
+                                      decl->getLocStart(), "param.this");
+      if (isOrContainsAKindOfStructuredOrByteBuffer(valueType)) {
+        curThis->setContainsAliasComponent(true);
+        needsLegalization = true;
+      }
     }
   }
 
@@ -5228,9 +5212,14 @@ void SpirvEmitter::storeValue(SpirvInstruction *lhsPtr,
     // wholesale handling here, they will be in the final transformed code.
     // Drivers don't like that.
     // TODO: consider moving this hack into SPIRV-Tools as a transformation.
-    assert(lhsValType->isConstantArrayType());
     assert(!rhsVal->isRValue());
 
+    if (!lhsValType->isConstantArrayType()) {
+      spvBuilder.createStore(lhsPtr, rhsVal, loc);
+      needsLegalization = true;
+      return;
+    }
+
     const auto *arrayType = astContext.getAsConstantArrayType(lhsValType);
     const auto elemType = arrayType->getElementType();
     const auto arraySize =
@@ -10738,9 +10727,8 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
   // The wrapper entry function surely does not have pre-assigned <result-id>
   // for it like other functions that got added to the work queue following
   // function calls. And the wrapper is the entry function.
-  entryFunction =
-      spvBuilder.beginFunction(astContext.VoidTy, /* param QualTypes */ {},
-                               decl->getLocStart(), decl->getName());
+  entryFunction = spvBuilder.beginFunction(
+      astContext.VoidTy, decl->getLocStart(), decl->getName());
   // Note this should happen before using declIdMapper for other tasks.
   declIdMapper.setEntryFunction(entryFunction);
 

+ 3 - 6
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -15,12 +15,9 @@
 namespace clang {
 namespace spirv {
 
-SpirvFunction::SpirvFunction(QualType returnType,
-                             llvm::ArrayRef<QualType> paramTypes,
-                             SourceLocation loc, llvm::StringRef name,
-                             bool isPrecise)
-    : functionId(0), astReturnType(returnType),
-      astParamTypes(paramTypes.begin(), paramTypes.end()), returnType(nullptr),
+SpirvFunction::SpirvFunction(QualType returnType, SourceLocation loc,
+                             llvm::StringRef name, bool isPrecise)
+    : functionId(0), astReturnType(returnType), returnType(nullptr),
       fnType(nullptr), relaxedPrecision(false), precise(isPrecise),
       containsAlias(false), rvalue(false), functionLoc(loc),
       functionName(name) {}

+ 20 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -255,6 +255,17 @@ SpirvVariable::SpirvVariable(QualType resultType, SourceLocation loc,
   setPrecise(precise);
 }
 
+SpirvVariable::SpirvVariable(const SpirvType *spvType, SourceLocation loc,
+                             spv::StorageClass sc, bool precise,
+                             SpirvInstruction *initializerInst)
+    : SpirvInstruction(IK_Variable, spv::Op::OpVariable, QualType(), loc),
+      initializer(initializerInst), descriptorSet(-1), binding(-1),
+      hlslUserType("") {
+  setResultType(spvType);
+  setStorageClass(sc);
+  setPrecise(precise);
+}
+
 SpirvFunctionParameter::SpirvFunctionParameter(QualType resultType,
                                                bool isPrecise,
                                                SourceLocation loc)
@@ -263,6 +274,15 @@ SpirvFunctionParameter::SpirvFunctionParameter(QualType resultType,
   setPrecise(isPrecise);
 }
 
+SpirvFunctionParameter::SpirvFunctionParameter(const SpirvType *spvType,
+                                               bool isPrecise,
+                                               SourceLocation loc)
+    : SpirvInstruction(IK_FunctionParameter, spv::Op::OpFunctionParameter,
+                       QualType(), loc) {
+  setResultType(spvType);
+  setPrecise(isPrecise);
+}
+
 SpirvMerge::SpirvMerge(Kind kind, spv::Op op, SourceLocation loc,
                        SpirvBasicBlock *mergeLabel)
     : SpirvInstruction(kind, op, QualType(), loc), mergeBlock(mergeLabel) {}

+ 23 - 0
tools/clang/test/CodeGenSPIRV/fn.param.unsized-opaque-array-o3.hlsl

@@ -0,0 +1,23 @@
+// Run: %dxc -T ps_6_0 -E main -O3
+
+struct PSInput
+{
+    float4 color : COLOR;
+};
+
+Texture2D bindless[];
+
+sampler DummySampler;
+
+// CHECK: [[src:%\d+]] = OpAccessChain %_ptr_UniformConstant_type_2d_image %bindless %uint_4
+// CHECK:                OpLoad %type_2d_image [[src]]
+
+float4 SampleArray(Texture2D src[], uint index, float2 uv)
+{
+    return src[index].Sample(DummySampler, uv);
+}
+
+float4 main(PSInput input) : SV_TARGET
+{
+    return input.color * SampleArray(bindless, 4, float2(1,1));
+}

+ 29 - 0
tools/clang/test/CodeGenSPIRV/fn.param.unsized-opaque-array.hlsl

@@ -0,0 +1,29 @@
+// Run: %dxc -T ps_6_0 -E main
+
+struct PSInput
+{
+    float4 color : COLOR;
+};
+
+Texture2D bindless[];
+
+sampler DummySampler;
+
+// CHECK: %_ptr_Function__ptr_UniformConstant__runtimearr_type_2d_image = OpTypePointer Function %_ptr_UniformConstant__runtimearr_type_2d_image
+// CHECK: %param_var_src = OpVariable %_ptr_Function__ptr_UniformConstant__runtimearr_type_2d_image Function
+// CHECK:                OpStore %param_var_src %bindless
+// CHECK:                OpFunctionCall
+// CHECK:         %src = OpFunctionParameter %_ptr_Function__ptr_UniformConstant__runtimearr_type_2d_image
+// CHECK: [[idx:%\d+]] = OpLoad %uint %index
+// CHECK: [[src:%\d+]] = OpLoad %_ptr_UniformConstant__runtimearr_type_2d_image %src
+// CHECK:                OpAccessChain %_ptr_Function_type_2d_image [[src]] [[idx]]
+
+float4 SampleArray(Texture2D src[], uint index, float2 uv)
+{
+    return src[index].Sample(DummySampler, uv);
+}
+
+float4 main(PSInput input) : SV_TARGET
+{
+    return input.color * SampleArray(bindless, 4, float2(1,1));
+}

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -561,6 +561,12 @@ TEST_F(FileTest, FunctionParamUnsizedArray) {
   // Unsized ararys as function params are not supported.
   runFileTest("fn.param.unsized-array.hlsl", Expect::Failure);
 }
+TEST_F(FileTest, FunctionParamUnsizedOpaqueArray) {
+  runFileTest("fn.param.unsized-opaque-array.hlsl", Expect::Success, false);
+}
+TEST_F(FileTest, FunctionParamUnsizedOpaqueArrayO3) {
+  runFileTest("fn.param.unsized-opaque-array-o3.hlsl");
+}
 TEST_F(FileTest, FunctionInOutParamTypeMismatch) {
   // The type for the inout parameter doesn't match the argument type.
   runFileTest("fn.param.inout.type-mismatch.hlsl");