瀏覽代碼

[spirv] Forbid using buffers as initializers and code refactoring (#992)

* Counter variable emission is merged into getTypeForPotentialAliasVar
* Changed to use DeclaratorDecl instead of ValueDecl for some methods
* collectArrayStructIndices is extended to suppor raw index
Lei Zhang 7 年之前
父節點
當前提交
1525a56558

+ 16 - 25
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -250,18 +250,13 @@ SpirvEvalInfo DeclResultIdMapper::getDeclResultId(const ValueDecl *decl,
 uint32_t DeclResultIdMapper::createFnParam(const ParmVarDecl *param) {
   bool isAlias = false;
   auto &info = astDecls[param].info;
-  const uint32_t type = getTypeForPotentialAliasVar(param, &isAlias, &info);
+  const uint32_t type =
+      getTypeAndCreateCounterForPotentialAliasVar(param, &isAlias, &info);
   const uint32_t ptrType =
       theBuilder.getPointerType(type, spv::StorageClass::Function);
   const uint32_t id = theBuilder.addFnParam(ptrType, param->getName());
   info.setResultId(id);
 
-  // The counter variable may be created before by forward declaration.
-  if (!counterVars.count(param))
-    // Create alias counter variable if suitable
-    if (isAlias && TypeTranslator::isRWAppendConsumeSBuffer(param->getType()))
-      createCounterVar(param, /*isAlias=*/true);
-
   return id;
 }
 
@@ -277,14 +272,11 @@ uint32_t DeclResultIdMapper::createFnVar(const VarDecl *var,
                                          llvm::Optional<uint32_t> init) {
   bool isAlias = false;
   auto &info = astDecls[var].info;
-  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias, &info);
+  const uint32_t type =
+      getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
   const uint32_t id = theBuilder.addFnVar(type, var->getName(), init);
   info.setResultId(id);
 
-  // Create alias counter variable if suitable
-  if (isAlias && TypeTranslator::isRWAppendConsumeSBuffer(var->getType()))
-    createCounterVar(var, /*isAlias=*/true);
-
   return id;
 }
 
@@ -292,17 +284,14 @@ uint32_t DeclResultIdMapper::createFileVar(const VarDecl *var,
                                            llvm::Optional<uint32_t> init) {
   bool isAlias = false;
   auto &info = astDecls[var].info;
-  const uint32_t type = getTypeForPotentialAliasVar(var, &isAlias, &info);
+  const uint32_t type =
+      getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias, &info);
   const uint32_t id = theBuilder.addModuleVar(type, spv::StorageClass::Private,
                                               var->getName(), init);
   info.setResultId(id);
   if (!isAlias)
     info.setStorageClass(spv::StorageClass::Private);
 
-  // Create alias counter variable if suitable
-  if (isAlias && TypeTranslator::isRWAppendConsumeSBuffer(var->getType()))
-    createCounterVar(var, /*isAlias=*/true);
-
   return id;
 }
 
@@ -513,7 +502,8 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
   auto &info = astDecls[fn].info;
 
   bool isAlias = false;
-  const uint32_t type = getTypeForPotentialAliasVar(fn, &isAlias, &info);
+  const uint32_t type =
+      getTypeAndCreateCounterForPotentialAliasVar(fn, &isAlias, &info);
 
   const uint32_t id = theBuilder.getSPIRVContext()->takeNextId();
   info.setResultId(id);
@@ -524,22 +514,19 @@ uint32_t DeclResultIdMapper::getOrRegisterFnResultId(const FunctionDecl *fn) {
       !TypeTranslator::isAKindOfStructuredOrByteBuffer(fn->getReturnType()))
     info.setRValue();
 
-  // Create alias counter variable if suitable
-  if (TypeTranslator::isRWAppendConsumeSBuffer(fn->getReturnType()))
-    createCounterVar(fn, /*isAlias=*/true);
-
   return id;
 }
 
 const CounterIdAliasPair *
-DeclResultIdMapper::getCounterIdAliasPair(const ValueDecl *decl) {
+DeclResultIdMapper::getCounterIdAliasPair(const DeclaratorDecl *decl) {
   const auto counter = counterVars.find(decl);
   if (counter != counterVars.end())
     return &counter->second;
   return nullptr;
 }
 
-void DeclResultIdMapper::createCounterVar(const ValueDecl *decl, bool isAlias) {
+void DeclResultIdMapper::createCounterVar(const DeclaratorDecl *decl,
+                                          bool isAlias) {
   const std::string counterName = "counter.var." + decl->getName().str();
   uint32_t counterType = typeTranslator.getACSBufferCounter();
   // {RW|Append|Consume}StructuredBuffer are all in Uniform storage class.
@@ -1876,7 +1863,7 @@ DeclResultIdMapper::getStorageClassForSigPoint(const hlsl::SigPoint *sigPoint) {
   return sc;
 }
 
-uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
+uint32_t DeclResultIdMapper::getTypeAndCreateCounterForPotentialAliasVar(
     const DeclaratorDecl *decl, bool *shouldBeAlias, SpirvEvalInfo *info) {
   if (const auto *varDecl = dyn_cast<VarDecl>(decl)) {
     // This method is only intended to be used to create SPIR-V variables in the
@@ -1902,6 +1889,10 @@ uint32_t DeclResultIdMapper::getTypeForPotentialAliasVar(
   if (genAlias) {
     needsLegalization = true;
 
+    if (!counterVars.count(decl) &&
+        TypeTranslator::isRWAppendConsumeSBuffer(type))
+      createCounterVar(decl, /*isAlias=*/true);
+
     if (info)
       info->setContainsAliasComponent(true);
   }

+ 9 - 7
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -255,15 +255,17 @@ public:
   /// \brief Returns the suitable type for the given decl, considering the
   /// given decl could possibly be created as an alias variable. If true, a
   /// pointer-to-the-value type will be returned, otherwise, just return the
-  /// normal value type.
+  /// normal value type. For an alias variable having a associated counter, the
+  /// counter variable will also be emitted.
   ///
   /// If the type is for an alias variable, writes true to *shouldBeAlias and
   /// writes storage class, layout rule, and valTypeId to *info.
   ///
   /// Note: legalization specific code
-  uint32_t getTypeForPotentialAliasVar(const DeclaratorDecl *var,
-                                       bool *shouldBeAlias = nullptr,
-                                       SpirvEvalInfo *info = nullptr);
+  uint32_t
+  getTypeAndCreateCounterForPotentialAliasVar(const DeclaratorDecl *var,
+                                              bool *shouldBeAlias = nullptr,
+                                              SpirvEvalInfo *info = nullptr);
 
   /// \brief Sets the <result-id> of the entry function.
   void setEntryFunctionId(uint32_t id) { entryFunctionId = id; }
@@ -308,7 +310,7 @@ public:
   /// pair for the given {RW|Append|Consume}StructuredBuffer variable.
   /// Returns nullptr if the given decl has no associated counter variable
   /// created.
-  const CounterIdAliasPair *getCounterIdAliasPair(const ValueDecl *decl);
+  const CounterIdAliasPair *getCounterIdAliasPair(const DeclaratorDecl *decl);
 
   /// \brief Returns the <type-id> for the given cbuffer, tbuffer,
   /// ConstantBuffer, TextureBuffer, or push constant block.
@@ -489,7 +491,7 @@ private:
   /// pointer-to-pointer type in Private storage class) if isAlias is true.
   ///
   /// Note: isAlias - legalization specific code
-  void createCounterVar(const ValueDecl *decl, bool isAlias);
+  void createCounterVar(const DeclaratorDecl *decl, bool isAlias);
 
   /// Decorates varId of the given asType with proper interpolation modes
   /// considering the attributes on the given decl.
@@ -530,7 +532,7 @@ private:
   llvm::SmallVector<ResourceVar, 8> resourceVars;
   /// Mapping from {RW|Append|Consume}StructuredBuffers to their
   /// counter variables' (<result-id>, is-alias-or-not) pairs
-  llvm::DenseMap<const ValueDecl *, CounterIdAliasPair> counterVars;
+  llvm::DenseMap<const DeclaratorDecl *, CounterIdAliasPair> counterVars;
 
   /// Mapping from cbuffer/tbuffer/ConstantBuffer/TextureBufer/push-constant
   /// to the <type-id>

+ 10 - 1
tools/clang/lib/SPIRV/InitListHandler.cpp

@@ -211,6 +211,12 @@ uint32_t InitListHandler::createInitForType(QualType type,
   if (TypeTranslator::isOpaqueType(type))
     return createInitForSamplerImageType(type, srcLoc);
 
+  // This should happen before the check for normal struct types
+  if (TypeTranslator::isAKindOfStructuredOrByteBuffer(type)) {
+    emitError("cannot handle structured/byte buffer as initializer", srcLoc);
+    return 0;
+  }
+
   if (type->isStructureType())
     return createInitForStructType(type);
 
@@ -365,8 +371,11 @@ uint32_t InitListHandler::createInitForStructType(QualType type) {
 
   llvm::SmallVector<uint32_t, 4> fields;
   const RecordDecl *structDecl = type->getAsStructureType()->getDecl();
-  for (const auto *field : structDecl->fields())
+  for (const auto *field : structDecl->fields()) {
     fields.push_back(createInitForType(field->getType(), field->getLocation()));
+    if (!fields.back())
+      return 0;
+  }
 
   const uint32_t typeId = typeTranslator.translateType(type);
   // TODO: use OpConstantComposite when all components are constants

+ 68 - 56
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -373,11 +373,14 @@ inline const FunctionDecl *getCalleeDefinition(const CallExpr *expr) {
 
 /// Returns the referenced definition. The given expr is expected to be a
 /// DeclRefExpr or CallExpr after ignoring casts. Returns nullptr otherwise.
-const ValueDecl *getReferencedDef(const Expr *expr) {
+const DeclaratorDecl *getReferencedDef(const Expr *expr) {
+  if (!expr)
+    return nullptr;
+
   expr = expr->IgnoreParenCasts();
 
   if (const auto *declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
-    return declRefExpr->getDecl();
+    return dyn_cast_or_null<DeclaratorDecl>(declRefExpr->getDecl());
   }
 
   if (const auto *callExpr = dyn_cast<CallExpr>(expr)) {
@@ -655,45 +658,45 @@ SpirvEvalInfo SPIRVEmitter::doExpr(const Expr *expr) {
 SpirvEvalInfo SPIRVEmitter::loadIfGLValue(const Expr *expr,
                                           SpirvEvalInfo info) {
 
-  if (!info.isRValue()) {
-    // Check whether we are trying to load an externally visible structured/byte
-    // buffer as a whole. If true, it means we are creating alias for it. Avoid
-    // the load and write the pointer directly to the alias variable then.
-    //
-    // Also for the case of alias function returns. If we are trying to load an
-    // alias function return as a whole, it means we are assigning it to another
-    // alias variable. Avoid the load and write the pointer directly.
-    //
-    // Note: legalization specific code
-    if (isReferencingNonAliasStructuredOrByteBuffer(expr)) {
-      return info.setRValue();
-    }
+  // Do nothing if this is already rvalue
+  if (info.isRValue())
+    return info;
 
-    if (loadIfAliasVarRef(expr, info)) {
-      // We are loading an alias variable as a whole here. This is likely for
-      // wholesale assignments or function returns. Need to load the pointer.
-      //
-      // Note: legalization specific code
-      // TODO: It seems we should not set rvalue here since info is still
-      // holding a pointer. But it fails structured buffer assignment because
-      // of double loadIfGLValue() calls if we do not. Fix it.
-      return info.setRValue();
-    }
-
-    uint32_t valType = 0;
-    // TODO: Ouch. Very hacky. We need special path to get the value type if
-    // we are loading a whole ConstantBuffer/TextureBuffer since the normal
-    // type translation path won't work.
-    if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
-      valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
-    } else {
-      valType =
-          typeTranslator.translateType(expr->getType(), info.getLayoutRule());
-    }
-    info.setResultId(theBuilder.createLoad(valType, info)).setRValue();
+  // Check whether we are trying to load an externally visible structured/byte
+  // buffer as a whole. If true, it means we are creating alias for it. Avoid
+  // the load and write the pointer directly to the alias variable then.
+  //
+  // Also for the case of alias function returns. If we are trying to load an
+  // alias function return as a whole, it means we are assigning it to another
+  // alias variable. Avoid the load and write the pointer directly.
+  //
+  // Note: legalization specific code
+  if (isReferencingNonAliasStructuredOrByteBuffer(expr)) {
+    return info.setRValue();
   }
 
-  return info;
+  if (loadIfAliasVarRef(expr, info)) {
+    // We are loading an alias variable as a whole here. This is likely for
+    // wholesale assignments or function returns. Need to load the pointer.
+    //
+    // Note: legalization specific code
+    // TODO: It seems we should not set rvalue here since info is still
+    // holding a pointer. But it fails structured buffer assignment because
+    // of double loadIfGLValue() calls if we do not. Fix it.
+    return info.setRValue();
+  }
+
+  uint32_t valType = 0;
+  // TODO: Ouch. Very hacky. We need special path to get the value type if
+  // we are loading a whole ConstantBuffer/TextureBuffer since the normal
+  // type translation path won't work.
+  if (const auto *declContext = isConstantTextureBufferDeclRef(expr)) {
+    valType = declIdMapper.getCTBufferPushConstantTypeId(declContext);
+  } else {
+    valType =
+        typeTranslator.translateType(expr->getType(), info.getLayoutRule());
+  }
+  return info.setResultId(theBuilder.createLoad(valType, info)).setRValue();
 }
 
 SpirvEvalInfo SPIRVEmitter::loadIfAliasVarRef(const Expr *expr) {
@@ -794,7 +797,8 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
     funcId = declIdMapper.getDeclResultId(decl);
   }
 
-  const uint32_t retType = declIdMapper.getTypeForPotentialAliasVar(decl);
+  const uint32_t retType =
+      declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(decl);
 
   // Construct the function signature.
   llvm::SmallVector<uint32_t, 4> paramTypes;
@@ -819,7 +823,8 @@ void SPIRVEmitter::doFunctionDecl(const FunctionDecl *decl) {
   }
 
   for (const auto *param : decl->params()) {
-    const uint32_t valueType = declIdMapper.getTypeForPotentialAliasVar(param);
+    const uint32_t valueType =
+        declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
     const uint32_t ptrType =
         theBuilder.getPointerType(valueType, spv::StorageClass::Function);
     paramTypes.push_back(ptrType);
@@ -1698,7 +1703,8 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
 
     // We need to create variables for holding the values to be used as
     // arguments. The variables themselves are of pointer types.
-    const uint32_t varType = declIdMapper.getTypeForPotentialAliasVar(param);
+    const uint32_t varType =
+        declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(param);
     const std::string varName = "param.var." + param->getNameAsString();
     const uint32_t tempVarId = theBuilder.addFnVar(varType, varName);
 
@@ -1718,7 +1724,8 @@ SpirvEvalInfo SPIRVEmitter::processCall(const CallExpr *callExpr) {
     workQueue.insert(callee);
   }
 
-  const uint32_t retType = declIdMapper.getTypeForPotentialAliasVar(callee);
+  const uint32_t retType =
+      declIdMapper.getTypeAndCreateCounterForPotentialAliasVar(callee);
   // Get or forward declare the function <result-id>
   const uint32_t funcId = declIdMapper.getOrRegisterFnResultId(callee);
 
@@ -2761,7 +2768,7 @@ uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
     (void)doExpr(object);
   }
 
-  const ValueDecl *buffer = getReferencedDef(object);
+  const auto *buffer = getReferencedDef(object);
   if (!buffer) {
     emitError("method call syntax unimplemented", expr->getExprLoc());
     return 0;
@@ -2788,7 +2795,7 @@ uint32_t SPIRVEmitter::incDecRWACSBufferCounter(const CXXMemberCallExpr *expr,
   return index;
 }
 
-bool SPIRVEmitter::tryToAssignCounterVar(const ValueDecl *dstDecl,
+bool SPIRVEmitter::tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
                                          const Expr *srcExpr) {
   // For parameters of forward-declared functions. We must make sure the
   // associated counter variable is created. But for forward-declared functions,
@@ -3959,13 +3966,7 @@ void SPIRVEmitter::storeValue(const SpirvEvalInfo &lhsPtr,
     theBuilder.createStore(lhsPtr, rhsVal);
   } else if (const auto *recordType = lhsValType->getAs<RecordType>()) {
     uint32_t index = 0;
-    for (const auto *decl : recordType->getDecl()->decls()) {
-      // Ignore implicit generated struct declarations/constructors/destructors.
-      if (decl->isImplicit())
-        continue;
-
-      const auto *field = cast<FieldDecl>(decl);
-
+    for (const auto *field : recordType->getDecl()->fields()) {
       const auto subRhsValType = typeTranslator.translateType(
           field->getType(), rhsVal.getLayoutRule());
       const auto subRhsVal =
@@ -4735,7 +4736,7 @@ SPIRVEmitter::processMatrixBinaryOp(const Expr *lhs, const Expr *rhs,
 }
 
 const Expr *SPIRVEmitter::collectArrayStructIndices(
-    const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices) {
+    const Expr *expr, llvm::SmallVectorImpl<uint32_t> *indices, bool rawIndex) {
   if (const auto *indexing = dyn_cast<MemberExpr>(expr)) {
     // First check whether this is referring to a static member. If it is, we
     // create a DeclRefExpr for it.
@@ -4747,12 +4748,14 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
             varDecl->getType(), VK_LValue);
 
     const Expr *base = collectArrayStructIndices(
-        indexing->getBase()->IgnoreParenNoopCasts(astContext), indices);
+        indexing->getBase()->IgnoreParenNoopCasts(astContext), indices,
+        rawIndex);
 
     // Append the index of the current level
     const auto *fieldDecl = cast<FieldDecl>(indexing->getMemberDecl());
     assert(fieldDecl);
-    indices->push_back(theBuilder.getConstantInt32(fieldDecl->getFieldIndex()));
+    const uint32_t index = fieldDecl->getFieldIndex();
+    indices->push_back(rawIndex ? index : theBuilder.getConstantInt32(index));
 
     return base;
   }
@@ -4762,20 +4765,26 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
   TypeTranslator::LiteralTypeHint hint(typeTranslator, astContext.IntTy);
 
   if (const auto *indexing = dyn_cast<ArraySubscriptExpr>(expr)) {
+    if (rawIndex)
+      return nullptr; // TODO: handle constant array index
+
     // The base of an ArraySubscriptExpr has a wrapping LValueToRValue implicit
     // cast. We need to ingore it to avoid creating OpLoad.
     const Expr *thisBase = indexing->getBase()->IgnoreParenLValueCasts();
-    const Expr *base = collectArrayStructIndices(thisBase, indices);
+    const Expr *base = collectArrayStructIndices(thisBase, indices, rawIndex);
     indices->push_back(doExpr(indexing->getIdx()));
     return base;
   }
 
   if (const auto *indexing = dyn_cast<CXXOperatorCallExpr>(expr))
     if (indexing->getOperator() == OverloadedOperatorKind::OO_Subscript) {
+      if (rawIndex)
+        return nullptr; // TODO: handle constant array index
+
       const Expr *thisBase =
           indexing->getArg(0)->IgnoreParenNoopCasts(astContext);
       const auto thisBaseType = thisBase->getType();
-      const Expr *base = collectArrayStructIndices(thisBase, indices);
+      const Expr *base = collectArrayStructIndices(thisBase, indices, rawIndex);
 
       if (thisBaseType != base->getType() &&
           TypeTranslator::isAKindOfStructuredOrByteBuffer(thisBaseType)) {
@@ -4810,6 +4819,9 @@ const Expr *SPIRVEmitter::collectArrayStructIndices(
     const Expr *index = nullptr;
     // TODO: the following is duplicating the logic in doCXXMemberCallExpr.
     if (const auto *object = isStructuredBufferLoad(expr, &index)) {
+      if (rawIndex)
+        return nullptr; // TODO: handle constant array index
+
       // For object.Load(index), there should be no more indexing into the
       // object.
       indices->push_back(theBuilder.getConstantInt32(0));

+ 8 - 4
tools/clang/lib/SPIRV/SPIRVEmitter.h

@@ -260,10 +260,13 @@ private:
 
   /// Collects all indices (SPIR-V constant values) from consecutive MemberExprs
   /// or ArraySubscriptExprs or operator[] calls and writes into indices.
-  /// Returns the real base.
+  /// Returns the real base. If rawIndex is set to true, the indices collected
+  /// will not be turned into SPIR-V constant values, and the base returned can
+  /// be nullptr, which means some indices are not constant.
   const Expr *
   collectArrayStructIndices(const Expr *expr,
-                            llvm::SmallVectorImpl<uint32_t> *indices);
+                            llvm::SmallVectorImpl<uint32_t> *indices,
+                            bool rawIndex = false);
 
   /// Creates an access chain to index into the given SPIR-V evaluation result
   /// and overwrites and returns the new SPIR-V evaluation result.
@@ -451,7 +454,7 @@ private:
   /// represented in a 32-bit integer type or a literal float that can be
   /// represented in a 32-bit float type without losing info. Returns false
   /// otherwise.
-  bool canBeRepresentedIn32Bits(const Expr* expr);
+  bool canBeRepresentedIn32Bits(const Expr *expr);
 
 private:
   /// Translates the given HLSL loop attribute into SPIR-V loop control mask.
@@ -686,7 +689,8 @@ private:
   /// srcExpr or dstDecl.
   ///
   /// Note: legalization specific code
-  bool tryToAssignCounterVar(const ValueDecl *dstDecl, const Expr *srcExpr);
+  bool tryToAssignCounterVar(const DeclaratorDecl *dstDecl,
+                             const Expr *srcExpr);
 
   /// \brief Loads numWords 32-bit unsigned integers or stores numWords 32-bit
   /// unsigned integers (based on the doStore parameter) to the given