Răsfoiți Sursa

[spirv] Create SpirvArrayLength instruction class.

Sadly, OpArrayLength takes an integer literal and does not fit into
SpirvBinaryOp instruction class.
Ehsan Nasiri 6 ani în urmă
părinte
comite
d73dd2aba6

+ 1 - 0
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -284,6 +284,7 @@ public:
   bool visit(SpirvStore *);
   bool visit(SpirvUnaryOp *);
   bool visit(SpirvVectorShuffle *);
+  bool visit(SpirvArrayLength *);
 
   // Returns the assembled binary built up in this visitor.
   std::vector<uint32_t> takeBinary();

+ 5 - 0
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -411,6 +411,11 @@ public:
   /// \brief Creates an OpEndPrimitive instruction.
   void createEndPrimitive(SourceLocation loc = {});
 
+  /// \brief Creates an OpArrayLength instruction.
+  SpirvArrayLength *createArrayLength(QualType resultType, SourceLocation loc,
+                                      SpirvInstruction *structure,
+                                      uint32_t arrayMember);
+
   // === SPIR-V Module Structure ===
 
   inline void requireCapability(spv::Capability, SourceLocation loc = {});

+ 22 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -111,6 +111,7 @@ public:
     IK_Store,                     // OpStore
     IK_UnaryOp,                   // Unary operations
     IK_VectorShuffle,             // OpVectorShuffle
+    IK_ArrayLength,               // OpArrayLength
   };
 
   virtual ~SpirvInstruction() = default;
@@ -1741,6 +1742,27 @@ private:
   llvm::SmallVector<uint32_t, 4> components;
 };
 
+class SpirvArrayLength : public SpirvInstruction {
+public:
+  SpirvArrayLength(QualType resultType, uint32_t resultId, SourceLocation loc,
+                   SpirvInstruction *structure, uint32_t arrayMember);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_ArrayLength;
+  }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvArrayLength)
+
+  SpirvInstruction *getStructure() const { return structure; }
+  uint32_t getArrayMember() const { return arrayMember; }
+
+private:
+  SpirvInstruction *structure;
+  uint32_t arrayMember;
+};
+
+
 #undef DECLARE_INVOKE_VISITOR_FOR_CLASS
 
 } // namespace spirv

+ 1 - 0
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -111,6 +111,7 @@ public:
   DEFINE_VISIT_METHOD(SpirvStore)
   DEFINE_VISIT_METHOD(SpirvUnaryOp)
   DEFINE_VISIT_METHOD(SpirvVectorShuffle)
+  DEFINE_VISIT_METHOD(SpirvArrayLength)
 
 #undef DEFINE_VISIT_METHOD
 

+ 12 - 0
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -1050,6 +1050,18 @@ bool EmitVisitor::visit(SpirvVectorShuffle *inst) {
   return true;
 }
 
+bool EmitVisitor::visit(SpirvArrayLength *inst) {
+  initInstruction(inst);
+  curInst.push_back(inst->getResultTypeId());
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
+  curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getStructure()));
+  curInst.push_back(inst->getArrayMember());
+  finalizeInstruction();
+  emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
+                              inst->getDebugName());
+  return true;
+}
+
 // EmitTypeHandler ------
 
 void EmitTypeHandler::initTypeInstruction(spv::Op op) {

+ 4 - 6
tools/clang/lib/SPIRV/SPIRVEmitter.cpp

@@ -2719,8 +2719,8 @@ SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
   // (RW)ByteAddressBuffers/(RW)StructuredBuffers are represented as a structure
   // with only one member that is a runtime array. We need to perform
   // OpArrayLength on member 0.
-  auto *length = spvBuilder.createBinaryOp(
-      spv::Op::OpArrayLength, astContext.UnsignedIntTy, objectInstr, 0);
+  SpirvInstruction *length = spvBuilder.createArrayLength(
+      astContext.UnsignedIntTy, expr->getExprLoc(), objectInstr, 0);
   // For (RW)ByteAddressBuffers, GetDimensions() must return the array length
   // in bytes, but OpArrayLength returns the number of uints in the runtime
   // array. Therefore we must multiply the results by 4.
@@ -2732,7 +2732,6 @@ SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
   spvBuilder.createStore(doExpr(expr->getArg(0)), length);
 
   if (isStructuredBuffer) {
-    /*
     // TODO (ehsan): We don't want to use getAlignmentAndSize :-(
 
     // For (RW)StructuredBuffer, the stride of the runtime array (which is the
@@ -2740,9 +2739,8 @@ SPIRVEmitter::processByteAddressBufferStructuredBufferGetDimensions(
     uint32_t size = 0, stride = 0;
     std::tie(std::ignore, size) = typeTranslator.getAlignmentAndSize(
         type, spirvOptions.sBufferLayoutRule, &stride);
-    const auto sizeId = theBuilder.getConstantUint32(size);
-    theBuilder.createStore(doExpr(expr->getArg(1)), sizeId);
-    */
+    auto *sizeInstr = spvBuilder.getConstantUint32(size);
+    spvBuilder.createStore(doExpr(expr->getArg(1)), sizeInstr);
   }
 
   return nullptr;

+ 14 - 0
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -146,6 +146,7 @@ SpirvCompositeExtract *SpirvBuilder::createCompositeExtract(
   assert(insertPoint && "null insert point");
   auto *instruction = new (context)
       SpirvCompositeExtract(resultType, /*id*/ 0, loc, composite, indexes);
+  instruction->setRValue();
   insertPoint->addInstruction(instruction);
   return instruction;
 }
@@ -167,6 +168,7 @@ SpirvVectorShuffle *SpirvBuilder::createVectorShuffle(
   assert(insertPoint && "null insert point");
   auto *instruction = new (context) SpirvVectorShuffle(
       resultType, /*id*/ 0, loc, vector1, vector2, selectors);
+  instruction->setRValue();
   insertPoint->addInstruction(instruction);
   return instruction;
 }
@@ -210,6 +212,7 @@ SpirvBuilder::createFunctionCall(QualType returnType, SpirvFunction *func,
   assert(insertPoint && "null insert point");
   auto *instruction =
       new (context) SpirvFunctionCall(returnType, /*id*/ 0, loc, func, params);
+  instruction->setRValue();
   insertPoint->addInstruction(instruction);
   return instruction;
 }
@@ -743,6 +746,17 @@ void SpirvBuilder::createEndPrimitive(SourceLocation loc) {
   insertPoint->addInstruction(inst);
 }
 
+SpirvArrayLength *SpirvBuilder::createArrayLength(QualType resultType,
+                                                  SourceLocation loc,
+                                                  SpirvInstruction *structure,
+                                                  uint32_t arrayMember) {
+  assert(insertPoint && "null insert point");
+  auto *inst = new (context) SpirvArrayLength(resultType, /*result-id*/ 0, loc,
+                                              structure, arrayMember);
+  insertPoint->addInstruction(inst);
+  return inst;
+}
+
 void SpirvBuilder::addExtension(Extension ext, llvm::StringRef target,
                                 SourceLocation loc) {
   // TODO: The extension management should be removed from here and added as a

+ 9 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -79,6 +79,7 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSpecConstantUnaryOp)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvStore)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvVectorShuffle)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvArrayLength)
 
 #undef DEFINE_INVOKE_VISITOR_FOR_CLASS
 
@@ -880,5 +881,13 @@ SpirvVectorShuffle::SpirvVectorShuffle(QualType resultType, uint32_t resultId,
       vec1(vec1Inst), vec2(vec2Inst),
       components(componentsVec.begin(), componentsVec.end()) {}
 
+SpirvArrayLength::SpirvArrayLength(QualType resultType, uint32_t resultId,
+                                   SourceLocation loc,
+                                   SpirvInstruction *structure_,
+                                   uint32_t memberLiteral)
+    : SpirvInstruction(IK_ArrayLength, spv::Op::OpArrayLength, resultType,
+                       resultId, loc),
+      structure(structure_), arrayMember(memberLiteral) {}
+
 } // namespace spirv
 } // namespace clang