浏览代码

[spirv] Add createImageQuery API to SpirvBuilder.

e.g. binaryOp with SampledImage return type.
Ehsan Nasiri 6 年之前
父节点
当前提交
566c95fb4b

+ 0 - 16
tools/clang/include/clang/SPIRV/LowerTypeVisitor.h

@@ -30,22 +30,6 @@ public:
   bool visit(SpirvFunction *, Phase);
   bool visit(SpirvFunction *, Phase);
   bool visit(SpirvBasicBlock *, Phase) { return true; }
   bool visit(SpirvBasicBlock *, Phase) { return true; }
 
 
-  // Custom visitor for variables. Variables must have a pointer result-type.
-  bool visit(SpirvVariable *);
-
-  // Custom visitor for function parameters. We use pointer type for function
-  // parameters.
-  bool visit(SpirvFunctionParameter *);
-
-  // Custom visitor for OpSampledImage. The result type of OpSampledImage should
-  // be OpTypeSampledImage, but instruction stores the QualType for the
-  // underlying image.
-  bool visit(SpirvSampledImage *);
-
-  // Custom visitor for sparse image operations: the result type must be the
-  // Sparse Residency Struct.
-  bool visit(SpirvImageOp *);
-
   /// The "sink" visit function for all instructions.
   /// The "sink" visit function for all instructions.
   ///
   ///
   /// By default, all other visit instructions redirect to this visit function.
   /// By default, all other visit instructions redirect to this visit function.

+ 9 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -314,6 +314,15 @@ public:
   createImageSparseTexelsResident(SpirvInstruction *resident_code,
   createImageSparseTexelsResident(SpirvInstruction *resident_code,
                                   SourceLocation loc = {});
                                   SourceLocation loc = {});
 
 
+  /// \brief Creates an image query instruction.
+  /// The given 'lod' is used as the Lod argument in the case of
+  /// OpImageQuerySizeLod, and it is used as the 'coordinate' parameter in the
+  /// case of OpImageQueryLod.
+  SpirvImageQuery *
+  SpirvBuilder::createImageQuery(spv::Op opcode, QualType resultType,
+                                 SourceLocation loc, SpirvInstruction *image,
+                                 SpirvInstruction *lod = nullptr);
+
   /// \brief Creates a select operation with the given values for true and false
   /// \brief Creates a select operation with the given values for true and false
   /// cases and returns the instruction pointer.
   /// cases and returns the instruction pointer.
   SpirvSelect *createSelect(QualType resultType, SpirvInstruction *condition,
   SpirvSelect *createSelect(QualType resultType, SpirvInstruction *condition,

+ 32 - 44
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -42,7 +42,6 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
     const SpirvType *spirvType =
     const SpirvType *spirvType =
         lowerType(astType, instr->getLayoutRule(), instr->getSourceLocation());
         lowerType(astType, instr->getLayoutRule(), instr->getSourceLocation());
     instr->setResultType(spirvType);
     instr->setResultType(spirvType);
-    return spirvType != nullptr;
   }
   }
   // Lower Hybrid type to SpirvType
   // Lower Hybrid type to SpirvType
   else if (hybridType) {
   else if (hybridType) {
@@ -51,60 +50,49 @@ bool LowerTypeVisitor::visitInstruction(SpirvInstruction *instr) {
     instr->setResultType(spirvType);
     instr->setResultType(spirvType);
   }
   }
 
 
-  // The instruction does not have a result-type, so nothing to do.
-  return true;
-}
-
-bool LowerTypeVisitor::visit(SpirvVariable *var) {
-  if (!visitInstruction(var))
-    return false;
-
-  const SpirvType *valueType = var->getResultType();
-  const SpirvType *pointerType =
-      spvContext.getPointerType(valueType, var->getStorageClass());
-  var->setResultType(pointerType);
-  return true;
-}
-
-bool LowerTypeVisitor::visit(SpirvFunctionParameter *param) {
-  if (!visitInstruction(param))
-    return false;
-
-  const SpirvType *valueType = param->getResultType();
-  const SpirvType *pointerType =
-      spvContext.getPointerType(valueType, param->getStorageClass());
-  param->setResultType(pointerType);
-  return true;
-}
+  // Instruction-specific type updates
 
 
-bool LowerTypeVisitor::visit(SpirvSampledImage *instr) {
-  if (!visitInstruction(instr))
-    return false;
-
-  // Wrap the image type in sampled image type if necessary.
   const auto *resultType = instr->getResultType();
   const auto *resultType = instr->getResultType();
-  if (!isa<SampledImageType>(resultType)) {
-    assert(isa<ImageType>(resultType));
-    instr->setResultType(
-        spvContext.getSampledImageType(cast<ImageType>(resultType)));
+  switch (instr->getopcode()) {
+  case spv::Op::OpSampledImage: {
+    // Wrap the image type in sampled image type if necessary.
+    if (!isa<SampledImageType>(resultType)) {
+      assert(isa<ImageType>(resultType));
+      instr->setResultType(
+          spvContext.getSampledImageType(cast<ImageType>(resultType)));
+    }
+    break;
+  }
+  // Variables and function parameters must have a pointer type.
+  case spv::Op::OpFunctionParameter:
+  case spv::Op::OpVariable: {
+    const SpirvType *pointerType =
+        spvContext.getPointerType(resultType, instr->getStorageClass());
+    instr->setResultType(pointerType);
+    break;
   }
   }
-  return true;
-}
-
-bool LowerTypeVisitor::visit(SpirvImageOp *instr) {
-  if (!visitInstruction(instr))
-    return false;
-
   // Sparse image operations return a sparse residency struct.
   // Sparse image operations return a sparse residency struct.
-  const auto *resultType = instr->getResultType();
-  if (instr->isSparse()) {
+  case spv::Op::OpImageSparseSampleImplicitLod:
+  case spv::Op::OpImageSparseSampleExplicitLod:
+  case spv::Op::OpImageSparseSampleDrefImplicitLod:
+  case spv::Op::OpImageSparseSampleDrefExplicitLod:
+  case spv::Op::OpImageSparseFetch:
+  case spv::Op::OpImageSparseGather:
+  case spv::Op::OpImageSparseDrefGather:
+  case spv::Op::OpImageSparseRead: {
     const auto *uintType = spvContext.getUIntType(32);
     const auto *uintType = spvContext.getUIntType(32);
     const auto *sparseResidencyStruct = spvContext.getStructType(
     const auto *sparseResidencyStruct = spvContext.getStructType(
         {StructType::FieldInfo(uintType, "Residency.Code"),
         {StructType::FieldInfo(uintType, "Residency.Code"),
          StructType::FieldInfo(resultType, "Result.Type")},
          StructType::FieldInfo(resultType, "Result.Type")},
         "SparseResidencyStruct");
         "SparseResidencyStruct");
     instr->setResultType(sparseResidencyStruct);
     instr->setResultType(sparseResidencyStruct);
+    break;
+  }
+  default:
+    break;
   }
   }
+
+  // The instruction does not have a result-type, so nothing to do.
   return true;
   return true;
 }
 }
 
 

+ 14 - 11
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2795,9 +2795,9 @@ SpirvInstruction *SPIRVEmitter::processRWByteAddressBufferAtomicMethods(
 SpirvInstruction *
 SpirvInstruction *
 SPIRVEmitter::processGetSamplePosition(const CXXMemberCallExpr *expr) {
 SPIRVEmitter::processGetSamplePosition(const CXXMemberCallExpr *expr) {
   const auto *object = expr->getImplicitObjectArgument()->IgnoreParens();
   const auto *object = expr->getImplicitObjectArgument()->IgnoreParens();
-  auto *sampleCount =
-      spvBuilder.createUnaryOp(spv::Op::OpImageQuerySamples,
-                               astContext.UnsignedIntTy, loadIfGLValue(object));
+  auto *sampleCount = spvBuilder.createImageQuery(
+      spv::Op::OpImageQuerySamples, astContext.UnsignedIntTy,
+      expr->getExprLoc(), loadIfGLValue(object));
   if (!spirvOptions.noWarnEmulatedFeatures)
   if (!spirvOptions.noWarnEmulatedFeatures)
     emitWarning("GetSamplePosition is emulated using many SPIR-V instructions "
     emitWarning("GetSamplePosition is emulated using many SPIR-V instructions "
                 "due to lack of direct SPIR-V equivalent, so it only supports "
                 "due to lack of direct SPIR-V equivalent, so it only supports "
@@ -2921,10 +2921,12 @@ SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
   }
   }
 
 
   SpirvInstruction *query =
   SpirvInstruction *query =
-      lod ? cast<SpirvInstruction>(spvBuilder.createBinaryOp(
-                spv::Op::OpImageQuerySizeLod, resultQualType, objectInstr, lod))
-          : cast<SpirvInstruction>(spvBuilder.createUnaryOp(
-                spv::Op::OpImageQuerySize, resultQualType, objectInstr));
+      lod ? cast<SpirvInstruction>(spvBuilder.createImageQuery(
+                spv::Op::OpImageQuerySizeLod, resultQualType,
+                expr->getExprLoc(), objectInstr, lod))
+          : cast<SpirvInstruction>(spvBuilder.createImageQuery(
+                spv::Op::OpImageQuerySize, resultQualType, expr->getExprLoc(),
+                objectInstr));
 
 
   if (querySize == 1) {
   if (querySize == 1) {
     const uint32_t argIndex = mipLevel ? 1 : 0;
     const uint32_t argIndex = mipLevel ? 1 : 0;
@@ -2945,8 +2947,8 @@ SPIRVEmitter::processBufferTextureGetDimensions(const CXXMemberCallExpr *expr) {
     const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples;
     const Expr *numLevelsSamplesArg = numLevels ? numLevels : numSamples;
     const spv::Op opcode =
     const spv::Op opcode =
         numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples;
         numLevels ? spv::Op::OpImageQueryLevels : spv::Op::OpImageQuerySamples;
-    auto *numLevelsSamplesQuery =
-        spvBuilder.createUnaryOp(opcode, astContext.UnsignedIntTy, objectInstr);
+    auto *numLevelsSamplesQuery = spvBuilder.createImageQuery(
+        opcode, astContext.UnsignedIntTy, expr->getExprLoc(), objectInstr);
     storeToOutputArg(numLevelsSamplesArg, numLevelsSamplesQuery,
     storeToOutputArg(numLevelsSamplesArg, numLevelsSamplesQuery,
                      astContext.UnsignedIntTy);
                      astContext.UnsignedIntTy);
   }
   }
@@ -2981,8 +2983,9 @@ SPIRVEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr,
   // The result type of OpImageQueryLod must be a float2.
   // The result type of OpImageQueryLod must be a float2.
   const QualType queryResultType =
   const QualType queryResultType =
       astContext.getExtVectorType(astContext.FloatTy, 2u);
       astContext.getExtVectorType(astContext.FloatTy, 2u);
-  auto *query = spvBuilder.createBinaryOp(
-      spv::Op::OpImageQueryLod, queryResultType, sampledImage, coordinate);
+  auto *query =
+      spvBuilder.createImageQuery(spv::Op::OpImageQueryLod, queryResultType,
+                                  expr->getExprLoc(), sampledImage, coordinate);
 
 
   // The first component of the float2 contains the mipmap array layer.
   // The first component of the float2 contains the mipmap array layer.
   // The second component of the float2 represents the unclamped lod.
   // The second component of the float2 represents the unclamped lod.

+ 20 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -568,6 +568,26 @@ SpirvBuilder::createImageSparseTexelsResident(SpirvInstruction *residentCode,
   return inst;
   return inst;
 }
 }
 
 
+SpirvImageQuery *SpirvBuilder::createImageQuery(spv::Op opcode,
+                                                QualType resultType,
+                                                SourceLocation loc,
+                                                SpirvInstruction *image,
+                                                SpirvInstruction *lod) {
+  assert(insertPoint && "null insert point");
+  SpirvInstruction *lodParam = nullptr;
+  SpirvInstruction *coordinateParam = nullptr;
+  if (opcode == spv::Op::OpImageQuerySizeLod)
+    lodParam = lod;
+  if (opcode == spv::Op::OpImageQueryLod)
+    coordinateParam = lod;
+
+  auto *inst =
+      new (context) SpirvImageQuery(opcode, resultType, /*result-id*/ 0, loc,
+                                    image, lodParam, coordinateParam);
+  insertPoint->addInstruction(inst);
+  return inst;
+}
+
 SpirvSelect *SpirvBuilder::createSelect(QualType resultType,
 SpirvSelect *SpirvBuilder::createSelect(QualType resultType,
                                         SpirvInstruction *condition,
                                         SpirvInstruction *condition,
                                         SpirvInstruction *trueValue,
                                         SpirvInstruction *trueValue,