ソースを参照

[spirv] Lower SPIR-V types that may contain hybrid types.

Ehsan 6 年 前
コミット
acdf54e2f8

+ 7 - 1
tools/clang/include/clang/SPIRV/LowerTypeVisitor.h

@@ -32,10 +32,16 @@ public:
 
   // Custom visitor for variables. Variables must have a pointer result-type.
   bool visit(SpirvVariable *);
+
   // Custom visitor for function parameters. We use pointer type for function
   // parameters.
   bool visit(SpirvFunctionParameter *);
 
+  // Custom visitor for OpSampledImage. The result type of OpSampledImage should
+  // be OpTypeSampledImage, but instruction stores the QualType for the
+  // underlying image.
+  bool visit(SpirvSampledImage *);
+
   /// The "sink" visit function for all instructions.
   ///
   /// By default, all other visit instructions redirect to this visit function.
@@ -61,7 +67,7 @@ private:
   ///
   /// Uses the above lowerType method to lower the QualType components of hybrid
   /// types.
-  const SpirvType *lowerType(const HybridType *, SpirvLayoutRule,
+  const SpirvType *lowerType(const SpirvType *, SpirvLayoutRule,
                              SourceLocation);
 
   /// Lowers the given HLSL resource type into its SPIR-V type.

+ 120 - 16
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -25,12 +25,10 @@ bool LowerTypeVisitor::visit(SpirvFunction *fn, Phase phase) {
                   /*SourceLocation*/ {});
     fn->setReturnType(const_cast<SpirvType *>(spirvReturnType));
 
-    // In case the function type is a hybrid type, we should also lower the
-    // return type of the SPIR-V function type.
-    if (auto *fnRetType = dyn_cast<HybridType>(fn->getFunctionType())) {
-      fn->setFunctionType(const_cast<SpirvType *>(lowerType(
-          fnRetType, SpirvLayoutRule::Void, fn->getSourceLocation())));
-    }
+    // Lower the SPIR-V function type if necessary.
+    fn->setFunctionType(const_cast<SpirvType *>(
+        lowerType(fn->getFunctionType(), SpirvLayoutRule::Void,
+                  fn->getSourceLocation())));
   }
   return true;
 }
@@ -48,11 +46,9 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
   }
   // Lower Hybrid type to SpirvType
   else if (hybridType) {
-    if (const auto *hybridType = dyn_cast<HybridType>(instr->getResultType())) {
-      const SpirvType *spirvType = lowerType(hybridType, instr->getLayoutRule(),
-                                             instr->getSourceLocation());
-      instr->setResultType(spirvType);
-    }
+    const SpirvType *spirvType = lowerType(hybridType, instr->getLayoutRule(),
+                                           instr->getSourceLocation());
+    instr->setResultType(spirvType);
   }
 
   // The instruction does not have a result-type, so nothing to do.
@@ -81,21 +77,35 @@ bool LowerTypeVisitor::visit(SpirvFunctionParameter *param) {
   return true;
 }
 
-const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
+bool LowerTypeVisitor::visit(SpirvSampledImage *instr) {
+  if (!visitInstruction(instr))
+    return false;
+
+  // Wrap the image type in sampled image type if necessary.
+  const auto *resultType = instr->getResultType();
+  if (!isa<SampledImageType>(resultType)) {
+    assert(isa<ImageType>(resultType));
+    instr->setResultType(
+        spvContext.getSampledImageType(cast<ImageType>(resultType)));
+  }
+  return true;
+}
+
+const SpirvType *LowerTypeVisitor::lowerType(const SpirvType *type,
                                              SpirvLayoutRule rule,
                                              SourceLocation loc) {
-  if (const auto *hybridPointer = dyn_cast<HybridPointerType>(hybrid)) {
+  if (const auto *hybridPointer = dyn_cast<HybridPointerType>(type)) {
     const QualType pointeeType = hybridPointer->getPointeeType();
     const SpirvType *pointeeSpirvType = lowerType(pointeeType, rule, loc);
     return spvContext.getPointerType(pointeeSpirvType,
                                      hybridPointer->getStorageClass());
   } else if (const auto *hybridSampledImage =
-                 dyn_cast<HybridSampledImageType>(hybrid)) {
+                 dyn_cast<HybridSampledImageType>(type)) {
     const QualType imageAstType = hybridSampledImage->getImageType();
     const SpirvType *imageSpirvType = lowerType(imageAstType, rule, loc);
     assert(isa<ImageType>(imageSpirvType));
     return spvContext.getSampledImageType(cast<ImageType>(imageSpirvType));
-  } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(hybrid)) {
+  } else if (const auto *hybridFn = dyn_cast<HybridFunctionType>(type)) {
     // Lower the return type.
     const QualType astReturnType = hybridFn->getAstReturnType();
     const SpirvType *spirvReturnType = lowerType(astReturnType, rule, loc);
@@ -111,7 +121,7 @@ const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
     }
 
     return spvContext.getFunctionType(spirvReturnType, paramTypes);
-  } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(hybrid)) {
+  } else if (const auto *hybridStruct = dyn_cast<HybridStructType>(type)) {
     // lower all fields of the struct.
     std::vector<StructType::FieldInfo> structFields;
     for (auto field : hybridStruct->getFields()) {
@@ -124,6 +134,100 @@ const SpirvType *LowerTypeVisitor::lowerType(const HybridType *hybrid,
                                     hybridStruct->isReadOnly(),
                                     hybridStruct->getInterfaceType());
   }
+  // Void, bool, int, float cannot be further lowered.
+  // Matrices cannot contain hybrid types. Only matrices of scalars are valid.
+  // sampledType in image types can only be numberical type.
+  // Sampler types cannot be further lowered.
+  // SampledImage types cannot be further lowered.
+  else if (isa<VoidType>(type) || isa<ScalarType>(type) ||
+           isa<MatrixType>(type) || isa<ImageType>(type) ||
+           isa<SamplerType>(type) || isa<SampledImageType>(type)) {
+    return type;
+  }
+  // Vectors could contain a hybrid type
+  else if (const auto *vecType = dyn_cast<VectorType>(type)) {
+    const auto *loweredElemType =
+        lowerType(vecType->getElementType(), rule, loc);
+    // If vector didn't contain any hybrid types, return itself.
+    if (vecType->getElementType() == loweredElemType)
+      return vecType;
+    return spvContext.getVectorType(loweredElemType,
+                                    vecType->getElementCount());
+  }
+  // Arrays could contain a hybrid type
+  else if (const auto *arrType = dyn_cast<ArrayType>(type)) {
+    const auto *loweredElemType =
+        lowerType(arrType->getElementType(), rule, loc);
+    // If array didn't contain any hybrid types, return itself.
+    if (arrType->getElementType() == loweredElemType)
+      return arrType;
+    return spvContext.getArrayType(loweredElemType, arrType->getElementCount());
+  }
+  // Runtime arrays could contain a hybrid type
+  else if (const auto *raType = dyn_cast<RuntimeArrayType>(type)) {
+    const auto *loweredElemType =
+        lowerType(raType->getElementType(), rule, loc);
+    // If runtime array didn't contain any hybrid types, return itself.
+    if (raType->getElementType() == loweredElemType)
+      return arrType;
+    return spvContext.getRuntimeArrayType(loweredElemType);
+  }
+  // Struct types could contain a hybrid type
+  else if (const auto *structType = dyn_cast<StructType>(type)) {
+    const auto &fields = structType->getFields();
+    llvm::SmallVector<StructType::FieldInfo, 4> loweredFields;
+    bool wasLowered = false;
+    for (auto &field : fields) {
+      const auto *loweredFieldType = lowerType(field.type, rule, loc);
+      if (loweredFieldType != field.type) {
+        wasLowered = true;
+        loweredFields.push_back(
+            StructType::FieldInfo(loweredFieldType, field.name,
+                                  field.vkOffsetAttr, field.packOffsetAttr));
+      } else {
+        loweredFields.push_back(field);
+      }
+    }
+    // If the struct didn't contain any hybrid types, return itself.
+    if (!wasLowered)
+      return structType;
+
+    return spvContext.getStructType(loweredFields, structType->getStructName(),
+                                    structType->isReadOnly(),
+                                    structType->getInterfaceType());
+  }
+  // Pointer types could point to a hybrid type.
+  else if (const auto *ptrType = dyn_cast<SpirvPointerType>(type)) {
+    const auto *loweredPointee =
+        lowerType(ptrType->getPointeeType(), rule, loc);
+    // If the pointer type didn't point to any hybrid type, return itself.
+    if (ptrType->getPointeeType() == loweredPointee)
+      return ptrType;
+
+    return spvContext.getPointerType(loweredPointee,
+                                     ptrType->getStorageClass());
+  }
+  // Function types may have a parameter or return type that is hybrid.
+  else if (const auto *fnType = dyn_cast<FunctionType>(type)) {
+    const auto *loweredRetType = lowerType(fnType->getReturnType(), rule, loc);
+    bool wasLowered = fnType->getReturnType() != loweredRetType;
+    llvm::SmallVector<const SpirvType *, 4> loweredParams;
+    const auto &paramTypes = fnType->getParamTypes();
+    for (auto *paramType : paramTypes) {
+      const auto *loweredParamType = lowerType(paramType, rule, loc);
+      loweredParams.push_back(loweredParamType);
+      if (loweredParamType != paramType) {
+        wasLowered = true;
+      }
+    }
+    // If the function type didn't include any hybrid types, return itself.
+    if (!wasLowered) {
+      return fnType;
+    }
+
+    return spvContext.getFunctionType(loweredRetType, loweredParams);
+  }
+
   llvm_unreachable("lowering of hybrid type not implemented");
 }
 

+ 1 - 33
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -524,37 +524,6 @@ std::string getFnName(const FunctionDecl *fn) {
   return getNamespacePrefix(fn) + classOrStructName + fn->getName().str();
 }
 
-/// Returns the capability required to non-uniformly index into the given type.
-spv::Capability getNonUniformCapability(QualType type) {
-  using spv::Capability;
-
-  if (type->isArrayType()) {
-    return getNonUniformCapability(
-        type->getAsArrayTypeUnsafe()->getElementType());
-  }
-  if (TypeTranslator::isTexture(type) || TypeTranslator::isSampler(type)) {
-    return Capability::SampledImageArrayNonUniformIndexingEXT;
-  }
-  if (TypeTranslator::isRWTexture(type)) {
-    return Capability::StorageImageArrayNonUniformIndexingEXT;
-  }
-  if (TypeTranslator::isBuffer(type)) {
-    return Capability::UniformTexelBufferArrayNonUniformIndexingEXT;
-  }
-  if (TypeTranslator::isRWBuffer(type)) {
-    return Capability::StorageTexelBufferArrayNonUniformIndexingEXT;
-  }
-  if (const auto *recordType = type->getAs<RecordType>()) {
-    const auto name = recordType->getDecl()->getName();
-
-    if (name == "SubpassInput" || name == "SubpassInputMS") {
-      return Capability::InputAttachmentArrayNonUniformIndexingEXT;
-    }
-  }
-
-  return Capability::Max;
-}
-
 } // namespace
 
 SPIRVEmitter::SPIRVEmitter(CompilerInstance &ci)
@@ -2999,9 +2968,8 @@ SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
   auto *samplerState = doExpr(expr->getArg(0));
   auto *coordinate = doExpr(expr->getArg(1));
 
-  auto *sampledImageType = spvContext.getSampledImageType(object->getType());
   auto *sampledImage = spvBuilder.createBinaryOp(
-      spv::Op::OpSampledImage, sampledImageType, objectInfo, samplerState);
+      spv::Op::OpSampledImage, object->getType(), objectInfo, samplerState);
 
   if (objectInfo->isNonUniform() || samplerState->isNonUniform()) {
     // The sampled image will be used to access resource's memory, so we need