浏览代码

[spirv] Fix bug in CapabilityVisitor.

Ehsan Nasiri 6 年之前
父节点
当前提交
2906cd30b3
共有 2 个文件被更改,包括 21 次插入21 次删除
  1. 19 18
      tools/clang/lib/SPIRV/CapabilityVisitor.cpp
  2. 2 3
      tools/clang/lib/SPIRV/CapabilityVisitor.h

+ 19 - 18
tools/clang/lib/SPIRV/CapabilityVisitor.cpp

@@ -48,7 +48,10 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
     switch (floatType->getBitwidth()) {
     case 16: {
       // Usage of a 16-bit float type.
-      spvBuilder.requireCapability(spv::Capability::Float16);
+      // It looks like the validator does not approve of Float16
+      // capability even though we do use the necessary extension.
+      // TODO: Re-enable adding Float16 capability below.
+      // spvBuilder.requireCapability(spv::Capability::Float16);
       spvBuilder.addExtension(Extension::AMD_gpu_shader_half_float,
                               "16-bit float", {});
 
@@ -70,15 +73,15 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
   }
   // Vectors
   else if (const auto *vecType = dyn_cast<VectorType>(type)) {
-    addCapabilityForType(vecType->getElementType());
+    addCapabilityForType(vecType->getElementType(), loc, sc);
   }
   // Matrices
   else if (const auto *matType = dyn_cast<MatrixType>(type)) {
-    addCapabilityForType(matType->getElementType());
+    addCapabilityForType(matType->getElementType(), loc, sc);
   }
   // Arrays
   else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
-    addCapabilityForType(arrType->getElementType());
+    addCapabilityForType(arrType->getElementType(), loc, sc);
   }
   // Runtime array of resources requires additional capability.
   else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
@@ -88,7 +91,7 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
                               "runtime array of resources", {});
       spvBuilder.requireCapability(spv::Capability::RuntimeDescriptorArrayEXT);
     }
-    addCapabilityForType(raType->getElementType());
+    addCapabilityForType(raType->getElementType(), loc, sc);
   }
   // Image types
   else if (const auto *imageType = dyn_cast<ImageType>(type)) {
@@ -155,15 +158,15 @@ void CapabilityVisitor::addCapabilityForType(const SpirvType *type,
     if (imageType->isArrayedImage() && imageType->isMSImage())
       spvBuilder.requireCapability(spv::Capability::ImageMSArray);
 
-    addCapabilityForType(imageType->getSampledType());
+    addCapabilityForType(imageType->getSampledType(), loc, sc);
   }
   // Sampled image type
   else if (const auto *sampledImageType = dyn_cast<SampledImageType>(type)) {
-    addCapabilityForType(sampledImageType->getImageType());
+    addCapabilityForType(sampledImageType->getImageType(), loc, sc);
   }
   // Pointer type
   else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
-    addCapabilityForType(ptrType->getPointeeType());
+    addCapabilityForType(ptrType->getPointeeType(), loc, sc);
   }
   // Struct type
   else if (const auto *structType = dyn_cast<StructType>(type)) {
@@ -341,19 +344,22 @@ CapabilityVisitor::getNonUniformCapability(const SpirvType *type) {
 }
 
 bool CapabilityVisitor::visit(SpirvImageQuery *instr) {
-  addCapabilityForType(instr->getResultType());
+  addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
+                       instr->getStorageClass());
   spvBuilder.requireCapability(spv::Capability::ImageQuery);
   return true;
 }
 
 bool CapabilityVisitor::visit(SpirvImageSparseTexelsResident *instr) {
-  addCapabilityForType(instr->getResultType());
+  addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
+                       instr->getStorageClass());
   spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
   return true;
 }
 
 bool CapabilityVisitor::visit(SpirvImageOp *instr) {
-  addCapabilityForType(instr->getResultType());
+  addCapabilityForType(instr->getResultType(), instr->getSourceLocation(),
+                       instr->getStorageClass());
   if (instr->hasOffset() || instr->hasConstOffsets())
     spvBuilder.requireCapability(spv::Capability::ImageGatherExtended);
   if (instr->hasMinLod())
@@ -364,18 +370,13 @@ bool CapabilityVisitor::visit(SpirvImageOp *instr) {
   return true;
 }
 
-bool CapabilityVisitor::visit(SpirvVariable *var) {
-  addCapabilityForType(var->getResultType(), var->getSourceLocation(),
-                       var->getStorageClass());
-  return true;
-}
-
 bool CapabilityVisitor::visitInstruction(SpirvInstruction *instr) {
   const SpirvType *resultType = instr->getResultType();
   const auto opcode = instr->getopcode();
 
   // Add result-type-specific capabilities
-  addCapabilityForType(resultType, instr->getSourceLocation());
+  addCapabilityForType(resultType, instr->getSourceLocation(),
+                       instr->getStorageClass());
 
   // Add NonUniform capabilities if necessary
   if (instr->isNonUniform()) {

+ 2 - 3
tools/clang/lib/SPIRV/CapabilityVisitor.h

@@ -32,7 +32,6 @@ public:
   bool visit(SpirvImageQuery *);
   bool visit(SpirvImageOp *);
   bool visit(SpirvImageSparseTexelsResident *);
-  bool visit(SpirvVariable *);
 
   /// The "sink" visit function for all instructions.
   ///
@@ -46,8 +45,8 @@ private:
   /// The called may also provide the storage class for variable types, because
   /// in the case of variable types, the storage class may affect the capability
   /// that is used.
-  void addCapabilityForType(const SpirvType *, SourceLocation loc = {},
-                            spv::StorageClass sc = spv::StorageClass::Max);
+  void addCapabilityForType(const SpirvType *, SourceLocation loc,
+                            spv::StorageClass sc);
 
   /// Returns the capability required to non-uniformly index into the given
   /// type.