瀏覽代碼

[spirv] Fix double dispatch

Lei Zhang 7 年之前
父節點
當前提交
539ae179a6

+ 110 - 3
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -98,6 +98,9 @@ public:
 
   virtual ~SpirvInstruction() = default;
 
+  // Invokes SPIR-V visitor on this instruction.
+  virtual bool invokeVisitor(Visitor *) = 0;
+
   Kind getKind() const { return kind; }
   spv::Op getopcode() const { return opcode; }
   QualType getResultType() const { return resultType; }
@@ -114,9 +117,6 @@ public:
 
   clang::SourceLocation getSourceLocation() const { return srcLoc; }
 
-  // Handle SPIR-V instruction visitors.
-  bool invokeVisitor(Visitor *);
-
 protected:
   // Forbid creating SpirvInstruction directly
   SpirvInstruction(Kind kind, spv::Op opcode, QualType resultType,
@@ -131,6 +131,9 @@ private:
   SourceLocation srcLoc;
 };
 
+#define DECLARE_INVOKE_VISITOR_FOR_CLASS(cls)                                  \
+  bool invokeVisitor(Visitor *v) override;
+
 /// \brief OpCapability instruction
 class SpirvCapability : public SpirvInstruction {
 public:
@@ -141,6 +144,8 @@ public:
     return inst->getKind() == IK_Capability;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvCapability)
+
   spv::Capability getCapability() const { return capability; }
 
 private:
@@ -157,6 +162,8 @@ public:
     return inst->getKind() == IK_Extension;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvExtension)
+
   llvm::StringRef getExtensionName() const { return extName; }
 
 private:
@@ -174,6 +181,8 @@ public:
     return inst->getKind() == IK_ExtInstImport;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInstImport)
+
   llvm::StringRef getExtendedInstSetName() const { return extName; }
 
 private:
@@ -190,6 +199,8 @@ public:
     return inst->getKind() == IK_MemoryModel;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvMemoryModel)
+
   spv::AddressingModel getAddressingModel() const { return addressModel; }
   spv::MemoryModel getMemoryModel() const { return memoryModel; }
 
@@ -210,6 +221,8 @@ public:
     return inst->getKind() == IK_EntryPoint;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvEntryPoint)
+
   spv::ExecutionModel getExecModel() const { return execModel; }
   uint32_t getEntryPointId() const { return entryPoint; }
   llvm::StringRef getEntryPointName() const { return name; }
@@ -234,6 +247,8 @@ public:
     return inst->getKind() == IK_ExecutionMode;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode)
+
   uint32_t getEntryPointId() const { return entryPointId; }
   spv::ExecutionMode getExecutionMode() const { return execMode; }
   llvm::ArrayRef<uint32_t> getParams() const { return params; }
@@ -254,6 +269,8 @@ public:
     return inst->getKind() == IK_String;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvString)
+
   llvm::StringRef getString() const { return str; }
 
 private:
@@ -271,6 +288,8 @@ public:
     return inst->getKind() == IK_Source;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSource)
+
   spv::SourceLanguage getSourceLanguage() const { return lang; }
   uint32_t getVersion() const { return version; }
   bool hasFileId() const { return file != 0; }
@@ -295,6 +314,8 @@ public:
     return inst->getKind() == IK_Name;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvName)
+
   uint32_t getTarget() const { return target; }
   bool isForMember() const { return member.hasValue(); }
   uint32_t getMember() const { return member.getValue(); }
@@ -316,6 +337,8 @@ public:
     return inst->getKind() == IK_ModuleProcessed;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvModuleProcessed)
+
   llvm::StringRef getProcess() const { return process; }
 
 private:
@@ -334,6 +357,8 @@ public:
     return inst->getKind() == IK_Decoration;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvDecoration)
+
   // Returns the <result-id> of the target of the decoration. It may be the id
   // of an object or the id of a structure type whose member is being decorated.
   uint32_t getTarget() const { return target; }
@@ -361,6 +386,8 @@ public:
     return inst->getKind() == IK_Variable;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvVariable)
+
   bool hasInitializer() const { return initializer != 0; }
   uint32_t getInitializer() const { return initializer; }
   spv::StorageClass getStorageClass() const { return storageClass; }
@@ -379,6 +406,8 @@ public:
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_FunctionParameter;
   }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvFunctionParameter)
 };
 
 /// \brief Merge instructions include OpLoopMerge and OpSelectionMerge
@@ -396,6 +425,8 @@ protected:
            inst->getKind() == IK_SelectionMerge;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvMerge)
+
 private:
   uint32_t mergeBlock;
 };
@@ -410,6 +441,8 @@ public:
     return inst->getKind() == IK_LoopMerge;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvLoopMerge)
+
   uint32_t getContinueTarget() const { return continueTarget; }
   spv::LoopControlMask getLoopControlMask() const { return loopControlMask; }
 
@@ -428,6 +461,8 @@ public:
     return inst->getKind() == IK_SelectionMerge;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSelectionMerge)
+
   spv::SelectionControlMask getSelectionControlMask() const {
     return selControlMask;
   }
@@ -452,6 +487,8 @@ public:
     return inst->getKind() >= IK_Branch && inst->getKind() <= IK_Unreachable;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvTerminator)
+
 protected:
   SpirvTerminator(Kind kind, spv::Op opcode, SourceLocation loc);
 };
@@ -465,6 +502,8 @@ public:
            inst->getKind() <= IK_BranchConditional;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBranching)
+
   virtual llvm::ArrayRef<uint32_t> getTargetBranches() const = 0;
 
 protected:
@@ -481,6 +520,8 @@ public:
     return inst->getKind() == IK_Branch;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBranch)
+
   uint32_t getTargetLabel() const { return targetLabel; }
 
   // Returns all possible branches that could be taken by the branching
@@ -502,6 +543,8 @@ public:
     return inst->getKind() == IK_BranchConditional;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBranchConditional)
+
   llvm::ArrayRef<uint32_t> getTargetBranches() const {
     return {trueLabel, falseLabel};
   }
@@ -525,6 +568,8 @@ public:
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Kill;
   }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvKill)
 };
 
 /// \brief OpReturn and OpReturnValue instructions
@@ -537,6 +582,8 @@ public:
     return inst->getKind() == IK_Return;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvReturn)
+
   bool hasReturnValue() const { return returnValue != 0; }
   uint32_t getReturnValue() const { return returnValue; }
 
@@ -555,6 +602,8 @@ public:
     return inst->getKind() == IK_Switch;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSwitch)
+
   uint32_t getSelector() const { return selector; }
   uint32_t getDefaultLabel() const { return defaultLabel; }
   llvm::ArrayRef<std::pair<uint32_t, uint32_t>> getTargets() const {
@@ -580,6 +629,8 @@ public:
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Unreachable;
   }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvUnreachable)
 };
 
 /// \brief Access Chain instruction representation (OpAccessChain)
@@ -596,6 +647,8 @@ public:
     return inst->getKind() == IK_AccessChain;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvAccessChain)
+
   uint32_t getBase() const { return base; }
   llvm::ArrayRef<uint32_t> getIndexes() const { return indices; }
 
@@ -641,6 +694,8 @@ public:
     return inst->getKind() == IK_Atomic;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvAtomic)
+
   uint32_t getPointer() const { return pointer; }
   spv::Scope getScope() const { return scope; }
   spv::MemorySemanticsMask getMemorySemantics() const { return memorySemantic; }
@@ -676,6 +731,8 @@ public:
     return inst->getKind() == IK_Barrier;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBarrier)
+
   spv::Scope getMemoryScope() const { return memoryScope; }
   spv::MemorySemanticsMask getMemorySemantics() const {
     return memorySemantics;
@@ -761,6 +818,8 @@ public:
     return inst->getKind() == IK_BinaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBinaryOp)
+
   uint32_t getOperand1() const { return operand1; }
   uint32_t getOperand2() const { return operand2; }
   bool isSpecConstantOp() const {
@@ -784,6 +843,8 @@ public:
            inst->getKind() == IK_BitFieldInsert;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBitField)
+
   virtual uint32_t getBase() const { return base; }
   virtual uint32_t getOffset() const { return offset; }
   virtual uint32_t getCount() const { return count; }
@@ -810,6 +871,8 @@ public:
     return inst->getKind() == IK_BitFieldExtract;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldExtract)
+
   uint32_t isSigned() const {
     return getopcode() == spv::Op::OpBitFieldSExtract;
   }
@@ -826,6 +889,8 @@ public:
     return inst->getKind() == IK_BitFieldInsert;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldInsert)
+
   uint32_t getInsert() const { return insert; }
 
 private:
@@ -847,6 +912,8 @@ public:
     return inst->getKind() == IK_Composite;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvComposite)
+
   bool isConstantComposite() const {
     return getopcode() == spv::Op::OpConstantComposite;
   }
@@ -871,6 +938,8 @@ public:
     return inst->getKind() == IK_CompositeExtract;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeExtract)
+
   uint32_t getComposite() const { return composite; }
   llvm::ArrayRef<uint32_t> getIndexes() const { return indices; }
 
@@ -891,6 +960,8 @@ public:
     return inst->getKind() == IK_ExtInst;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInst)
+
   uint32_t getInstructionSetId() const { return instructionSetId; }
   GLSLstd450 getInstruction() const { return instruction; }
   llvm::ArrayRef<uint32_t> getOperands() const { return operands; }
@@ -912,6 +983,8 @@ public:
     return inst->getKind() == IK_FunctionCall;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvFunctionCall)
+
   uint32_t getFunction() const { return function; }
   llvm::ArrayRef<uint32_t> getArgs() const { return args; }
 
@@ -929,6 +1002,8 @@ public:
            inst->getKind() <= IK_GroupNonUniformUnaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvGroupNonUniformOp)
+
   spv::Scope getExecutionScope() const { return execScope; }
 
 protected:
@@ -952,6 +1027,8 @@ public:
     return inst->getKind() == IK_GroupNonUniformBinaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformBinaryOp)
+
   uint32_t getArg1() const { return arg1; }
   uint32_t getArg2() const { return arg2; }
 
@@ -971,6 +1048,8 @@ public:
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_GroupNonUniformElect;
   }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformElect)
 };
 
 /// \brief OpGroupNonUniform* unary instructions.
@@ -986,6 +1065,8 @@ public:
     return inst->getKind() == IK_GroupNonUniformUnaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformUnaryOp)
+
   uint32_t getArg() const { return arg; }
   bool hasGroupOp() const { return groupOp.hasValue(); }
   spv::GroupOperation getGroupOp() const { return groupOp.getValue(); }
@@ -1037,6 +1118,8 @@ public:
     return inst->getKind() == IK_ImageOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvImageOp)
+
   uint32_t getImage() const { return image; }
   uint32_t getCoordinate() const { return coordinate; }
   spv::ImageOperandsMask getImageOperandsMask() const { return operandsMask; }
@@ -1108,6 +1191,8 @@ public:
     return inst->getKind() == IK_ImageQuery;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvImageQuery)
+
   uint32_t getImage() const { return image; }
   uint32_t hasLod() const { return lod != 0; }
   uint32_t getLod() const { return lod; }
@@ -1131,6 +1216,8 @@ public:
     return inst->getKind() == IK_ImageSparseTexelsResident;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvImageSparseTexelsResident)
+
   uint32_t getResidentCode() const { return residentCode; }
 
 private:
@@ -1149,6 +1236,8 @@ public:
     return inst->getKind() == IK_ImageTexelPointer;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvImageTexelPointer)
+
   uint32_t getImage() const { return image; }
   uint32_t getCoordinate() const { return coordinate; }
   uint32_t getSample() const { return sample; }
@@ -1170,6 +1259,8 @@ public:
     return inst->getKind() == IK_Load;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvLoad)
+
   uint32_t getPointer() const { return pointer; }
   bool hasMemoryAccessSemantics() const { return memoryAccess.hasValue(); }
   spv::MemoryAccessMask getMemoryAccess() const {
@@ -1192,6 +1283,8 @@ public:
     return inst->getKind() == IK_SampledImage;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSampledImage)
+
   uint32_t getImage() const { return image; }
   uint32_t getSampler() const { return sampler; }
 
@@ -1211,6 +1304,8 @@ public:
     return inst->getKind() == IK_Select;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSelect)
+
   uint32_t getCondition() const { return condition; }
   uint32_t getTrueObject() const { return trueObject; }
   uint32_t getFalseObject() const { return falseObject; }
@@ -1233,6 +1328,8 @@ public:
     return inst->getKind() == IK_SpecConstantBinaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSpecConstantBinaryOp)
+
   spv::Op getSpecConstantopcode() const { return specOp; }
   uint32_t getOperand1() const { return operand1; }
   uint32_t getOperand2() const { return operand2; }
@@ -1255,6 +1352,8 @@ public:
     return inst->getKind() == IK_SpecConstantUnaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvSpecConstantUnaryOp)
+
   spv::Op getSpecConstantopcode() const { return specOp; }
   uint32_t getOperand() const { return operand; }
 
@@ -1274,6 +1373,8 @@ public:
     return inst->getKind() == IK_Store;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvStore)
+
   uint32_t getPointer() const { return pointer; }
   uint32_t getObject() const { return object; }
   bool hasMemoryAccessSemantics() const { return memoryAccess.hasValue(); }
@@ -1335,6 +1436,8 @@ public:
     return inst->getKind() == IK_UnaryOp;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
+
   uint32_t getOperand() const { return operand; }
 
 private:
@@ -1353,6 +1456,8 @@ public:
     return inst->getKind() == IK_VectorShuffle;
   }
 
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvVectorShuffle)
+
   uint32_t getVec1() const { return vec1; }
   uint32_t getVec2() const { return vec2; }
   llvm::ArrayRef<uint32_t> getComponents() const { return components; }
@@ -1363,6 +1468,8 @@ private:
   llvm::SmallVector<uint32_t, 4> components;
 };
 
+#undef DECLARE_INVOKE_VISITOR_FOR_CLASS
+
 } // namespace spirv
 } // namespace clang
 

+ 1 - 1
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -51,7 +51,7 @@ public:
   SpirvModule &operator=(SpirvModule &&) = delete;
 
   // Handle SPIR-V module visitors.
-  bool visit(Visitor *);
+  bool invokeVisitor(Visitor *);
 
 private:
   uint32_t bound; ///< The <result-id> bound: the next unused one

+ 62 - 48
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -42,54 +42,68 @@ public:
   virtual bool visit(SpirvModule *, Phase) {}
   virtual bool visit(SpirvFunction *, Phase) {}
   virtual bool visit(SpirvBasicBlock *, Phase) {}
-  virtual bool visit(SpirvInstruction *) {}
-  virtual bool visit(SpirvCapability *) {}
-  virtual bool visit(SpirvExtension *) {}
-  virtual bool visit(SpirvExtInstImport *) {}
-  virtual bool visit(SpirvMemoryModel *) {}
-  virtual bool visit(SpirvEntryPoint *) {}
-  virtual bool visit(SpirvExecutionMode *) {}
-  virtual bool visit(SpirvString *) {}
-  virtual bool visit(SpirvSource *) {}
-  virtual bool visit(SpirvName *) {}
-  virtual bool visit(SpirvModuleProcessed *) {}
-  virtual bool visit(SpirvDecoration *) {}
-  virtual bool visit(SpirvVariable *) {}
-  virtual bool visit(SpirvFunctionParameter *) {}
-  virtual bool visit(SpirvLoopMerge *) {}
-  virtual bool visit(SpirvSelectionMerge *) {}
-  virtual bool visit(SpirvBranching *) {}
-  virtual bool visit(SpirvBranch *) {}
-  virtual bool visit(SpirvBranchConditional *) {}
-  virtual bool visit(SpirvKill *) {}
-  virtual bool visit(SpirvReturn *) {}
-  virtual bool visit(SpirvSwitch *) {}
-  virtual bool visit(SpirvUnreachable *) {}
-  virtual bool visit(SpirvAccessChain *) {}
-  virtual bool visit(SpirvAtomic *) {}
-  virtual bool visit(SpirvBarrier *) {}
-  virtual bool visit(SpirvBinaryOp *) {}
-  virtual bool visit(SpirvBitFieldExtract *) {}
-  virtual bool visit(SpirvBitFieldInsert *) {}
-  virtual bool visit(SpirvComposite *) {}
-  virtual bool visit(SpirvCompositeExtract *) {}
-  virtual bool visit(SpirvExtInst *) {}
-  virtual bool visit(SpirvFunctionCall *) {}
-  virtual bool visit(SpirvNonUniformBinaryOp *) {}
-  virtual bool visit(SpirvNonUniformElect *) {}
-  virtual bool visit(SpirvNonUniformUnaryOp *) {}
-  virtual bool visit(SpirvImageOp *) {}
-  virtual bool visit(SpirvImageQuery *) {}
-  virtual bool visit(SpirvImageSparseTexelsResident *) {}
-  virtual bool visit(SpirvImageTexelPointer *) {}
-  virtual bool visit(SpirvLoad *) {}
-  virtual bool visit(SpirvSampledImage *) {}
-  virtual bool visit(SpirvSelect *) {}
-  virtual bool visit(SpirvSpecConstantBinaryOp *) {}
-  virtual bool visit(SpirvSpecConstantUnaryOp *) {}
-  virtual bool visit(SpirvStore *) {}
-  virtual bool visit(SpirvUnaryOp *) {}
-  virtual bool visit(SpirvVectorShuffle *) {}
+
+  /// The "sink" visit function for all instructions.
+  ///
+  /// By default, all other visit instructions redirect to this visit function.
+  /// So that you want override this visit function to handle all instructions,
+  /// regardless of their polymorphism.
+  virtual bool visitInstruction(SpirvInstruction *) {}
+
+#define DEFINE_VISIT_METHOD(cls)                                               \
+  virtual bool visit(cls *i) { visitInstruction(i); }
+
+  DEFINE_VISIT_METHOD(SpirvCapability)
+  DEFINE_VISIT_METHOD(SpirvExtension)
+  DEFINE_VISIT_METHOD(SpirvExtInstImport)
+  DEFINE_VISIT_METHOD(SpirvMemoryModel)
+  DEFINE_VISIT_METHOD(SpirvEntryPoint)
+  DEFINE_VISIT_METHOD(SpirvExecutionMode)
+  DEFINE_VISIT_METHOD(SpirvString)
+  DEFINE_VISIT_METHOD(SpirvSource)
+  DEFINE_VISIT_METHOD(SpirvName)
+  DEFINE_VISIT_METHOD(SpirvModuleProcessed)
+  DEFINE_VISIT_METHOD(SpirvDecoration)
+  DEFINE_VISIT_METHOD(SpirvVariable)
+
+  DEFINE_VISIT_METHOD(SpirvFunctionParameter)
+  DEFINE_VISIT_METHOD(SpirvLoopMerge)
+  DEFINE_VISIT_METHOD(SpirvSelectionMerge)
+  DEFINE_VISIT_METHOD(SpirvBranching)
+  DEFINE_VISIT_METHOD(SpirvBranch)
+  DEFINE_VISIT_METHOD(SpirvBranchConditional)
+  DEFINE_VISIT_METHOD(SpirvKill)
+  DEFINE_VISIT_METHOD(SpirvReturn)
+  DEFINE_VISIT_METHOD(SpirvSwitch)
+  DEFINE_VISIT_METHOD(SpirvUnreachable)
+
+  DEFINE_VISIT_METHOD(SpirvAccessChain)
+  DEFINE_VISIT_METHOD(SpirvAtomic)
+  DEFINE_VISIT_METHOD(SpirvBarrier)
+  DEFINE_VISIT_METHOD(SpirvBinaryOp)
+  DEFINE_VISIT_METHOD(SpirvBitFieldExtract)
+  DEFINE_VISIT_METHOD(SpirvBitFieldInsert)
+  DEFINE_VISIT_METHOD(SpirvComposite)
+  DEFINE_VISIT_METHOD(SpirvCompositeExtract)
+  DEFINE_VISIT_METHOD(SpirvExtInst)
+  DEFINE_VISIT_METHOD(SpirvFunctionCall)
+  DEFINE_VISIT_METHOD(SpirvNonUniformBinaryOp)
+  DEFINE_VISIT_METHOD(SpirvNonUniformElect)
+  DEFINE_VISIT_METHOD(SpirvNonUniformUnaryOp)
+  DEFINE_VISIT_METHOD(SpirvImageOp)
+  DEFINE_VISIT_METHOD(SpirvImageQuery)
+  DEFINE_VISIT_METHOD(SpirvImageSparseTexelsResident)
+  DEFINE_VISIT_METHOD(SpirvImageTexelPointer)
+  DEFINE_VISIT_METHOD(SpirvLoad)
+  DEFINE_VISIT_METHOD(SpirvSampledImage)
+  DEFINE_VISIT_METHOD(SpirvSelect)
+  DEFINE_VISIT_METHOD(SpirvSpecConstantBinaryOp)
+  DEFINE_VISIT_METHOD(SpirvSpecConstantUnaryOp)
+  DEFINE_VISIT_METHOD(SpirvStore)
+  DEFINE_VISIT_METHOD(SpirvUnaryOp)
+  DEFINE_VISIT_METHOD(SpirvVectorShuffle)
+
+#undef DEFINE_VISIT_METHOD
 
 protected:
   Visitor() = default;

+ 54 - 4
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -16,14 +16,64 @@
 namespace clang {
 namespace spirv {
 
+#define DEFINE_INVOKE_VISITOR_FOR_CLASS(cls)                                   \
+  bool cls::invokeVisitor(Visitor *v) { return v->visit(this); }
+
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCapability)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtension)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInstImport)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvMemoryModel)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEntryPoint)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvString)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSource)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvName)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvModuleProcessed)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvDecoration)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvVariable)
+
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvFunctionParameter)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvLoopMerge)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSelectionMerge)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBranch)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBranchConditional)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvKill)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvReturn)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSwitch)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUnreachable)
+
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvAccessChain)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvAtomic)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBarrier)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBinaryOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldExtract)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldInsert)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvComposite)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeExtract)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInst)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvFunctionCall)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformBinaryOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformElect)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvNonUniformUnaryOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageQuery)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageSparseTexelsResident)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvImageTexelPointer)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvLoad)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSampledImage)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSelect)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSpecConstantBinaryOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSpecConstantUnaryOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvStore)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvUnaryOp)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvVectorShuffle)
+
+#undef DEFINE_INVOKE_VISITOR_FOR_CLASS
+
 SpirvInstruction::SpirvInstruction(Kind k, spv::Op op, QualType type,
                                    uint32_t id, SourceLocation loc)
     : kind(k), opcode(op), resultType(type), resultId(id), srcLoc(loc) {}
 
-bool SpirvInstruction::invokeVisitor(Visitor *visitor) {
-  return visitor->visit(this);
-}
-
 SpirvCapability::SpirvCapability(SourceLocation loc, spv::Capability cap)
     : SpirvInstruction(IK_Capability, spv::Op::OpCapability, QualType(),
                        /*resultId=*/0, loc),

+ 1 - 1
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -16,7 +16,7 @@ namespace spirv {
 SpirvModule::SpirvModule()
     : bound(1), memoryModel(nullptr), debugSource(nullptr) {}
 
-bool SpirvModule::visit(Visitor *visitor) {
+bool SpirvModule::invokeVisitor(Visitor *visitor) {
   if (!visitor->visit(this, Visitor::Phase::Init))
     return false;