Jelajahi Sumber

[spirv] Fully translate constant pixel shader (#448)

* Added support for Constant and SpecConstant
* Fixes duplicate constant issue
* Fixes type and constant inter-dependency when emitting SPIR-V
* Added support for translating a pixel shader that returns a
  constant float4.
Ehsan 8 tahun lalu
induk
melakukan
4e1e9a78db

+ 31 - 0
tools/clang/include/clang/SPIRV/BitwiseCast.h

@@ -0,0 +1,31 @@
+//===-- BitwiseCast.h - Bitwise cast ----------------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_CLANG_SPIRV_BITWISECAST_H
+#define LLVM_CLANG_SPIRV_BITWISECAST_H
+
+#include <cstring>
+
+namespace clang {
+namespace spirv {
+namespace cast {
+
+/// \brief Performs bitwise copy of source to the destination type Dest.
+template <typename Dest, typename Src> Dest BitwiseCast(Src source) {
+  Dest dest;
+  static_assert(sizeof(source) == sizeof(dest),
+                "BitwiseCast: source and destination must have the same size.");
+  std::memcpy(&dest, &source, sizeof(dest));
+  return dest;
+}
+
+} // end namespace cast
+} // end namespace spirv
+} // end namespace clang
+
+#endif

+ 124 - 0
tools/clang/include/clang/SPIRV/Constant.h

@@ -0,0 +1,124 @@
+//===-- Constant.h - SPIR-V Constant ----------------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+#ifndef LLVM_CLANG_SPIRV_CONSTANT_H
+#define LLVM_CLANG_SPIRV_CONSTANT_H
+
+#include <set>
+#include <unordered_set>
+#include <vector>
+
+#include "spirv/1.0/spirv.hpp11"
+#include "clang/SPIRV/Decoration.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/Optional.h"
+
+namespace clang {
+namespace spirv {
+
+class SPIRVContext;
+
+/// \brief SPIR-V Constant
+///
+/// This class defines a unique SPIR-V constant.
+/// A SPIR-V constant includes its <opcode> defined by the SPIR-V Spec.
+/// It also incldues any arguments (32-bit words) needed to initialize that
+/// constant. It also includes a set of decorations that are applied to
+/// that constant.
+///
+/// The class includes static getXXX(...) functions for getting pointers of any
+/// needed constant. A unique constant has a unique pointer (e.g. calling
+/// 'getTrue' function will always return the same pointer for the given
+/// context).
+class Constant {
+public:
+  using DecorationSet = std::set<const Decoration *>;
+
+  spv::Op getOpcode() const { return opcode; }
+  uint32_t getTypeId() const { return typeId; }
+  const std::vector<uint32_t> &getArgs() const { return args; }
+  const DecorationSet &getDecorations() const { return decorations; }
+  bool hasDecoration(const Decoration *) const;
+
+  // OpConstantTrue and OpConstantFalse are boolean.
+  // OpSpecConstantTrue and OpSpecConstantFalse are boolean.
+  bool isBoolean() const;
+
+  // OpConstant and OpSpecConstant are only allowed to take integers and floats.
+  bool isNumerical() const;
+
+  // OpConstantComposite and OpSpecConstantComposite.
+  bool isComposite() const;
+
+  // Get constants.
+  static const Constant *getTrue(SPIRVContext &ctx, uint32_t type_id,
+                                 DecorationSet dec = {});
+  static const Constant *getFalse(SPIRVContext &ctx, uint32_t type_id,
+                                  DecorationSet dec = {});
+  static const Constant *getInt32(SPIRVContext &ctx, uint32_t type_id,
+                                  int32_t value, DecorationSet dec = {});
+  static const Constant *getUint32(SPIRVContext &ctx, uint32_t type_id,
+                                   uint32_t value, DecorationSet dec = {});
+  static const Constant *getFloat32(SPIRVContext &ctx, uint32_t type_id,
+                                    float value, DecorationSet dec = {});
+
+  // TODO: 64-bit float and integer constant implementation
+
+  static const Constant *getComposite(SPIRVContext &ctx, uint32_t type_id,
+                                      llvm::ArrayRef<uint32_t> constituents,
+                                      DecorationSet dec = {});
+  static const Constant *getSampler(SPIRVContext &ctx, uint32_t type_id,
+                                    spv::SamplerAddressingMode, uint32_t param,
+                                    spv::SamplerFilterMode,
+                                    DecorationSet dec = {});
+  static const Constant *getNull(SPIRVContext &ctx, uint32_t type_id,
+                                 DecorationSet dec = {});
+
+  // Get specialization constants.
+  static const Constant *getSpecTrue(SPIRVContext &ctx, uint32_t type_id,
+                                     DecorationSet dec = {});
+  static const Constant *getSpecFalse(SPIRVContext &ctx, uint32_t type_id,
+                                      DecorationSet dec = {});
+  static const Constant *getSpecInt32(SPIRVContext &ctx, uint32_t type_id,
+                                      int32_t value, DecorationSet dec = {});
+  static const Constant *getSpecUint32(SPIRVContext &ctx, uint32_t type_id,
+                                       uint32_t value, DecorationSet dec = {});
+  static const Constant *getSpecFloat32(SPIRVContext &ctx, uint32_t type_id,
+                                        float value, DecorationSet dec = {});
+  static const Constant *getSpecComposite(SPIRVContext &ctx, uint32_t type_id,
+                                          llvm::ArrayRef<uint32_t> constituents,
+                                          DecorationSet dec = {});
+
+  bool operator==(const Constant &other) const {
+    return opcode == other.opcode && args == other.args &&
+           decorations == other.decorations;
+  }
+
+  // \brief Construct the SPIR-V words for this constant with the given
+  // <result-id>.
+  std::vector<uint32_t> withResultId(uint32_t resultId) const;
+
+private:
+  /// \brief Private constructor.
+  Constant(spv::Op, uint32_t type, llvm::ArrayRef<uint32_t> arg = {},
+           std::set<const Decoration *> dec = {});
+
+  /// \brief Returns the unique constant pointer within the given context.
+  static const Constant *getUniqueConstant(SPIRVContext &, const Constant &);
+
+private:
+  spv::Op opcode;             ///< OpCode of the constant
+  uint32_t typeId;            ///< <result-id> of the type of the constant
+  std::vector<uint32_t> args; ///< Arguments needed to define the constant
+  DecorationSet decorations;  ///< Decorations applied to the constant
+};
+
+} // end namespace spirv
+} // end namespace clang
+
+#endif

+ 1 - 1
tools/clang/include/clang/SPIRV/InstBuilder.h

@@ -832,4 +832,4 @@ private:
 } // end namespace spirv
 } // end namespace spirv
 } // end namespace clang
 } // end namespace clang
 
 
-#endif
+#endif

+ 6 - 2
tools/clang/include/clang/SPIRV/ModuleBuilder.h

@@ -128,6 +128,7 @@ public:
   // === Type ===
   // === Type ===
 
 
   uint32_t getVoidType();
   uint32_t getVoidType();
+  uint32_t getUint32Type();
   uint32_t getInt32Type();
   uint32_t getInt32Type();
   uint32_t getFloatType();
   uint32_t getFloatType();
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
   uint32_t getVecType(uint32_t elemType, uint32_t elemCount);
@@ -137,8 +138,11 @@ public:
                            const std::vector<uint32_t> &paramTypes);
                            const std::vector<uint32_t> &paramTypes);
 
 
   // === Constant ===
   // === Constant ===
-
-  uint32_t getInt32Value(uint32_t value);
+  uint32_t getConstantFloat32(float value);
+  uint32_t getConstantInt32(int32_t value);
+  uint32_t getConstantUint32(uint32_t value);
+  uint32_t getConstantComposite(uint32_t typeId,
+                                llvm::ArrayRef<uint32_t> constituents);
 
 
 private:
 private:
   /// \brief Map from basic blocks' <label-id> to their structured
   /// \brief Map from basic blocks' <label-id> to their structured

+ 25 - 0
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -12,6 +12,7 @@
 #include <unordered_map>
 #include <unordered_map>
 
 
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/Frontend/FrontendAction.h"
+#include "clang/SPIRV/Constant.h"
 #include "clang/SPIRV/Decoration.h"
 #include "clang/SPIRV/Decoration.h"
 #include "clang/SPIRV/Type.h"
 #include "clang/SPIRV/Type.h"
 
 
@@ -30,6 +31,12 @@ struct DecorationHash {
     return std::hash<uint32_t>{}(static_cast<uint32_t>(d.getValue()));
     return std::hash<uint32_t>{}(static_cast<uint32_t>(d.getValue()));
   }
   }
 };
 };
+struct ConstantHash {
+  std::size_t operator()(const Constant &c) const {
+    // TODO: We could improve this hash function if necessary.
+    return std::hash<uint32_t>{}(static_cast<uint32_t>(c.getTypeId()));
+  }
+};
 
 
 /// \brief A class for holding various data needed in SPIR-V codegen.
 /// \brief A class for holding various data needed in SPIR-V codegen.
 /// It should outlive all SPIR-V codegen components that requires/allocates
 /// It should outlive all SPIR-V codegen components that requires/allocates
@@ -54,16 +61,25 @@ public:
   /// has not been defined, it will define and store its instruction.
   /// has not been defined, it will define and store its instruction.
   uint32_t getResultIdForType(const Type *);
   uint32_t getResultIdForType(const Type *);
 
 
+  /// \brief Returns the <result-id> that defines the given Constant. If the
+  /// constant has not been defined, it will define and return its result-id.
+  uint32_t getResultIdForConstant(const Constant *);
+
   /// \brief Registers the existence of the given type in the current context,
   /// \brief Registers the existence of the given type in the current context,
   /// and returns the unique Type pointer.
   /// and returns the unique Type pointer.
   const Type *registerType(const Type &);
   const Type *registerType(const Type &);
 
 
+  /// \brief Registers the existence of the given constant in the current
+  /// context, and returns the unique pointer to it.
+  const Constant *registerConstant(const Constant &);
+
   /// \brief Registers the existence of the given decoration in the current
   /// \brief Registers the existence of the given decoration in the current
   /// context, and returns the unique Decoration pointer.
   /// context, and returns the unique Decoration pointer.
   const Decoration *registerDecoration(const Decoration &);
   const Decoration *registerDecoration(const Decoration &);
 
 
 private:
 private:
   using TypeSet = std::unordered_set<Type, TypeHash>;
   using TypeSet = std::unordered_set<Type, TypeHash>;
+  using ConstantSet = std::unordered_set<Constant, ConstantHash>;
   using DecorationSet = std::unordered_set<Decoration, DecorationHash>;
   using DecorationSet = std::unordered_set<Decoration, DecorationHash>;
 
 
   uint32_t nextId;
   uint32_t nextId;
@@ -74,10 +90,19 @@ private:
   /// \brief All the unique types defined in the current context.
   /// \brief All the unique types defined in the current context.
   TypeSet existingTypes;
   TypeSet existingTypes;
 
 
+  /// \brief All constants defined in the current context.
+  /// These can be boolean, integer, float, or composite constants.
+  ConstantSet existingConstants;
+
   /// \brief Maps a given type to the <result-id> that is defined for
   /// \brief Maps a given type to the <result-id> that is defined for
   /// that type. If a Type* does not exist in the map, the type
   /// that type. If a Type* does not exist in the map, the type
   /// is not yet defined and is not associated with a <result-id>.
   /// is not yet defined and is not associated with a <result-id>.
   std::unordered_map<const Type *, uint32_t> typeResultIdMap;
   std::unordered_map<const Type *, uint32_t> typeResultIdMap;
+
+  /// \brief Maps a given constant to the <result-id> that is defined for
+  /// that constant. If a Constant* does not exist in the map, the constant
+  /// is not yet defined and is not associated with a <result-id>.
+  std::unordered_map<const Constant *, uint32_t> constantResultIdMap;
 };
 };
 
 
 SPIRVContext::SPIRVContext() : nextId(1) {}
 SPIRVContext::SPIRVContext() : nextId(1) {}

+ 19 - 15
tools/clang/include/clang/SPIRV/Structure.h

@@ -24,6 +24,7 @@
 #include <vector>
 #include <vector>
 
 
 #include "spirv/1.0/spirv.hpp11"
 #include "spirv/1.0/spirv.hpp11"
+#include "clang/SPIRV/Constant.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/InstBuilder.h"
 #include "clang/SPIRV/Type.h"
 #include "clang/SPIRV/Type.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -188,14 +189,6 @@ struct TypeIdPair {
   const uint32_t resultId;
   const uint32_t resultId;
 };
 };
 
 
-/// \brief The struct representing a constant and its type.
-struct Constant {
-  inline Constant(const Type &ty, Instruction &&value);
-
-  const Type &type;
-  Instruction constant;
-};
-
 /// \brief The class representing a SPIR-V module.
 /// \brief The class representing a SPIR-V module.
 class SPIRVModule {
 class SPIRVModule {
 public:
 public:
@@ -238,10 +231,23 @@ public:
                            llvm::Optional<uint32_t> memberIndex = llvm::None);
                            llvm::Optional<uint32_t> memberIndex = llvm::None);
   inline void addDecoration(const Decoration &decoration, uint32_t targetId);
   inline void addDecoration(const Decoration &decoration, uint32_t targetId);
   inline void addType(const Type *type, uint32_t resultId);
   inline void addType(const Type *type, uint32_t resultId);
-  inline void addConstant(const Type &type, Instruction &&constant);
+  inline void addConstant(const Constant *constant, uint32_t resultId);
   inline void addVariable(Instruction &&);
   inline void addVariable(Instruction &&);
   inline void addFunction(std::unique_ptr<Function>);
   inline void addFunction(std::unique_ptr<Function>);
 
 
+private:
+  /// \brief Collects all the Integer type definitions in this module and
+  /// consumes them using the consumer within the given InstBuilder.
+  /// After this method is called, all integer types are remove from the list of
+  /// types in this object.
+  void takeIntegerTypes(InstBuilder *builder);
+
+  /// \brief Finds the constant on which the given array type depends.
+  /// If found, (a) defines the constant by passing it to the consumer in the
+  /// given InstBuilder. (b) Removes the constant from the list of constants
+  /// in this object.
+  void takeConstantForArrayType(const Type *arrType, InstBuilder *ib);
+
 private:
 private:
   Header header; ///< SPIR-V module header.
   Header header; ///< SPIR-V module header.
   std::vector<spv::Capability> capabilities;
   std::vector<spv::Capability> capabilities;
@@ -262,7 +268,8 @@ private:
   // their corresponding types. We store types and constants separately, but
   // their corresponding types. We store types and constants separately, but
   // they should be handled together.
   // they should be handled together.
   llvm::MapVector<const Type *, uint32_t> types;
   llvm::MapVector<const Type *, uint32_t> types;
-  std::vector<Constant> constants;
+  llvm::MapVector<const Constant *, uint32_t> constants;
+
   std::vector<Instruction> variables;
   std::vector<Instruction> variables;
   std::vector<std::unique_ptr<Function>> functions;
   std::vector<std::unique_ptr<Function>> functions;
 };
 };
@@ -329,9 +336,6 @@ DecorationIdPair::DecorationIdPair(const Decoration &decor, uint32_t id)
 
 
 TypeIdPair::TypeIdPair(const Type &ty, uint32_t id) : type(ty), resultId(id) {}
 TypeIdPair::TypeIdPair(const Type &ty, uint32_t id) : type(ty), resultId(id) {}
 
 
-Constant::Constant(const Type &ty, Instruction &&value)
-    : type(ty), constant(std::move(value)) {}
-
 SPIRVModule::SPIRVModule()
 SPIRVModule::SPIRVModule()
     : addressingModel(llvm::None), memoryModel(llvm::None) {}
     : addressingModel(llvm::None), memoryModel(llvm::None) {}
 
 
@@ -371,8 +375,8 @@ void SPIRVModule::addDecoration(const Decoration &decoration,
 void SPIRVModule::addType(const Type *type, uint32_t resultId) {
 void SPIRVModule::addType(const Type *type, uint32_t resultId) {
   types.insert(std::make_pair(type, resultId));
   types.insert(std::make_pair(type, resultId));
 }
 }
-void SPIRVModule::addConstant(const Type &type, Instruction &&constant) {
-  constants.emplace_back(type, std::move(constant));
+void SPIRVModule::addConstant(const Constant *constant, uint32_t resultId) {
+  constants.insert(std::make_pair(constant, resultId));
 };
 };
 void SPIRVModule::addVariable(Instruction &&var) {
 void SPIRVModule::addVariable(Instruction &&var) {
   variables.push_back(std::move(var));
   variables.push_back(std::move(var));

+ 2 - 4
tools/clang/include/clang/SPIRV/Type.h

@@ -40,9 +40,7 @@ public:
 
 
   spv::Op getOpcode() const { return opcode; }
   spv::Op getOpcode() const { return opcode; }
   const std::vector<uint32_t> &getArgs() const { return args; }
   const std::vector<uint32_t> &getArgs() const { return args; }
-  const std::set<const Decoration *> &getDecorations() const {
-    return decorations;
-  }
+  const DecorationSet &getDecorations() const { return decorations; }
   bool hasDecoration(const Decoration *) const;
   bool hasDecoration(const Decoration *) const;
 
 
   bool isBooleanType() const;
   bool isBooleanType() const;
@@ -131,7 +129,7 @@ private:
 private:
 private:
   spv::Op opcode;             ///< OpCode of the Type defined in SPIR-V Spec
   spv::Op opcode;             ///< OpCode of the Type defined in SPIR-V Spec
   std::vector<uint32_t> args; ///< Arguments needed to define the type
   std::vector<uint32_t> args; ///< Arguments needed to define the type
-  std::set<const Decoration *> decorations; ///< decorations applied to the type
+  DecorationSet decorations;  ///< decorations applied to the type
 };
 };
 
 
 } // end namespace spirv
 } // end namespace spirv

+ 1 - 0
tools/clang/lib/SPIRV/CMakeLists.txt

@@ -3,6 +3,7 @@ set(LLVM_LINK_COMPONENTS
   )
   )
 
 
 add_clang_library(clangSPIRV
 add_clang_library(clangSPIRV
+  Constant.cpp
   Decoration.cpp
   Decoration.cpp
   EmitSPIRVAction.cpp
   EmitSPIRVAction.cpp
   InstBuilderAuto.cpp
   InstBuilderAuto.cpp

+ 159 - 0
tools/clang/lib/SPIRV/Constant.cpp

@@ -0,0 +1,159 @@
+//===--- Constant.cpp - SPIR-V Constant implementation --------------------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/SPIRV/Constant.h"
+#include "clang/SPIRV/BitwiseCast.h"
+#include "clang/SPIRV/SPIRVContext.h"
+
+namespace clang {
+namespace spirv {
+
+Constant::Constant(spv::Op op, uint32_t type, llvm::ArrayRef<uint32_t> arg,
+                   std::set<const Decoration *> decs)
+    : opcode(op), typeId(type), args(arg), decorations(decs) {}
+
+const Constant *Constant::getUniqueConstant(SPIRVContext &context,
+                                            const Constant &c) {
+  return context.registerConstant(c);
+}
+
+const Constant *Constant::getTrue(SPIRVContext &ctx, uint32_t type_id,
+                                  DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpConstantTrue, type_id, {}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getFalse(SPIRVContext &ctx, uint32_t type_id,
+                                   DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpConstantFalse, type_id, {}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getFloat32(SPIRVContext &ctx, uint32_t type_id,
+                                     float value, DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {cast::BitwiseCast<uint32_t, float>(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getUint32(SPIRVContext &ctx, uint32_t type_id,
+                                    uint32_t value, DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpConstant, type_id, {value}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getInt32(SPIRVContext &ctx, uint32_t type_id,
+                                   int32_t value, DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpConstant, type_id,
+                        {cast::BitwiseCast<uint32_t, int32_t>(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getComposite(SPIRVContext &ctx, uint32_t type_id,
+                                       llvm::ArrayRef<uint32_t> constituents,
+                                       DecorationSet dec) {
+  Constant c =
+      Constant(spv::Op::OpConstantComposite, type_id, constituents, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getSampler(SPIRVContext &ctx, uint32_t type_id,
+                                     spv::SamplerAddressingMode sam,
+                                     uint32_t param, spv::SamplerFilterMode sfm,
+                                     DecorationSet dec) {
+  Constant c = Constant(
+      spv::Op::OpConstantSampler, type_id,
+      {static_cast<uint32_t>(sam), param, static_cast<uint32_t>(sfm)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getNull(SPIRVContext &ctx, uint32_t type_id,
+                                  DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpConstantNull, type_id, {}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getSpecTrue(SPIRVContext &ctx, uint32_t type_id,
+                                      DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpSpecConstantTrue, type_id, {}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getSpecFalse(SPIRVContext &ctx, uint32_t type_id,
+                                       DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpSpecConstantFalse, type_id, {}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getSpecFloat32(SPIRVContext &ctx, uint32_t type_id,
+                                         float value, DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpSpecConstant, type_id,
+                        {cast::BitwiseCast<uint32_t, float>(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getSpecUint32(SPIRVContext &ctx, uint32_t type_id,
+                                        uint32_t value, DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpSpecConstant, type_id, {value}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *Constant::getSpecInt32(SPIRVContext &ctx, uint32_t type_id,
+                                       int32_t value, DecorationSet dec) {
+  Constant c = Constant(spv::Op::OpSpecConstant, type_id,
+                        {cast::BitwiseCast<uint32_t, int32_t>(value)}, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+const Constant *
+Constant::getSpecComposite(SPIRVContext &ctx, uint32_t type_id,
+                           llvm::ArrayRef<uint32_t> constituents,
+                           DecorationSet dec) {
+  Constant c =
+      Constant(spv::Op::OpSpecConstantComposite, type_id, constituents, dec);
+  return getUniqueConstant(ctx, c);
+}
+
+bool Constant::hasDecoration(const Decoration *d) const {
+  return decorations.find(d) != decorations.end();
+}
+
+bool Constant::isBoolean() const {
+  return (opcode == spv::Op::OpConstantTrue ||
+          opcode == spv::Op::OpConstantFalse ||
+          opcode == spv::Op::OpSpecConstantTrue ||
+          opcode == spv::Op::OpSpecConstantFalse);
+}
+
+bool Constant::isNumerical() const {
+  return (opcode == spv::Op::OpConstant || opcode == spv::Op::OpSpecConstant);
+}
+
+bool Constant::isComposite() const {
+  return (opcode == spv::Op::OpConstantComposite ||
+          opcode == spv::Op::OpSpecConstantComposite);
+}
+
+std::vector<uint32_t> Constant::withResultId(uint32_t resultId) const {
+  std::vector<uint32_t> words;
+
+  // TODO: we are essentially duplicate the work InstBuilder is responsible for.
+  // Should figure out a way to unify them.
+  words.reserve(3 + args.size());
+  words.push_back(static_cast<uint32_t>(opcode));
+  words.push_back(typeId);
+  words.push_back(resultId);
+  words.insert(words.end(), args.begin(), args.end());
+  words.front() |= static_cast<uint32_t>(words.size()) << 16;
+
+  return words;
+}
+
+} // end namespace spirv
+} // end namespace clang

+ 32 - 2
tools/clang/lib/SPIRV/EmitSPIRVAction.cpp

@@ -608,7 +608,7 @@ public:
         const uint32_t ptrType = theBuilder.getPointerType(
         const uint32_t ptrType = theBuilder.getPointerType(
             typeTranslator.translateType(field->getType()),
             typeTranslator.translateType(field->getType()),
             spv::StorageClass::Function);
             spv::StorageClass::Function);
-        const uint32_t indexId = theBuilder.getInt32Value(fieldIndex++);
+        const uint32_t indexId = theBuilder.getConstantInt32(fieldIndex++);
         const uint32_t valuePtr =
         const uint32_t valuePtr =
             theBuilder.createAccessChain(ptrType, retValue, {indexId});
             theBuilder.createAccessChain(ptrType, retValue, {indexId});
         const uint32_t value = theBuilder.createLoad(valueType, valuePtr);
         const uint32_t value = theBuilder.createLoad(valueType, valuePtr);
@@ -644,7 +644,8 @@ public:
       const uint32_t base = doExpr(memberExpr->getBase());
       const uint32_t base = doExpr(memberExpr->getBase());
       auto *memberDecl = memberExpr->getMemberDecl();
       auto *memberDecl = memberExpr->getMemberDecl();
       if (auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
       if (auto *fieldDecl = dyn_cast<FieldDecl>(memberDecl)) {
-        const auto index = theBuilder.getInt32Value(fieldDecl->getFieldIndex());
+        const auto index =
+            theBuilder.getConstantInt32(fieldDecl->getFieldIndex());
         const uint32_t fieldType =
         const uint32_t fieldType =
             typeTranslator.translateType(fieldDecl->getType());
             typeTranslator.translateType(fieldDecl->getType());
         const uint32_t ptrType =
         const uint32_t ptrType =
@@ -654,6 +655,35 @@ public:
         emitError("Decl '%0' in MemberExpr is not supported yet.")
         emitError("Decl '%0' in MemberExpr is not supported yet.")
             << memberDecl->getDeclKindName();
             << memberDecl->getDeclKindName();
       }
       }
+    } else if (auto *cxxFunctionalCastExpr =
+                   dyn_cast<CXXFunctionalCastExpr>(expr)) {
+      // Explicit cast is a NO-OP (e.g. vector<float, 4> -> float4)
+      if (cxxFunctionalCastExpr->getCastKind() == CK_NoOp) {
+        return doExpr(cxxFunctionalCastExpr->getSubExpr());
+      } else {
+        emitError("Found unhandled CXXFunctionalCastExpr cast type: %0")
+            << cxxFunctionalCastExpr->getCastKindName();
+      }
+    } else if (auto *initListExpr = dyn_cast<InitListExpr>(expr)) {
+      const bool isConstantInitializer = expr->isConstantInitializer(
+          theCompilerInstance.getASTContext(), false);
+      const uint32_t resultType =
+          typeTranslator.translateType(initListExpr->getType());
+      std::vector<uint32_t> constituents;
+      for (size_t i = 0; i < initListExpr->getNumInits(); ++i) {
+        constituents.push_back(doExpr(initListExpr->getInit(i)));
+      }
+      if (isConstantInitializer) {
+        return theBuilder.getConstantComposite(resultType, constituents);
+      } else {
+        // TODO: use OpCompositeConstruct if it is not a constant initializer
+        // list.
+        emitError("Non-const initializer lists are currently not supported.");
+      }
+    } else if (auto *floatingLiteral = dyn_cast<FloatingLiteral>(expr)) {
+      // TODO: use floatingLiteral->getType() to also handle float64 cases.
+      const float value = floatingLiteral->getValue().convertToFloat();
+      return theBuilder.getConstantFloat32(value);
     }
     }
     emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
     emitError("Expr '%0' is not supported yet.") << expr->getStmtClassName();
     // TODO: handle other expressions
     // TODO: handle other expressions

+ 1 - 1
tools/clang/lib/SPIRV/InstBuilderAuto.cpp

@@ -7536,4 +7536,4 @@ InstBuilder &InstBuilder::literalString(std::string value) {
 }
 }
 
 
 } // end namespace spirv
 } // end namespace spirv
-} // end namespace clang
+} // end namespace clang

+ 42 - 9
tools/clang/lib/SPIRV/ModuleBuilder.cpp

@@ -142,6 +142,41 @@ void ModuleBuilder::createReturnValue(uint32_t value) {
   insertPoint->appendInstruction(std::move(constructSite));
   insertPoint->appendInstruction(std::move(constructSite));
 }
 }
 
 
+uint32_t
+ModuleBuilder::getConstantComposite(uint32_t typeId,
+                                    llvm::ArrayRef<uint32_t> constituents) {
+  const Constant *constant =
+      Constant::getComposite(theContext, typeId, constituents);
+  const uint32_t constId = theContext.getResultIdForConstant(constant);
+  theModule.addConstant(constant, constId);
+  return constId;
+}
+
+uint32_t ModuleBuilder::getConstantFloat32(float value) {
+  const uint32_t floatTypeId = getFloatType();
+  const Constant *constant =
+      Constant::getFloat32(theContext, floatTypeId, value);
+  const uint32_t constId = theContext.getResultIdForConstant(constant);
+  theModule.addConstant(constant, constId);
+  return constId;
+}
+
+uint32_t ModuleBuilder::getConstantInt32(int32_t value) {
+  const uint32_t intTypeId = getInt32Type();
+  const Constant *constant = Constant::getInt32(theContext, intTypeId, value);
+  const uint32_t constId = theContext.getResultIdForConstant(constant);
+  theModule.addConstant(constant, constId);
+  return constId;
+}
+
+uint32_t ModuleBuilder::getConstantUint32(uint32_t value) {
+  const uint32_t uintTypeId = getUint32Type();
+  const Constant *constant = Constant::getUint32(theContext, uintTypeId, value);
+  const uint32_t constId = theContext.getResultIdForConstant(constant);
+  theModule.addConstant(constant, constId);
+  return constId;
+}
+
 uint32_t ModuleBuilder::getVoidType() {
 uint32_t ModuleBuilder::getVoidType() {
   const Type *type = Type::getVoid(theContext);
   const Type *type = Type::getVoid(theContext);
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
@@ -149,6 +184,13 @@ uint32_t ModuleBuilder::getVoidType() {
   return typeId;
   return typeId;
 }
 }
 
 
+uint32_t ModuleBuilder::getUint32Type() {
+  const Type *type = Type::getUint32(theContext);
+  const uint32_t typeId = theContext.getResultIdForType(type);
+  theModule.addType(type, typeId);
+  return typeId;
+}
+
 uint32_t ModuleBuilder::getInt32Type() {
 uint32_t ModuleBuilder::getInt32Type() {
   const Type *type = Type::getInt32(theContext);
   const Type *type = Type::getInt32(theContext);
   const uint32_t typeId = theContext.getResultIdForType(type);
   const uint32_t typeId = theContext.getResultIdForType(type);
@@ -211,15 +253,6 @@ uint32_t ModuleBuilder::getPointerType(uint32_t pointeeType,
   return typeId;
   return typeId;
 }
 }
 
 
-uint32_t ModuleBuilder::getInt32Value(uint32_t value) {
-  const Type *i32Type = Type::getInt32(theContext);
-  const uint32_t i32TypeId = getInt32Type();
-  const uint32_t constantId = theContext.takeNextId();
-  instBuilder.opConstant(i32TypeId, constantId, value).x();
-  theModule.addConstant(*i32Type, std::move(constructSite));
-  return constantId;
-}
-
 uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
 uint32_t ModuleBuilder::addStageIOVariable(uint32_t type,
                                            spv::StorageClass storageClass) {
                                            spv::StorageClass storageClass) {
   const uint32_t pointerType = getPointerType(type, storageClass);
   const uint32_t pointerType = getPointerType(type, storageClass);

+ 24 - 0
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -32,6 +32,23 @@ uint32_t SPIRVContext::getResultIdForType(const Type *t) {
   return result_id;
   return result_id;
 }
 }
 
 
+uint32_t SPIRVContext::getResultIdForConstant(const Constant *c) {
+  assert(c != nullptr);
+  uint32_t result_id = 0;
+
+  auto iter = constantResultIdMap.find(c);
+  if (iter == constantResultIdMap.end()) {
+    // The constant has not been defined yet. Reserve an ID for it.
+    result_id = takeNextId();
+    constantResultIdMap[c] = result_id;
+  } else {
+    result_id = iter->second;
+  }
+
+  assert(result_id != 0);
+  return result_id;
+}
+
 const Type *SPIRVContext::registerType(const Type &t) {
 const Type *SPIRVContext::registerType(const Type &t) {
   // Insert function will only insert if it doesn't already exist in the set.
   // Insert function will only insert if it doesn't already exist in the set.
   TypeSet::iterator it;
   TypeSet::iterator it;
@@ -39,6 +56,13 @@ const Type *SPIRVContext::registerType(const Type &t) {
   return &*it;
   return &*it;
 }
 }
 
 
+const Constant *SPIRVContext::registerConstant(const Constant &c) {
+  // Insert function will only insert if it doesn't already exist in the set.
+  ConstantSet::iterator it;
+  std::tie(it, std::ignore) = existingConstants.insert(c);
+  return &*it;
+}
+
 const Decoration *SPIRVContext::registerDecoration(const Decoration &d) {
 const Decoration *SPIRVContext::registerDecoration(const Decoration &d) {
   // Insert function will only insert if it doesn't already exist in the set.
   // Insert function will only insert if it doesn't already exist in the set.
   DecorationSet::iterator it;
   DecorationSet::iterator it;

+ 46 - 3
tools/clang/lib/SPIRV/Structure.cpp

@@ -159,6 +159,36 @@ void SPIRVModule::clear() {
   functions.clear();
   functions.clear();
 }
 }
 
 
+void SPIRVModule::takeIntegerTypes(InstBuilder *ib) {
+  const auto &consumer = ib->getConsumer();
+  // If it finds any integer type, feeds it into the consumer, and removes it
+  // from the types collection.
+  types.remove_if([&consumer](std::pair<const Type *, uint32_t> &item) {
+    const bool isInteger = item.first->isIntegerType();
+    if (isInteger)
+      consumer(item.first->withResultId(item.second));
+    return isInteger;
+  });
+}
+
+void SPIRVModule::takeConstantForArrayType(const Type *arrType,
+                                           InstBuilder *ib) {
+  assert(arrType->isArrayType() &&
+         "takeConstantForArrayType was called with a non-array type.");
+  const auto &consumer = ib->getConsumer();
+  const uint32_t arrayLengthResultId = arrType->getArgs().back();
+
+  // If it finds the constant, feeds it into the consumer, and removes it
+  // from the constants collection.
+  constants.remove_if([&consumer, arrayLengthResultId](
+                          std::pair<const Constant *, uint32_t> &item) {
+    const bool isArrayLengthConstant = (item.second == arrayLengthResultId);
+    if (isArrayLengthConstant)
+      consumer(item.first->withResultId(item.second));
+    return isArrayLengthConstant;
+  });
+}
+
 void SPIRVModule::take(InstBuilder *builder) {
 void SPIRVModule::take(InstBuilder *builder) {
   const auto &consumer = builder->getConsumer();
   const auto &consumer = builder->getConsumer();
 
 
@@ -206,14 +236,27 @@ void SPIRVModule::take(InstBuilder *builder) {
     consumer(d.decoration.withTargetId(d.targetId));
     consumer(d.decoration.withTargetId(d.targetId));
   }
   }
 
 
-  // TODO: handle the interdependency between types and constants
+  // Note on interdependence of types and constants:
+  // There is only one type (OpTypeArray) that requires the result-id of a
+  // constant. As a result, the constant integer should be defined before the
+  // array is defined. The integer type should also be defined before the
+  // constant integer is defined.
+
+  // First define all integer types
+  takeIntegerTypes(builder);
 
 
   for (const auto &t : types) {
   for (const auto &t : types) {
+    // If we have an array type, we must first define the integer constant that
+    // defines its length.
+    if (t.first->isArrayType()) {
+      takeConstantForArrayType(t.first, builder);
+    }
+
     consumer(t.first->withResultId(t.second));
     consumer(t.first->withResultId(t.second));
   }
   }
 
 
-  for (auto &c : constants) {
-    consumer(std::move(c.constant));
+  for (const auto &c : constants) {
+    consumer(c.first->withResultId(c.second));
   }
   }
 
 
   for (auto &v : variables) {
   for (auto &v : variables) {

+ 34 - 0
tools/clang/test/CodeGenSPIRV/constant-ps.hlsl2spv

@@ -0,0 +1,34 @@
+// Run: %dxc -T ps_6_0 -E main
+float4 main(): SV_TARGET
+{
+  return float4(1.0f, 2.0f, 3.5f, 4.7f);
+}
+
+// CHECK-WHOLE-SPIR-V:
+// ; SPIR-V
+// ; Version: 1.0
+// ; Generator: Google spiregg; 0
+// ; Bound: 14
+// ; Schema: 0
+// OpCapability Shader
+// OpMemoryModel Logical GLSL450
+// OpEntryPoint Fragment %main "main" %4
+// OpExecutionMode %main OriginUpperLeft
+// OpName %main "main"
+// OpDecorate %4 Location 0
+// %float = OpTypeFloat 32
+// %v4float = OpTypeVector %float 4
+// %_ptr_Output_v4float = OpTypePointer Output %v4float
+// %void = OpTypeVoid
+// %6 = OpTypeFunction %void
+// %float_1 = OpConstant %float 1
+// %float_2 = OpConstant %float 2
+// %float_3_5 = OpConstant %float 3.5
+// %float_4_7 = OpConstant %float 4.7
+// %13 = OpConstantComposite %v4float %float_1 %float_2 %float_3_5 %float_4_7
+// %4 = OpVariable %_ptr_Output_v4float Output
+// %main = OpFunction %void None %6
+// %8 = OpLabel
+// OpStore %4 %13
+// OpReturn
+// OpFunctionEnd

+ 8 - 13
tools/clang/test/CodeGenSPIRV/passthru-vs.hlsl2spv

@@ -12,15 +12,12 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
   return result;
   return result;
 }
 }
 
 
-// TODO:
-// Deduplicate integer constants
-
 
 
 // CHECK-WHOLE-SPIR-V:
 // CHECK-WHOLE-SPIR-V:
 // ; SPIR-V
 // ; SPIR-V
 // ; Version: 1.0
 // ; Version: 1.0
 // ; Generator: Google spiregg; 0
 // ; Generator: Google spiregg; 0
-// ; Bound: 30
+// ; Bound: 28
 // ; Schema: 0
 // ; Schema: 0
 // OpCapability Shader
 // OpCapability Shader
 // OpMemoryModel Logical GLSL450
 // OpMemoryModel Logical GLSL450
@@ -30,6 +27,7 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // OpDecorate %7 Location 0
 // OpDecorate %7 Location 0
 // OpDecorate %8 Location 1
 // OpDecorate %8 Location 1
 // OpDecorate %5 Location 0
 // OpDecorate %5 Location 0
+// %int = OpTypeInt 32 1
 // %float = OpTypeFloat 32
 // %float = OpTypeFloat 32
 // %v4float = OpTypeVector %float 4
 // %v4float = OpTypeVector %float 4
 // %_ptr_Output_v4float = OpTypePointer Output %v4float
 // %_ptr_Output_v4float = OpTypePointer Output %v4float
@@ -38,12 +36,9 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %10 = OpTypeFunction %void
 // %10 = OpTypeFunction %void
 // %_struct_13 = OpTypeStruct %v4float %v4float
 // %_struct_13 = OpTypeStruct %v4float %v4float
 // %_ptr_Function__struct_13 = OpTypePointer Function %_struct_13
 // %_ptr_Function__struct_13 = OpTypePointer Function %_struct_13
-// %int = OpTypeInt 32 1
 // %_ptr_Function_v4float = OpTypePointer Function %v4float
 // %_ptr_Function_v4float = OpTypePointer Function %v4float
 // %int_0 = OpConstant %int 0
 // %int_0 = OpConstant %int 0
 // %int_1 = OpConstant %int 1
 // %int_1 = OpConstant %int 1
-// %int_0_0 = OpConstant %int 0
-// %int_1_0 = OpConstant %int 1
 // %gl_Position = OpVariable %_ptr_Output_v4float Output
 // %gl_Position = OpVariable %_ptr_Output_v4float Output
 // %5 = OpVariable %_ptr_Output_v4float Output
 // %5 = OpVariable %_ptr_Output_v4float Output
 // %7 = OpVariable %_ptr_Input_v4float Input
 // %7 = OpVariable %_ptr_Input_v4float Input
@@ -57,11 +52,11 @@ PSInput VSmain(float4 position: POSITION, float4 color: COLOR) {
 // %22 = OpAccessChain %_ptr_Function_v4float %15 %int_1
 // %22 = OpAccessChain %_ptr_Function_v4float %15 %int_1
 // %23 = OpLoad %v4float %8
 // %23 = OpLoad %v4float %8
 // OpStore %22 %23
 // OpStore %22 %23
-// %25 = OpAccessChain %_ptr_Function_v4float %15 %int_0_0
-// %26 = OpLoad %v4float %25
-// OpStore %gl_Position %26
-// %28 = OpAccessChain %_ptr_Function_v4float %15 %int_1_0
-// %29 = OpLoad %v4float %28
-// OpStore %5 %29
+// %24 = OpAccessChain %_ptr_Function_v4float %15 %int_0
+// %25 = OpLoad %v4float %24
+// OpStore %gl_Position %25
+// %26 = OpAccessChain %_ptr_Function_v4float %15 %int_1
+// %27 = OpLoad %v4float %26
+// OpStore %5 %27
 // OpReturn
 // OpReturn
 // OpFunctionEnd
 // OpFunctionEnd

+ 1 - 0
tools/clang/unittests/SPIRV/CMakeLists.txt

@@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
 
 
 add_clang_unittest(clang-spirv-tests
 add_clang_unittest(clang-spirv-tests
   CodeGenSPIRVTest.cpp
   CodeGenSPIRVTest.cpp
+  ConstantTest.cpp
   DecorationTest.cpp
   DecorationTest.cpp
   InstBuilderTest.cpp
   InstBuilderTest.cpp
   ModuleBuilderTest.cpp
   ModuleBuilderTest.cpp

+ 6 - 0
tools/clang/unittests/SPIRV/CodeGenSPIRVTest.cpp

@@ -29,3 +29,9 @@ TEST_F(WholeFileTest, PassThruVertexShader) {
                    /*generateHeader*/ true,
                    /*generateHeader*/ true,
                    /*runValidation*/ true);
                    /*runValidation*/ true);
 }
 }
+
+TEST_F(WholeFileTest, ConstantPixelShader) {
+  runWholeFileTest("constant-ps.hlsl2spv",
+                   /*generateHeader*/ true,
+                   /*runValidation*/ false);
+}

+ 265 - 0
tools/clang/unittests/SPIRV/ConstantTest.cpp

@@ -0,0 +1,265 @@
+//===- unittests/SPIRV/ConstantTest.cpp ---------- Constant tests ---------===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "SPIRVTestUtils.h"
+#include "gmock/gmock.h"
+#include "clang/SPIRV/BitwiseCast.h"
+#include "clang/SPIRV/Constant.h"
+#include "clang/SPIRV/SPIRVContext.h"
+#include "gtest/gtest.h"
+
+using namespace clang::spirv;
+
+namespace {
+using ::testing::ElementsAre;
+using ::testing::ContainerEq;
+
+TEST(Constant, True) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getTrue(ctx, 2);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpConstantTrue, {2, 3});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, False) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getFalse(ctx, 2);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpConstantFalse, {2, 3});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, Uint32) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getUint32(ctx, 2, 7u);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpConstant, {2, 3, 7u});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, Int32) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getInt32(ctx, 2, -7);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpConstant, {2, 3, 0xFFFFFFF9});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, Float32) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getFloat32(ctx, 2, 7.0);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(
+      spv::Op::OpConstant, {2, 3, cast::BitwiseCast<uint32_t, float>(7.0)});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, Composite) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getComposite(ctx, 8, {4, 5, 6, 7});
+  const auto result = c->withResultId(9);
+  const auto expected =
+      constructInst(spv::Op::OpConstantComposite, {8, 9, 4, 5, 6, 7});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, Sampler) {
+  SPIRVContext ctx;
+  const Constant *c =
+      Constant::getSampler(ctx, 8, spv::SamplerAddressingMode::Repeat, 1,
+                           spv::SamplerFilterMode::Linear);
+  const auto result = c->withResultId(9);
+  const auto expected = constructInst(
+      spv::Op::OpConstantSampler,
+      {8, 9, static_cast<uint32_t>(spv::SamplerAddressingMode::Repeat), 1,
+       static_cast<uint32_t>(spv::SamplerFilterMode::Linear)});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, Null) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getNull(ctx, 8);
+  const auto result = c->withResultId(9);
+  const auto expected = constructInst(spv::Op::OpConstantNull, {8, 9});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, SpecTrue) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getSpecTrue(ctx, 2);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpSpecConstantTrue, {2, 3});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, SpecFalse) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getSpecFalse(ctx, 2);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpSpecConstantFalse, {2, 3});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, SpecUint32) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getSpecUint32(ctx, 2, 7u);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(spv::Op::OpSpecConstant, {2, 3, 7u});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, SpecInt32) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getSpecInt32(ctx, 2, -7);
+  const auto result = c->withResultId(3);
+  const auto expected =
+      constructInst(spv::Op::OpSpecConstant, {2, 3, 0xFFFFFFF9});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, SpecFloat32) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getSpecFloat32(ctx, 2, 7.0);
+  const auto result = c->withResultId(3);
+  const auto expected = constructInst(
+      spv::Op::OpSpecConstant, {2, 3, cast::BitwiseCast<uint32_t, float>(7.0)});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, SpecComposite) {
+  SPIRVContext ctx;
+  const Constant *c = Constant::getSpecComposite(ctx, 8, {4, 5, 6, 7});
+  const auto result = c->withResultId(9);
+  const auto expected =
+      constructInst(spv::Op::OpSpecConstantComposite, {8, 9, 4, 5, 6, 7});
+  EXPECT_THAT(result, ContainerEq(expected));
+}
+TEST(Constant, DecoratedTrue) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getTrue(ctx, 2, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstantTrue);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_TRUE(c->getArgs().empty());
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedFalse) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getFalse(ctx, 2, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstantFalse);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_TRUE(c->getArgs().empty());
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedUint32) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getUint32(ctx, 2, 7u, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstant);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_THAT(c->getArgs(), ElementsAre(7u));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedInt32) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getInt32(ctx, 2, -7, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstant);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_THAT(c->getArgs(), ElementsAre(0xFFFFFFF9));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedFloat32) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getFloat32(ctx, 2, 7.0f, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstant);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_THAT(c->getArgs(),
+              ElementsAre(cast::BitwiseCast<uint32_t, float>(7.0)));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedComposite) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getComposite(ctx, 8, {4, 5, 6, 7}, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstantComposite);
+  EXPECT_EQ(c->getTypeId(), 8);
+  EXPECT_THAT(c->getArgs(), ElementsAre(4, 5, 6, 7));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSampler) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c =
+      Constant::getSampler(ctx, 8, spv::SamplerAddressingMode::Repeat, 1,
+                           spv::SamplerFilterMode::Linear, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstantSampler);
+  EXPECT_EQ(c->getTypeId(), 8);
+  EXPECT_THAT(
+      c->getArgs(),
+      ElementsAre(static_cast<uint32_t>(spv::SamplerAddressingMode::Repeat), 1,
+                  static_cast<uint32_t>(spv::SamplerFilterMode::Linear)));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedNull) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getNull(ctx, 2, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpConstantNull);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_TRUE(c->getArgs().empty());
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSpecTrue) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getSpecTrue(ctx, 2, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpSpecConstantTrue);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_TRUE(c->getArgs().empty());
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSpecFalse) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getSpecFalse(ctx, 2, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpSpecConstantFalse);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_TRUE(c->getArgs().empty());
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSpecUint32) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getSpecUint32(ctx, 2, 7u, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpSpecConstant);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_THAT(c->getArgs(), ElementsAre(7u));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSpecInt32) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getSpecInt32(ctx, 2, -7, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpSpecConstant);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_THAT(c->getArgs(), ElementsAre(0xFFFFFFF9));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSpecFloat32) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getSpecFloat32(ctx, 2, 7.0f, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpSpecConstant);
+  EXPECT_EQ(c->getTypeId(), 2);
+  EXPECT_THAT(c->getArgs(),
+              ElementsAre(cast::BitwiseCast<uint32_t, float>(7.0)));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+TEST(Constant, DecoratedSpecComposite) {
+  SPIRVContext ctx;
+  const Decoration *d = Decoration::getSpecId(ctx, 5);
+  const Constant *c = Constant::getSpecComposite(ctx, 8, {4, 5, 6, 7}, {d});
+  EXPECT_EQ(c->getOpcode(), spv::Op::OpSpecConstantComposite);
+  EXPECT_EQ(c->getTypeId(), 8);
+  EXPECT_THAT(c->getArgs(), ElementsAre(4, 5, 6, 7));
+  EXPECT_THAT(c->getDecorations(), ElementsAre(d));
+}
+
+} // anonymous namespace

+ 79 - 7
tools/clang/unittests/SPIRV/StructureTest.cpp

@@ -161,6 +161,11 @@ TEST(Structure, TakeModuleHaveAllContents) {
                    {entryPointId,
                    {entryPointId,
                     static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)}));
                     static_cast<uint32_t>(spv::Decoration::RelaxedPrecision)}));
 
 
+  const auto *i32Type = Type::getInt32(context);
+  const uint32_t i32Id = context.getResultIdForType(i32Type);
+  m.addType(i32Type, i32Id);
+  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i32Id, 32, 1}));
+
   const auto *voidType = Type::getVoid(context);
   const auto *voidType = Type::getVoid(context);
   const uint32_t voidId = context.getResultIdForType(voidType);
   const uint32_t voidId = context.getResultIdForType(voidType);
   m.addType(voidType, voidId);
   m.addType(voidType, voidId);
@@ -172,14 +177,9 @@ TEST(Structure, TakeModuleHaveAllContents) {
   appendVector(&expected, constructInst(spv::Op::OpTypeFunction,
   appendVector(&expected, constructInst(spv::Op::OpTypeFunction,
                                         {funcTypeId, voidId, voidId}));
                                         {funcTypeId, voidId, voidId}));
 
 
-  const auto *i32Type = Type::getInt32(context);
-  const uint32_t i32Id = context.getResultIdForType(i32Type);
-  m.addType(i32Type, i32Id);
-  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i32Id, 32, 1}));
-
+  const auto *i32Const = Constant::getInt32(context, i32Id, 42);
   const uint32_t constantId = context.takeNextId();
   const uint32_t constantId = context.takeNextId();
-  m.addConstant(*i32Type,
-                constructInst(spv::Op::OpConstant, {i32Id, constantId, 42}));
+  m.addConstant(i32Const, constantId);
   appendVector(&expected,
   appendVector(&expected,
                constructInst(spv::Op::OpConstant, {i32Id, constantId, 42}));
                constructInst(spv::Op::OpConstant, {i32Id, constantId, 42}));
   // TODO: global variable
   // TODO: global variable
@@ -209,4 +209,76 @@ TEST(Structure, TakeModuleHaveAllContents) {
   EXPECT_TRUE(m.isEmpty());
   EXPECT_TRUE(m.isEmpty());
 }
 }
 
 
+TEST(Structure, TakeModuleWithArrayAndConstantDependency) {
+  SPIRVContext context;
+  auto m = SPIRVModule();
+  std::vector<uint32_t> expected = getModuleHeader(0);
+
+  // Add void type
+  const auto *voidType = Type::getVoid(context);
+  const uint32_t voidId = context.getResultIdForType(voidType);
+  m.addType(voidType, voidId);
+
+  // Add float type
+  const auto *f32Type = Type::getFloat32(context);
+  const uint32_t f32Id = context.getResultIdForType(f32Type);
+  m.addType(f32Type, f32Id);
+
+  // Add int64 type
+  const auto *i64Type = Type::getInt64(context);
+  const uint32_t i64Id = context.getResultIdForType(i64Type);
+  m.addType(i64Type, i64Id);
+
+  // Add int32 type
+  const auto *i32Type = Type::getInt32(context);
+  const uint32_t i32Id = context.getResultIdForType(i32Type);
+  m.addType(i32Type, i32Id);
+
+  // Add 32-bit integer constant (8)
+  const auto *i32Const = Constant::getInt32(context, i32Id, 8);
+  const uint32_t constantId = context.getResultIdForConstant(i32Const);
+  m.addConstant(i32Const, constantId);
+
+  // Add array of 8 32-bit integers type
+  const auto *arrType = Type::getArray(context, i32Id, constantId);
+  const uint32_t arrId = context.getResultIdForType(arrType);
+  m.addType(arrType, arrId);
+  m.setBound(context.getNextId());
+
+  // Add another array of the same size. The constant does not need to be
+  // redefined.
+  const auto *arrStride = Decoration::getArrayStride(context, 4);
+  const auto *secondArrType =
+      Type::getArray(context, i32Id, constantId, {arrStride});
+  const uint32_t secondArrId = context.getResultIdForType(arrType);
+  m.addType(secondArrType, secondArrId);
+  m.addDecoration(*arrStride, secondArrId);
+  m.setBound(context.getNextId());
+
+  // Decorations
+  appendVector(
+      &expected,
+      constructInst(spv::Op::OpDecorate,
+                    {secondArrId,
+                     static_cast<uint32_t>(spv::Decoration::ArrayStride), 4}));
+  // Now the expected order: int64, int32, void, float, constant(8), array
+  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i64Id, 64, 1}));
+  appendVector(&expected, constructInst(spv::Op::OpTypeInt, {i32Id, 32, 1}));
+  appendVector(&expected, constructInst(spv::Op::OpTypeVoid, {voidId}));
+  appendVector(&expected, constructInst(spv::Op::OpTypeFloat, {f32Id, 32}));
+  appendVector(&expected,
+               constructInst(spv::Op::OpConstant, {i32Id, constantId, 8}));
+  appendVector(&expected,
+               constructInst(spv::Op::OpTypeArray, {arrId, i32Id, constantId}));
+  appendVector(&expected, constructInst(spv::Op::OpTypeArray,
+                                        {secondArrId, i32Id, constantId}));
+  expected[3] = context.getNextId();
+
+  std::vector<uint32_t> result;
+  auto ib = constructInstBuilder(result);
+  m.take(&ib);
+
+  EXPECT_THAT(result, ContainerEq(expected));
+  EXPECT_TRUE(m.isEmpty());
+}
 } // anonymous namespace
 } // anonymous namespace