Selaa lähdekoodia

[spirv] Fix memory leaks in the SPIR-V backend. (#3091)

* [spirv] Fix memory leaks in the SPIR-V backend.

* [spirv] Avoid leaking memory if there are failures during AST traversal.
Ehsan 5 vuotta sitten
vanhempi
commit
3da289217c

+ 1 - 1
tools/clang/include/clang/SPIRV/SpirvBasicBlock.h

@@ -35,7 +35,7 @@ public:
 class SpirvBasicBlock {
 public:
   SpirvBasicBlock(llvm::StringRef name);
-  ~SpirvBasicBlock() = default;
+  ~SpirvBasicBlock();
 
   // Forbid copy construction and assignment
   SpirvBasicBlock(const SpirvBasicBlock &) = delete;

+ 18 - 18
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -61,17 +61,22 @@ public:
   SpirvBuilder &operator=(SpirvBuilder &&) = delete;
 
   /// Returns the SPIR-V module being built.
-  SpirvModule *getModule() { return mod; }
+  SpirvModule *getModule() { return mod.get(); }
 
   // === Function and Basic Block ===
 
+  /// \brief Creates a SpirvFunction object with the given information and adds
+  /// it to list of all discovered functions in the SpirvModule.
+  SpirvFunction *createSpirvFunction(QualType returnType, SourceLocation,
+                                     llvm::StringRef name, bool isPrecise);
+
   /// \brief Begins building a SPIR-V function by allocating a SpirvFunction
   /// object. Returns the pointer for the function on success. Returns nullptr
   /// on failure.
   ///
   /// At any time, there can only exist at most one function under building.
-  SpirvFunction *beginFunction(QualType returnType,
-                               SourceLocation, llvm::StringRef name = "",
+  SpirvFunction *beginFunction(QualType returnType, SourceLocation,
+                               llvm::StringRef name = "",
                                bool isPrecise = false,
                                SpirvFunction *func = nullptr);
 
@@ -605,16 +610,9 @@ private:
   ASTContext &astContext;
   SpirvContext &context; ///< From which we allocate various SPIR-V object
 
-  SpirvModule *mod;             ///< The current module being built
-  SpirvFunction *function;      ///< The current function being built
-  SpirvBasicBlock *insertPoint; ///< The current basic block being built
-
-  /// \brief List of basic blocks being built.
-  ///
-  /// We need a vector here to remember the order of insertion. Order matters
-  /// here since, for example, we'll know for sure the first basic block is
-  /// the entry block.
-  std::vector<SpirvBasicBlock *> basicBlocks;
+  std::unique_ptr<SpirvModule> mod; ///< The current module being built
+  SpirvFunction *function;          ///< The current function being built
+  SpirvBasicBlock *insertPoint;     ///< The current basic block being built
 
   const SpirvCodeGenOptions &spirvOptions; ///< Command line options.
 
@@ -635,11 +633,15 @@ private:
 };
 
 void SpirvBuilder::requireCapability(spv::Capability cap, SourceLocation loc) {
-  mod->addCapability(new (context) SpirvCapability(loc, cap));
+  auto *capability = new (context) SpirvCapability(loc, cap);
+  if (!mod->addCapability(capability))
+    capability->releaseMemory();
 }
 
 void SpirvBuilder::requireExtension(llvm::StringRef ext, SourceLocation loc) {
-  mod->addExtension(new (context) SpirvExtension(loc, ext));
+  auto *extension = new (context) SpirvExtension(loc, ext);
+  if (!mod->addExtension(extension))
+    extension->releaseMemory();
 }
 
 void SpirvBuilder::setMemoryModel(spv::AddressingModel addrModel,
@@ -661,9 +663,7 @@ SpirvBuilder::setDebugSource(uint32_t major, uint32_t minor,
   uint32_t version = 100 * major + 10 * minor;
   SpirvSource *mainSource = nullptr;
   for (const auto &name : fileNames) {
-    SpirvString *fileString =
-        name.empty() ? nullptr
-                     : new (context) SpirvString(/*SourceLocation*/ {}, name);
+    SpirvString *fileString = name.empty() ? nullptr : getString(name);
     SpirvSource *debugSource = new (context)
         SpirvSource(/*SourceLocation*/ {}, spv::SourceLanguage::HLSL, version,
                     fileString, content);

+ 8 - 1
tools/clang/include/clang/SPIRV/SpirvContext.h

@@ -121,7 +121,7 @@ class SpirvContext {
 public:
   using ShaderModelKind = hlsl::ShaderModel::Kind;
   SpirvContext();
-  ~SpirvContext() = default;
+  ~SpirvContext();
 
   // Forbid copy construction and assignment
   SpirvContext(const SpirvContext &) = delete;
@@ -259,7 +259,11 @@ private:
   std::array<const IntegerType *, 7> uintTypes;
   std::array<const FloatType *, 7> floatTypes;
 
+  // The VectorType at index i has the length of i. For example, vector of
+  // size 4 would be at index 4. Valid SPIR-V vector sizes are 2,3,4.
+  // Therefore, index 0 and 1 of this array are unused (nullptr).
   using VectorTypeArray = std::array<const VectorType *, 5>;
+
   using MatrixTypeVector = std::vector<const MatrixType *>;
   using SCToPtrTyMap =
       llvm::DenseMap<spv::StorageClass, const SpirvPointerType *,
@@ -273,11 +277,14 @@ private:
   llvm::DenseSet<const ImageType *, ImageTypeMapInfo> imageTypes;
   const SamplerType *samplerType;
   llvm::DenseMap<const ImageType *, const SampledImageType *> sampledImageTypes;
+  llvm::SmallVector<const HybridSampledImageType *, 4> hybridSampledImageTypes;
   llvm::DenseSet<const ArrayType *, ArrayTypeMapInfo> arrayTypes;
   llvm::DenseSet<const RuntimeArrayType *, RuntimeArrayTypeMapInfo>
       runtimeArrayTypes;
   llvm::SmallVector<const StructType *, 8> structTypes;
+  llvm::SmallVector<const HybridStructType *, 8> hybridStructTypes;
   llvm::DenseMap<const SpirvType *, SCToPtrTyMap> pointerTypes;
+  llvm::SmallVector<const HybridPointerType *, 8> hybridPointerTypes;
   llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
   const AccelerationStructureTypeNV *accelerationStructureTypeNV;
   const RayQueryProvisionalTypeKHR *rayQueryProvisionalTypeKHR;

+ 8 - 9
tools/clang/include/clang/SPIRV/SpirvFunction.h

@@ -26,7 +26,8 @@ class SpirvFunction {
 public:
   SpirvFunction(QualType astReturnType, SourceLocation,
                 llvm::StringRef name = "", bool precise = false);
-  ~SpirvFunction() = default;
+
+  ~SpirvFunction();
 
   // Forbid copy construction and assignment
   SpirvFunction(const SpirvFunction &) = delete;
@@ -93,14 +94,12 @@ public:
   bool isRValue() { return rvalue; }
 
 private:
-  uint32_t functionId; ///< This function's <result-id>
-
-  QualType astReturnType;                       ///< The return type
-  SpirvType *returnType;                        ///< The lowered return type
-  SpirvType *fnType;                            ///< The SPIR-V function type
-
-  bool relaxedPrecision; ///< Whether the return type is at relaxed precision
-  bool precise;          ///< Whether the return value is 'precise'
+  uint32_t functionId;    ///< This function's <result-id>
+  QualType astReturnType; ///< The return type
+  SpirvType *returnType;  ///< The lowered return type
+  SpirvType *fnType;      ///< The SPIR-V function type
+  bool relaxedPrecision;  ///< Whether the return type is at relaxed precision
+  bool precise;           ///< Whether the return value is 'precise'
 
   /// Legalization-specific code
   ///

+ 125 - 0
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -32,6 +32,9 @@ class SpirvVariable;
 class SpirvString;
 class Visitor;
 
+#define DEFINE_RELEASE_MEMORY_FOR_CLASS(cls)                                   \
+  void releaseMemory() override { this->~cls(); }
+
 /// \brief The base class for representing SPIR-V instructions.
 class SpirvInstruction {
 public:
@@ -122,6 +125,11 @@ public:
     IK_VectorShuffle,             // OpVectorShuffle
   };
 
+  // All instruction classes should include a releaseMemory method.
+  // This is needed in order to avoid leaking memory for classes that include
+  // members that are not trivially destructible.
+  virtual void releaseMemory() = 0;
+
   virtual ~SpirvInstruction() = default;
 
   // Invokes SPIR-V visitor on this instruction.
@@ -218,6 +226,8 @@ class SpirvCapability : public SpirvInstruction {
 public:
   SpirvCapability(SourceLocation loc, spv::Capability cap);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvCapability)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Capability;
@@ -238,6 +248,8 @@ class SpirvExtension : public SpirvInstruction {
 public:
   SpirvExtension(SourceLocation loc, llvm::StringRef extensionName);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExtension)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Extension;
@@ -258,6 +270,8 @@ class SpirvExtInstImport : public SpirvInstruction {
 public:
   SpirvExtInstImport(SourceLocation loc, llvm::StringRef extensionName);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExtInstImport)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ExtInstImport;
@@ -276,6 +290,8 @@ class SpirvMemoryModel : public SpirvInstruction {
 public:
   SpirvMemoryModel(spv::AddressingModel addrModel, spv::MemoryModel memModel);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvMemoryModel)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_MemoryModel;
@@ -298,6 +314,8 @@ public:
                   SpirvFunction *entryPoint, llvm::StringRef nameStr,
                   llvm::ArrayRef<SpirvVariable *> iface);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEntryPoint)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_EntryPoint;
@@ -324,6 +342,8 @@ public:
                      spv::ExecutionMode, llvm::ArrayRef<uint32_t> params,
                      bool usesIdParams);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionMode)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ExecutionMode;
@@ -346,6 +366,8 @@ class SpirvString : public SpirvInstruction {
 public:
   SpirvString(SourceLocation loc, llvm::StringRef stringLiteral);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvString)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_String;
@@ -365,6 +387,8 @@ public:
   SpirvSource(SourceLocation loc, spv::SourceLanguage language, uint32_t ver,
               SpirvString *file, llvm::StringRef src);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSource)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Source;
@@ -390,6 +414,8 @@ class SpirvModuleProcessed : public SpirvInstruction {
 public:
   SpirvModuleProcessed(SourceLocation loc, llvm::StringRef processStr);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvModuleProcessed)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ModuleProcessed;
@@ -418,6 +444,8 @@ public:
                   spv::Decoration decor,
                   llvm::ArrayRef<SpirvInstruction *> params);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvDecoration)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Decoration;
@@ -458,6 +486,8 @@ public:
                 spv::StorageClass sc, bool isPrecise,
                 SpirvInstruction *initializerId = 0);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvVariable)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Variable;
@@ -489,6 +519,8 @@ public:
   SpirvFunctionParameter(const SpirvType *spvType, bool isPrecise,
                          SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvFunctionParameter)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_FunctionParameter;
@@ -521,6 +553,8 @@ public:
   SpirvLoopMerge(SourceLocation loc, SpirvBasicBlock *mergeBlock,
                  SpirvBasicBlock *contTarget, spv::LoopControlMask mask);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvLoopMerge)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_LoopMerge;
@@ -541,6 +575,8 @@ public:
   SpirvSelectionMerge(SourceLocation loc, SpirvBasicBlock *mergeBlock,
                       spv::SelectionControlMask mask);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSelectionMerge)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_SelectionMerge;
@@ -596,6 +632,8 @@ class SpirvBranch : public SpirvBranching {
 public:
   SpirvBranch(SourceLocation loc, SpirvBasicBlock *target);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBranch)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Branch;
@@ -622,6 +660,8 @@ public:
                          SpirvBasicBlock *trueLabel,
                          SpirvBasicBlock *falseLabel);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBranchConditional)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_BranchConditional;
@@ -648,6 +688,8 @@ class SpirvKill : public SpirvTerminator {
 public:
   SpirvKill(SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvKill)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Kill;
@@ -661,6 +703,8 @@ class SpirvReturn : public SpirvTerminator {
 public:
   SpirvReturn(SourceLocation loc, SpirvInstruction *retVal = 0);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvReturn)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Return;
@@ -683,6 +727,8 @@ public:
       SpirvBasicBlock *defaultLabel,
       llvm::ArrayRef<std::pair<uint32_t, SpirvBasicBlock *>> &targetsVec);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSwitch)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Switch;
@@ -711,6 +757,8 @@ class SpirvUnreachable : public SpirvTerminator {
 public:
   SpirvUnreachable(SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvUnreachable)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Unreachable;
@@ -729,6 +777,8 @@ public:
                    SpirvInstruction *base,
                    llvm::ArrayRef<SpirvInstruction *> indexVec);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAccessChain)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_AccessChain;
@@ -776,6 +826,8 @@ public:
               spv::MemorySemanticsMask semanticsUnequal,
               SpirvInstruction *value, SpirvInstruction *comparator);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvAtomic)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Atomic;
@@ -813,6 +865,8 @@ public:
                spv::MemorySemanticsMask memorySemantics,
                llvm::Optional<spv::Scope> executionScope = llvm::None);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBarrier)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Barrier;
@@ -900,6 +954,8 @@ public:
   SpirvBinaryOp(spv::Op opcode, QualType resultType, SourceLocation loc,
                 SpirvInstruction *op1, SpirvInstruction *op2);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBinaryOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_BinaryOp;
@@ -951,6 +1007,8 @@ public:
                        SpirvInstruction *base, SpirvInstruction *offset,
                        SpirvInstruction *count, bool isSigned);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBitFieldExtract)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_BitFieldExtract;
@@ -969,6 +1027,8 @@ public:
                       SpirvInstruction *base, SpirvInstruction *insert,
                       SpirvInstruction *offset, SpirvInstruction *count);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBitFieldInsert)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_BitFieldInsert;
@@ -1001,6 +1061,8 @@ class SpirvConstantBoolean : public SpirvConstant {
 public:
   SpirvConstantBoolean(QualType type, bool value, bool isSpecConst = false);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantBoolean)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ConstantBoolean;
@@ -1022,6 +1084,8 @@ public:
   SpirvConstantInteger(QualType type, llvm::APInt value,
                        bool isSpecConst = false);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantInteger)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ConstantInteger;
@@ -1042,6 +1106,8 @@ public:
   SpirvConstantFloat(QualType type, llvm::APFloat value,
                      bool isSpecConst = false);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantFloat)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ConstantFloat;
@@ -1063,6 +1129,8 @@ public:
                          llvm::ArrayRef<SpirvConstant *> constituents,
                          bool isSpecConst = false);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantComposite)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ConstantComposite;
@@ -1082,6 +1150,8 @@ class SpirvConstantNull : public SpirvConstant {
 public:
   SpirvConstantNull(QualType type);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvConstantNull)
+
   bool invokeVisitor(Visitor *v) override;
 
   // For LLVM-style RTTI
@@ -1098,6 +1168,8 @@ public:
   SpirvCompositeConstruct(QualType resultType, SourceLocation loc,
                           llvm::ArrayRef<SpirvInstruction *> constituentsVec);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvCompositeConstruct)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_CompositeConstruct;
@@ -1120,6 +1192,8 @@ public:
                         SpirvInstruction *composite,
                         llvm::ArrayRef<uint32_t> indices);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvCompositeExtract)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_CompositeExtract;
@@ -1142,6 +1216,8 @@ public:
                        SpirvInstruction *composite, SpirvInstruction *object,
                        llvm::ArrayRef<uint32_t> indices);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvCompositeInsert)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_CompositeInsert;
@@ -1164,6 +1240,8 @@ class SpirvEmitVertex : public SpirvInstruction {
 public:
   SpirvEmitVertex(SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEmitVertex)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_EmitVertex;
@@ -1177,6 +1255,8 @@ class SpirvEndPrimitive : public SpirvInstruction {
 public:
   SpirvEndPrimitive(SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvEndPrimitive)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_EndPrimitive;
@@ -1191,6 +1271,8 @@ public:
   SpirvExtInst(QualType resultType, SourceLocation loc, SpirvExtInstImport *set,
                uint32_t inst, llvm::ArrayRef<SpirvInstruction *> operandsVec);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExtInst)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ExtInst;
@@ -1215,6 +1297,8 @@ public:
                     SpirvFunction *function,
                     llvm::ArrayRef<SpirvInstruction *> argsVec);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvFunctionCall)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_FunctionCall;
@@ -1256,6 +1340,8 @@ public:
                           SourceLocation loc, spv::Scope scope,
                           SpirvInstruction *arg1, SpirvInstruction *arg2);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformBinaryOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_GroupNonUniformBinaryOp;
@@ -1278,6 +1364,8 @@ public:
   SpirvNonUniformElect(QualType resultType, SourceLocation loc,
                        spv::Scope scope);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformElect)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_GroupNonUniformElect;
@@ -1294,6 +1382,8 @@ public:
                          llvm::Optional<spv::GroupOperation> group,
                          SpirvInstruction *arg);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvNonUniformUnaryOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_GroupNonUniformUnaryOp;
@@ -1353,6 +1443,8 @@ public:
                SpirvInstruction *component = nullptr,
                SpirvInstruction *texelToWrite = nullptr);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvImageOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ImageOp;
@@ -1427,6 +1519,8 @@ public:
                   SpirvInstruction *img, SpirvInstruction *lod = nullptr,
                   SpirvInstruction *coord = nullptr);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvImageQuery)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ImageQuery;
@@ -1452,6 +1546,8 @@ public:
   SpirvImageSparseTexelsResident(QualType resultType, SourceLocation loc,
                                  SpirvInstruction *resCode);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvImageSparseTexelsResident)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ImageSparseTexelsResident;
@@ -1475,6 +1571,8 @@ public:
                          SpirvInstruction *image, SpirvInstruction *coordinate,
                          SpirvInstruction *sample);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvImageTexelPointer)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ImageTexelPointer;
@@ -1498,6 +1596,8 @@ public:
   SpirvLoad(QualType resultType, SourceLocation loc, SpirvInstruction *pointer,
             llvm::Optional<spv::MemoryAccessMask> mask = llvm::None);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvLoad)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Load;
@@ -1522,6 +1622,8 @@ public:
   SpirvCopyObject(QualType resultType, SourceLocation loc,
                   SpirvInstruction *pointer);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvCopyObject)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_CopyObject;
@@ -1544,6 +1646,8 @@ public:
   SpirvSampledImage(QualType resultType, SourceLocation loc,
                     SpirvInstruction *image, SpirvInstruction *sampler);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSampledImage)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_SampledImage;
@@ -1565,6 +1669,8 @@ public:
   SpirvSelect(QualType resultType, SourceLocation loc, SpirvInstruction *cond,
               SpirvInstruction *trueId, SpirvInstruction *falseId);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSelect)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Select;
@@ -1589,6 +1695,8 @@ public:
                             SourceLocation loc, SpirvInstruction *operand1,
                             SpirvInstruction *operand2);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSpecConstantBinaryOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_SpecConstantBinaryOp;
@@ -1612,6 +1720,8 @@ public:
   SpirvSpecConstantUnaryOp(spv::Op specConstantOp, QualType resultType,
                            SourceLocation loc, SpirvInstruction *operand);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvSpecConstantUnaryOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_SpecConstantUnaryOp;
@@ -1634,6 +1744,8 @@ public:
              SpirvInstruction *object,
              llvm::Optional<spv::MemoryAccessMask> mask = llvm::None);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvStore)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_Store;
@@ -1697,6 +1809,8 @@ public:
   SpirvUnaryOp(spv::Op opcode, QualType resultType, SourceLocation loc,
                SpirvInstruction *op);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvUnaryOp)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_UnaryOp;
@@ -1718,6 +1832,8 @@ public:
                      SpirvInstruction *vec1, SpirvInstruction *vec2,
                      llvm::ArrayRef<uint32_t> componentsVec);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvVectorShuffle)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_VectorShuffle;
@@ -1740,6 +1856,8 @@ public:
   SpirvArrayLength(QualType resultType, SourceLocation loc,
                    SpirvInstruction *structure, uint32_t arrayMember);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvArrayLength)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_ArrayLength;
@@ -1765,6 +1883,8 @@ public:
                       llvm::ArrayRef<SpirvInstruction *> vecOperands,
                       SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvRayTracingOpNV)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_RayTracingOpNV;
@@ -1777,12 +1897,15 @@ public:
 private:
   llvm::SmallVector<SpirvInstruction *, 4> operands;
 };
+
 class SpirvRayQueryOpKHR : public SpirvInstruction {
 public:
   SpirvRayQueryOpKHR(QualType resultType, spv::Op opcode,
                      llvm::ArrayRef<SpirvInstruction *> vecOperands, bool flags,
                      SourceLocation loc);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvRayQueryOpKHR)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_RayQueryOpKHR;
@@ -1810,6 +1933,8 @@ class SpirvDemoteToHelperInvocationEXT : public SpirvInstruction {
 public:
   SpirvDemoteToHelperInvocationEXT(SourceLocation);
 
+  DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvDemoteToHelperInvocationEXT)
+
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
     return inst->getKind() == IK_DemoteToHelperInvocationEXT;

+ 17 - 4
tools/clang/include/clang/SPIRV/SpirvModule.h

@@ -76,7 +76,7 @@ struct CapabilityComparisonInfo {
 class SpirvModule {
 public:
   SpirvModule();
-  ~SpirvModule() = default;
+  ~SpirvModule();
 
   // Forbid copy construction and assignment
   SpirvModule(const SpirvModule &) = delete;
@@ -90,10 +90,16 @@ public:
   bool invokeVisitor(Visitor *, bool reverseOrder = false);
 
   // Add a function to the list of module functions.
+  void addFunctionToListOfSortedModuleFunctions(SpirvFunction *);
+
+  // Adds the given function to the vector of all discovered functions. Calling
+  // this function will not result in emitting the function.
   void addFunction(SpirvFunction *);
 
   // Add a capability to the list of module capabilities.
-  void addCapability(SpirvCapability *cap);
+  // Returns true if the capability was added.
+  // Returns false otherwise (e.g. if the capability already existed).
+  bool addCapability(SpirvCapability *cap);
 
   // Set the memory model of the module.
   void setMemoryModel(SpirvMemoryModel *model);
@@ -104,8 +110,9 @@ public:
   // Adds an execution mode to the module.
   void addExecutionMode(SpirvExecutionMode *);
 
-  // Adds an extension to the module.
-  void addExtension(SpirvExtension *);
+  // Adds an extension to the module. Returns true if the extension was added.
+  // Returns false otherwise (e.g. if the extension already existed).
+  bool addExtension(SpirvExtension *);
 
   // Adds an extended instruction set to the module.
   void addExtInstSet(SpirvExtInstImport *);
@@ -169,7 +176,13 @@ private:
 
   std::vector<SpirvConstant *> constants;
   std::vector<SpirvVariable *> variables;
+  // A vector of functions in the module in the order that they should be
+  // emitted. The order starts with the entry-point function followed by a
+  // depth-first discovery of functions reachable from the entry-point function.
   std::vector<SpirvFunction *> functions;
+  // A vector of all functions that have been visited in the AST tree. This
+  // vector is not in any particular order, and may contain unused functions.
+  llvm::SetVector<SpirvFunction *> allFunctions;
 };
 
 } // end namespace spirv

+ 1 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -1142,7 +1142,7 @@ SpirvFunction *DeclResultIdMapper::getOrRegisterFn(const FunctionDecl *fn) {
   // definition is seen, the parameter types will be set properly and take into
   // account whether the function is a member function of a class/struct (in
   // which case a 'this' parameter is added at the beginnig).
-  SpirvFunction *spirvFunction = new (spvContext) SpirvFunction(
+  SpirvFunction *spirvFunction = spvBuilder.createSpirvFunction(
       fn->getReturnType(), fn->getLocation(), fn->getName(), isPrecise);
 
   // No need to dereference to get the pointer. Function returns that are

+ 5 - 10
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -161,9 +161,8 @@ void EmitVisitor::emitDebugLine(spv::Op op, const SourceLocation &loc) {
     auto it = debugFileIdMap.find(fileName);
     if (it == debugFileIdMap.end()) {
       // Emit the OpString for this new fileName.
-      SpirvString *inst =
-          new (context) SpirvString(/*SourceLocation*/ {}, fileName);
-      visit(inst);
+      SpirvString inst(/*SourceLocation*/ {}, fileName);
+      visit(&inst);
       it = debugFileIdMap.find(fileName);
     }
     fileId = it->second;
@@ -435,13 +434,9 @@ bool EmitVisitor::visit(SpirvString *inst) {
 }
 
 bool EmitVisitor::visit(SpirvSource *inst) {
-  // Emit the OpString for the file name.
-  if (inst->hasFile()) {
-    visit(inst->getFile());
-
-    if (spvOptions.debugInfoLine && !debugMainFileId)
-      debugMainFileId = debugFileIdMap[inst->getFile()->getString()];
-  }
+  // Set the debugMainFileId.
+  if (inst->hasFile() && spvOptions.debugInfoLine && !debugMainFileId)
+    debugMainFileId = debugFileIdMap[inst->getFile()->getString()];
 
   // Chop up the source into multiple segments if it is too long.
   llvm::Optional<llvm::StringRef> firstSnippet = llvm::None;

+ 5 - 0
tools/clang/lib/SPIRV/SpirvBasicBlock.cpp

@@ -17,6 +17,11 @@ SpirvBasicBlock::SpirvBasicBlock(llvm::StringRef name)
     : labelId(0), labelName(name), mergeTarget(nullptr),
       continueTarget(nullptr) {}
 
+SpirvBasicBlock::~SpirvBasicBlock() {
+  for (auto instructionNode : instructions)
+    instructionNode.instruction->releaseMemory();
+}
+
 bool SpirvBasicBlock::hasTerminator() const {
   return !instructions.empty() &&
          isa<SpirvTerminator>(instructions.back().instruction);

+ 16 - 20
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -23,9 +23,16 @@ namespace spirv {
 
 SpirvBuilder::SpirvBuilder(ASTContext &ac, SpirvContext &ctx,
                            const SpirvCodeGenOptions &opt)
-    : astContext(ac), context(ctx), mod(nullptr), function(nullptr),
-      spirvOptions(opt), builtinVars(), stringLiterals() {
-  mod = new (context) SpirvModule;
+    : astContext(ac), context(ctx), mod(llvm::make_unique<SpirvModule>()),
+      function(nullptr), spirvOptions(opt), builtinVars(), stringLiterals() {}
+
+SpirvFunction *SpirvBuilder::createSpirvFunction(QualType returnType,
+                                                 SourceLocation loc,
+                                                 llvm::StringRef name,
+                                                 bool isPrecise) {
+  auto *fn = new (context) SpirvFunction(returnType, loc, name, isPrecise);
+  mod->addFunction(fn);
+  return fn;
 }
 
 SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
@@ -41,8 +48,7 @@ SpirvFunction *SpirvBuilder::beginFunction(QualType returnType,
     function->setFunctionName(funcName);
     function->setPrecise(isPrecise);
   } else {
-    function =
-        new (context) SpirvFunction(returnType, loc, funcName, isPrecise);
+    function = createSpirvFunction(returnType, loc, funcName, isPrecise);
   }
 
   return function;
@@ -77,10 +83,9 @@ SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
   if (isBindlessOpaqueArray(valueType)) {
     // If it is a bindless array of an opaque type, we have to use
     // a pointer to a pointer of the runtime array.
-    var = new (context)
-        SpirvVariable(context.getPointerType(
-                          valueType, spv::StorageClass::UniformConstant),
-                      loc, spv::StorageClass::Function, isPrecise, init);
+    var = new (context) SpirvVariable(
+        context.getPointerType(valueType, spv::StorageClass::UniformConstant),
+        loc, spv::StorageClass::Function, isPrecise, init);
   } else {
     var = new (context) SpirvVariable(
         valueType, loc, spv::StorageClass::Function, isPrecise, init);
@@ -92,16 +97,7 @@ SpirvVariable *SpirvBuilder::addFnVar(QualType valueType, SourceLocation loc,
 
 void SpirvBuilder::endFunction() {
   assert(function && "no active function");
-
-  // Move all basic blocks into the current function.
-  // TODO: we should adjust the order the basic blocks according to
-  // SPIR-V validation rules.
-  for (auto *bb : basicBlocks) {
-    function->addBasicBlock(bb);
-  }
-  basicBlocks.clear();
-
-  mod->addFunction(function);
+  mod->addFunctionToListOfSortedModuleFunctions(function);
   function = nullptr;
   insertPoint = nullptr;
 }
@@ -109,7 +105,7 @@ void SpirvBuilder::endFunction() {
 SpirvBasicBlock *SpirvBuilder::createBasicBlock(llvm::StringRef name) {
   assert(function && "found detached basic block");
   auto *bb = new (context) SpirvBasicBlock(name);
-  basicBlocks.push_back(bb);
+  function->addBasicBlock(bb);
   return bb;
 }
 

+ 71 - 3
tools/clang/lib/SPIRV/SpirvContext.cpp

@@ -27,6 +27,66 @@ SpirvContext::SpirvContext()
   rayQueryProvisionalTypeKHR = new (this) RayQueryProvisionalTypeKHR;
 }
 
+SpirvContext::~SpirvContext() {
+  voidType->~VoidType();
+  boolType->~BoolType();
+  samplerType->~SamplerType();
+  accelerationStructureTypeNV->~AccelerationStructureTypeNV();
+  rayQueryProvisionalTypeKHR->~RayQueryProvisionalTypeKHR();
+
+  for (auto *sintType : sintTypes)
+    if (sintType) // sintTypes may contain nullptr
+      sintType->~IntegerType();
+
+  for (auto *uintType : uintTypes)
+    if (uintType) // uintTypes may contain nullptr
+      uintType->~IntegerType();
+
+  for (auto *floatType : floatTypes)
+    if (floatType) // floatTypes may contain nullptr
+      floatType->~FloatType();
+
+  for (auto &pair : vecTypes)
+    for (auto *vecType : pair.second)
+      if (vecType) // vecTypes may contain nullptr
+        vecType->~VectorType();
+
+  for (auto &pair : matTypes)
+    for (auto *matType : pair.second)
+      matType->~MatrixType();
+
+  for (auto *arrType : arrayTypes)
+    arrType->~ArrayType();
+
+  for (auto *raType : runtimeArrayTypes)
+    raType->~RuntimeArrayType();
+
+  for (auto *fnType : functionTypes)
+    fnType->~FunctionType();
+
+  for (auto *structType : structTypes)
+    structType->~StructType();
+
+  for (auto *hybridStructType : hybridStructTypes)
+    hybridStructType->~HybridStructType();
+
+  for (auto pair : sampledImageTypes)
+    pair.second->~SampledImageType();
+
+  for (auto *hybridSampledImageType : hybridSampledImageTypes)
+    hybridSampledImageType->~HybridSampledImageType();
+
+  for (auto *imgType : imageTypes)
+    imgType->~ImageType();
+
+  for (auto &pair : pointerTypes)
+    for (auto &scPtrTypePair : pair.second)
+      scPtrTypePair.second->~SpirvPointerType();
+
+  for (auto *hybridPtrType : hybridPointerTypes)
+    hybridPtrType->~HybridPointerType();
+}
+
 inline uint32_t log2ForBitwidth(uint32_t bitwidth) {
   assert(bitwidth >= 16 && bitwidth <= 64 && llvm::isPowerOf2_32(bitwidth));
 
@@ -148,7 +208,10 @@ SpirvContext::getSampledImageType(const ImageType *image) {
 
 const HybridSampledImageType *
 SpirvContext::getSampledImageType(QualType image) {
-  return new (this) HybridSampledImageType(image);
+  const HybridSampledImageType *result =
+      new (this) HybridSampledImageType(image);
+  hybridSampledImageTypes.push_back(result);
+  return result;
 }
 
 const ArrayType *
@@ -207,7 +270,10 @@ SpirvContext::getStructType(llvm::ArrayRef<StructType::FieldInfo> fields,
 const HybridStructType *SpirvContext::getHybridStructType(
     llvm::ArrayRef<HybridStructType::FieldInfo> fields, llvm::StringRef name,
     bool isReadOnly, StructInterfaceType interfaceType) {
-  return new (this) HybridStructType(fields, name, isReadOnly, interfaceType);
+  const HybridStructType *result =
+      new (this) HybridStructType(fields, name, isReadOnly, interfaceType);
+  hybridStructTypes.push_back(result);
+  return result;
 }
 
 const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
@@ -227,7 +293,9 @@ const SpirvPointerType *SpirvContext::getPointerType(const SpirvType *pointee,
 
 const HybridPointerType *SpirvContext::getPointerType(QualType pointee,
                                                       spv::StorageClass sc) {
-  return new (this) HybridPointerType(pointee, sc);
+  const HybridPointerType *result = new (this) HybridPointerType(pointee, sc);
+  hybridPointerTypes.push_back(result);
+  return result;
 }
 
 FunctionType *

+ 9 - 0
tools/clang/lib/SPIRV/SpirvFunction.cpp

@@ -22,6 +22,15 @@ SpirvFunction::SpirvFunction(QualType returnType, SourceLocation loc,
       containsAlias(false), rvalue(false), functionLoc(loc),
       functionName(name) {}
 
+SpirvFunction::~SpirvFunction() {
+  for (auto *param : parameters)
+    param->releaseMemory();
+  for (auto *var : variables)
+    var->releaseMemory();
+  for (auto *bb : basicBlocks)
+    bb->~SpirvBasicBlock();
+}
+
 bool SpirvFunction::invokeVisitor(Visitor *visitor, bool reverseOrder) {
   if (!visitor->visit(this, Visitor::Phase::Init))
     return false;

+ 39 - 5
tools/clang/lib/SPIRV/SpirvModule.cpp

@@ -19,6 +19,35 @@ SpirvModule::SpirvModule()
       entryPoints({}), executionModes({}), moduleProcesses({}), decorations({}),
       constants({}), variables({}), functions({}) {}
 
+SpirvModule::~SpirvModule() {
+  for (auto *cap : capabilities)
+    cap->releaseMemory();
+  for (auto *ext : extensions)
+    ext->releaseMemory();
+  for (auto *set : extInstSets)
+    set->releaseMemory();
+  if (memoryModel)
+    memoryModel->releaseMemory();
+  for (auto *entry : entryPoints)
+    entry->releaseMemory();
+  for (auto *exec : executionModes)
+    exec->releaseMemory();
+  for (auto *str : constStrings)
+    str->releaseMemory();
+  for (auto *d : debugSources)
+    d->releaseMemory();
+  for (auto *mp : moduleProcesses)
+    mp->releaseMemory();
+  for (auto *decoration : decorations)
+    decoration->releaseMemory();
+  for (auto *constant : constants)
+    constant->releaseMemory();
+  for (auto *var : variables)
+    var->releaseMemory();
+  for (auto *f : allFunctions)
+    f->~SpirvFunction();
+}
+
 bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
   // Note: It is debatable whether reverse order of visiting the module should
   // reverse everything in this method. For the time being, we just reverse the
@@ -180,14 +209,19 @@ bool SpirvModule::invokeVisitor(Visitor *visitor, bool reverseOrder) {
   return true;
 }
 
-void SpirvModule::addFunction(SpirvFunction *fn) {
+void SpirvModule::addFunctionToListOfSortedModuleFunctions(SpirvFunction *fn) {
   assert(fn && "cannot add null function to the module");
   functions.push_back(fn);
 }
 
-void SpirvModule::addCapability(SpirvCapability *cap) {
+void SpirvModule::addFunction(SpirvFunction *fn) {
+  assert(fn && "cannot add null function to the module");
+  allFunctions.insert(fn);
+}
+
+bool SpirvModule::addCapability(SpirvCapability *cap) {
   assert(cap && "cannot add null capability to the module");
-  capabilities.insert(cap);
+  return capabilities.insert(cap);
 }
 
 void SpirvModule::setMemoryModel(SpirvMemoryModel *model) {
@@ -205,9 +239,9 @@ void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
   executionModes.push_back(em);
 }
 
-void SpirvModule::addExtension(SpirvExtension *ext) {
+bool SpirvModule::addExtension(SpirvExtension *ext) {
   assert(ext && "cannot add null extension");
-  extensions.insert(ext);
+  return extensions.insert(ext);
 }
 
 void SpirvModule::addExtInstSet(SpirvExtInstImport *set) {

+ 34 - 23
tools/clang/unittests/SPIRV/SpirvBasicBlockTest.cpp

@@ -9,6 +9,8 @@
 
 #include "clang/SPIRV/SpirvBasicBlock.h"
 #include "clang/SPIRV/SpirvInstruction.h"
+
+#include "SpirvTestBase.h"
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
@@ -16,32 +18,34 @@ using namespace clang::spirv;
 
 namespace {
 
-TEST(SpirvBasicBlockTest, CheckName) {
+class SpirvBasicBlockTest : public SpirvTestBase {};
+
+TEST_F(SpirvBasicBlockTest, CheckName) {
   SpirvBasicBlock bb("myBasicBlock");
   EXPECT_EQ(bb.getName(), "myBasicBlock");
 }
 
-TEST(SpirvBasicBlockTest, CheckResultId) {
+TEST_F(SpirvBasicBlockTest, CheckResultId) {
   SpirvBasicBlock bb("myBasicBlock");
   bb.setResultId(5);
   EXPECT_EQ(bb.getResultId(), 5u);
 }
 
-TEST(SpirvBasicBlockTest, CheckMergeTarget) {
+TEST_F(SpirvBasicBlockTest, CheckMergeTarget) {
   SpirvBasicBlock bb1("bb1");
   SpirvBasicBlock bb2("bb2");
   bb1.setMergeTarget(&bb2);
   EXPECT_EQ(bb1.getMergeTarget(), &bb2);
 }
 
-TEST(SpirvBasicBlockTest, CheckContinueTarget) {
+TEST_F(SpirvBasicBlockTest, CheckContinueTarget) {
   SpirvBasicBlock bb1("bb1");
   SpirvBasicBlock bb2("bb2");
   bb1.setContinueTarget(&bb2);
   EXPECT_EQ(bb1.getContinueTarget(), &bb2);
 }
 
-TEST(SpirvBasicBlockTest, CheckSuccessors) {
+TEST_F(SpirvBasicBlockTest, CheckSuccessors) {
   SpirvBasicBlock bb1("bb1");
   SpirvBasicBlock bb2("bb2");
   SpirvBasicBlock bb3("bb3");
@@ -52,45 +56,52 @@ TEST(SpirvBasicBlockTest, CheckSuccessors) {
   EXPECT_EQ(successors[1], &bb3);
 }
 
-TEST(SpirvBasicBlockTest, CheckTerminatedByKill) {
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByKill) {
   SpirvBasicBlock bb("bb");
-  SpirvKill kill({});
-  bb.addInstruction(&kill);
+  SpirvContext &context = getSpirvContext();
+  auto *kill = new (context) SpirvKill({});
+  bb.addInstruction(kill);
   EXPECT_TRUE(bb.hasTerminator());
 }
 
-TEST(SpirvBasicBlockTest, CheckTerminatedByBranch) {
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByBranch) {
   SpirvBasicBlock bb("bb");
-  SpirvBranch branch({}, nullptr);
-  bb.addInstruction(&branch);
+  SpirvContext &context = getSpirvContext();
+  auto *branch = new (context) SpirvBranch({}, nullptr);
+  bb.addInstruction(branch);
   EXPECT_TRUE(bb.hasTerminator());
 }
 
-TEST(SpirvBasicBlockTest, CheckTerminatedByBranchConditional) {
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByBranchConditional) {
   SpirvBasicBlock bb("bb");
-  SpirvBranchConditional branch({}, nullptr, nullptr, nullptr);
-  bb.addInstruction(&branch);
+  SpirvContext &context = getSpirvContext();
+  auto *branch =
+      new (context) SpirvBranchConditional({}, nullptr, nullptr, nullptr);
+  bb.addInstruction(branch);
   EXPECT_TRUE(bb.hasTerminator());
 }
 
-TEST(SpirvBasicBlockTest, CheckTerminatedByReturn) {
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByReturn) {
   SpirvBasicBlock bb("bb");
-  SpirvReturn returnInstr({});
-  bb.addInstruction(&returnInstr);
+  SpirvContext &context = getSpirvContext();
+  auto *returnInstr = new (context) SpirvReturn({});
+  bb.addInstruction(returnInstr);
   EXPECT_TRUE(bb.hasTerminator());
 }
 
-TEST(SpirvBasicBlockTest, CheckTerminatedByUnreachable) {
+TEST_F(SpirvBasicBlockTest, CheckTerminatedByUnreachable) {
   SpirvBasicBlock bb("bb");
-  SpirvUnreachable unreachable({});
-  bb.addInstruction(&unreachable);
+  SpirvContext &context = getSpirvContext();
+  auto *unreachable = new (context) SpirvUnreachable({});
+  bb.addInstruction(unreachable);
   EXPECT_TRUE(bb.hasTerminator());
 }
 
-TEST(SpirvBasicBlockTest, CheckNotTerminated) {
+TEST_F(SpirvBasicBlockTest, CheckNotTerminated) {
   SpirvBasicBlock bb("bb");
-  SpirvLoad load({}, {}, nullptr);
-  bb.addInstruction(&load);
+  SpirvContext &context = getSpirvContext();
+  auto *load = new (context) SpirvLoad({}, {}, nullptr);
+  bb.addInstruction(load);
   EXPECT_FALSE(bb.hasTerminator());
 }