Browse Source

[spirv] SpirvConstant instructions.

Ehsan 6 years ago
parent
commit
6af3eb80ff

+ 5 - 0
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -57,6 +57,11 @@ public:
   bool visit(SpirvBinaryOp *);
   bool visit(SpirvBitFieldExtract *);
   bool visit(SpirvBitFieldInsert *);
+  bool visit(SpirvConstantBoolean *);
+  bool visit(SpirvConstantInteger *);
+  bool visit(SpirvConstantFloat *);
+  bool visit(SpirvConstantComposite *);
+  bool visit(SpirvConstantNull *);
   bool visit(SpirvComposite *);
   bool visit(SpirvCompositeExtract *);
   bool visit(SpirvCompositeInsert *);

+ 3 - 2
tools/clang/include/clang/SPIRV/SpirvBuilder.h

@@ -9,6 +9,7 @@
 #ifndef LLVM_CLANG_SPIRV_SPIRVBUILDER_H
 #define LLVM_CLANG_SPIRV_SPIRVBUILDER_H
 
+#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/SPIRVContext.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
 #include "clang/SPIRV/SpirvFunction.h"
@@ -32,8 +33,8 @@ namespace spirv {
 /// module.
 class SpirvBuilder {
 public:
-  explicit SpirvBuilder(ASTContext &ac, SpirvContext &c, FeatureManager *,
-                        const SpirvCodeGenOptions &);
+  SpirvBuilder(ASTContext &ac, SpirvContext &c, FeatureManager *,
+               const SpirvCodeGenOptions &);
   ~SpirvBuilder() = default;
 
   // Forbid copy construction and assignment

+ 137 - 12
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -44,9 +44,15 @@ public:
     IK_ModuleProcessed, // OpModuleProcessed (debug)
     IK_Decoration,      // Op*Decorate
     IK_Type,            // OpType*
-    IK_Constant,        // OpConstant*
     IK_Variable,        // OpVariable
 
+    // Different kind of constants. Order matters.
+    IK_ConstantBoolean,
+    IK_ConstantInteger,
+    IK_ConstantFloat,
+    IK_ConstantComposite,
+    IK_ConstantNull,
+
     // Function structure kinds
 
     IK_FunctionParameter, // OpFunctionParameter
@@ -414,8 +420,6 @@ protected:
            inst->getKind() == IK_SelectionMerge;
   }
 
-  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvMerge)
-
 private:
   SpirvBasicBlock *mergeBlock;
 };
@@ -476,8 +480,6 @@ public:
     return inst->getKind() >= IK_Branch && inst->getKind() <= IK_Unreachable;
   }
 
-  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvTerminator)
-
 protected:
   SpirvTerminator(Kind kind, spv::Op opcode, SourceLocation loc);
 };
@@ -491,8 +493,6 @@ public:
            inst->getKind() <= IK_BranchConditional;
   }
 
-  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBranching)
-
   virtual llvm::ArrayRef<SpirvBasicBlock *> getTargetBranches() const = 0;
 
 protected:
@@ -839,8 +839,6 @@ public:
            inst->getKind() == IK_BitFieldInsert;
   }
 
-  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvBitField)
-
   virtual SpirvInstruction *getBase() const { return base; }
   virtual SpirvInstruction *getOffset() const { return offset; }
   virtual SpirvInstruction *getCount() const { return count; }
@@ -895,6 +893,136 @@ private:
   SpirvInstruction *insert;
 };
 
+class SpirvConstant : public SpirvInstruction {
+public:
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() >= IK_ConstantBoolean &&
+           inst->getKind() <= IK_ConstantNull;
+  }
+
+protected:
+  SpirvConstant(Kind, spv::Op, QualType resultType, uint32_t resultId,
+                SourceLocation);
+};
+
+class SpirvConstantBoolean : public SpirvConstant {
+public:
+  SpirvConstantBoolean(bool value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_ConstantBoolean;
+  }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantBoolean)
+
+  bool getValue() const { return value; }
+
+private:
+  bool value;
+};
+
+/// \brief Represent OpConstant for integer values.
+class SpirvConstantInteger : public SpirvConstant {
+public:
+  SpirvConstantInteger(uint16_t value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+  SpirvConstantInteger(int16_t value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+  SpirvConstantInteger(uint32_t value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+  SpirvConstantInteger(int32_t value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+  SpirvConstantInteger(uint64_t value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+  SpirvConstantInteger(int64_t value, QualType resultType, uint32_t resultId,
+                       SourceLocation loc);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_ConstantInteger;
+  }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantInteger)
+
+  uint16_t getUnsignedInt16Value() const;
+  int16_t getSignedInt16Value() const;
+  uint32_t getUnsignedInt32Value() const;
+  int32_t getSignedInt32Value() const;
+  uint64_t getUnsignedInt64Value() const;
+  int64_t getSignedInt64Value() const;
+
+  uint32_t getBitwidth() const { return bitwidth; }
+  void setBitwidth(uint32_t width) { bitwidth = width; }
+  bool isSigned() const { return getResultType()->isSignedIntegerType(); }
+
+private:
+  uint32_t bitwidth;
+  uint64_t value;
+};
+
+class SpirvConstantFloat : public SpirvConstant {
+public:
+  SpirvConstantFloat(uint16_t value, QualType resultType, uint32_t resultId,
+                     SourceLocation loc);
+  SpirvConstantFloat(float value, QualType resultType, uint32_t resultId,
+                     SourceLocation loc);
+  SpirvConstantFloat(double value, QualType resultType, uint32_t resultId,
+                     SourceLocation loc);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_ConstantFloat;
+  }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantFloat)
+
+  uint16_t getValue16() const;
+  float getValue32() const;
+  double getValue64() const;
+  uint32_t getBitwidth() const { return bitwidth; }
+  void setBitwidth(uint32_t width) { bitwidth = width; }
+
+private:
+  uint32_t bitwidth;
+  uint64_t value;
+};
+
+class SpirvConstantComposite : public SpirvConstant {
+public:
+  SpirvConstantComposite(llvm::ArrayRef<SpirvConstant *> constituents,
+                         QualType resultType, uint32_t resultId,
+                         SourceLocation loc);
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_ConstantComposite;
+  }
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantComposite)
+
+  llvm::ArrayRef<SpirvConstant *> getConstituents() const {
+    return constituents;
+  }
+
+private:
+  std::vector<SpirvConstant *> constituents;
+};
+
+class SpirvConstantNull : public SpirvConstant {
+public:
+  SpirvConstantNull(QualType resultType, uint32_t resultId, SourceLocation loc);
+
+  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantNull)
+
+  // For LLVM-style RTTI
+  static bool classof(const SpirvInstruction *inst) {
+    return inst->getKind() == IK_ConstantNull;
+  }
+};
+
 /// \brief Composition instructions
 ///
 /// This class includes OpConstantComposite, OpSpecConstantComposite,
@@ -986,7 +1114,6 @@ public:
   DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvEmitVertex)
 };
 
-
 /// \brief EndPrimitive instruction
 class SpirvEndPrimitive : public SpirvInstruction {
 public:
@@ -1055,8 +1182,6 @@ public:
            inst->getKind() <= IK_GroupNonUniformUnaryOp;
   }
 
-  DECLARE_INVOKE_VISITOR_FOR_CLASS(SpirvGroupNonUniformOp)
-
   spv::Scope getExecutionScope() const { return execScope; }
 
 protected:

+ 5 - 0
tools/clang/include/clang/SPIRV/SpirvVisitor.h

@@ -83,6 +83,11 @@ public:
   DEFINE_VISIT_METHOD(SpirvBinaryOp)
   DEFINE_VISIT_METHOD(SpirvBitFieldExtract)
   DEFINE_VISIT_METHOD(SpirvBitFieldInsert)
+  DEFINE_VISIT_METHOD(SpirvConstantBoolean)
+  DEFINE_VISIT_METHOD(SpirvConstantInteger)
+  DEFINE_VISIT_METHOD(SpirvConstantFloat)
+  DEFINE_VISIT_METHOD(SpirvConstantComposite)
+  DEFINE_VISIT_METHOD(SpirvConstantNull)
   DEFINE_VISIT_METHOD(SpirvComposite)
   DEFINE_VISIT_METHOD(SpirvCompositeExtract)
   DEFINE_VISIT_METHOD(SpirvCompositeInsert)

+ 133 - 1
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -8,12 +8,44 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/SPIRV/EmitVisitor.h"
+#include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
 #include "clang/SPIRV/SpirvFunction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
 #include "clang/SPIRV/SpirvModule.h"
 #include "clang/SPIRV/String.h"
 
+namespace {
+uint32_t zeroExtendTo32Bits(uint16_t value) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine. Assuming Little Endian at the moment.
+  struct two16Bits {
+    uint16_t low;
+    uint16_t high;
+  };
+
+  two16Bits result = {value, 0};
+  return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
+}
+
+uint32_t signExtendTo32Bits(int16_t value) {
+  // TODO: The ordering of the 2 words depends on the endian-ness of the host
+  // machine. Assuming Little Endian at the moment.
+  struct two16Bits {
+    int16_t low;
+    uint16_t high;
+  };
+
+  two16Bits result = {value, 0};
+
+  // Sign bit is 1
+  if (value >> 15) {
+    result.high = 0xffff;
+  }
+  return clang::spirv::cast::BitwiseCast<uint32_t, two16Bits>(result);
+}
+} // anonymous namespace
+
 namespace clang {
 namespace spirv {
 
@@ -66,7 +98,8 @@ void EmitVisitor::finalizeInstruction() {
   case spv::Op::OpDecorationGroup:
   case spv::Op::OpDecorateStringGOOGLE:
   case spv::Op::OpMemberDecorateStringGOOGLE:
-    annotationsBinary.insert(annotationsBinary.end(), curInst.begin(), curInst.end());
+    annotationsBinary.insert(annotationsBinary.end(), curInst.begin(),
+                             curInst.end());
     break;
   default:
     mainBinary.insert(mainBinary.end(), curInst.begin(), curInst.end());
@@ -396,6 +429,105 @@ bool EmitVisitor::visit(SpirvBitFieldInsert *inst) {
   return true;
 }
 
+bool EmitVisitor::visit(SpirvConstantBoolean *inst) {
+  initInstruction(inst->getopcode());
+  curInst.push_back(inst->getResultTypeId());
+  curInst.push_back(inst->getResultId());
+  finalizeInstruction();
+  emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
+  return true;
+}
+
+bool EmitVisitor::visit(SpirvConstantInteger *inst) {
+  initInstruction(inst->getopcode());
+  curInst.push_back(inst->getResultTypeId());
+  curInst.push_back(inst->getResultId());
+  // 16-bit cases
+  if (inst->getBitwidth() == 16) {
+    if (inst->isSigned()) {
+      curInst.push_back(signExtendTo32Bits(inst->getSignedInt16Value()));
+    } else {
+      curInst.push_back(zeroExtendTo32Bits(inst->getUnsignedInt16Value()));
+    }
+  }
+  // 32-bit cases
+  else if (inst->getBitwidth() == 32) {
+    if (inst->isSigned()) {
+      curInst.push_back(
+          cast::BitwiseCast<uint32_t, int32_t>(inst->getSignedInt32Value()));
+    } else {
+      curInst.push_back(inst->getUnsignedInt32Value());
+    }
+  }
+  // 64-bit cases
+  else {
+    struct wideInt {
+      uint32_t word0;
+      uint32_t word1;
+    };
+    wideInt words;
+    if (inst->isSigned()) {
+      words = cast::BitwiseCast<wideInt, int64_t>(inst->getSignedInt64Value());
+    } else {
+      words =
+          cast::BitwiseCast<wideInt, uint64_t>(inst->getUnsignedInt64Value());
+    }
+    curInst.push_back(words.word0);
+    curInst.push_back(words.word1);
+  }
+  finalizeInstruction();
+  emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
+  return true;
+}
+
+bool EmitVisitor::visit(SpirvConstantFloat *inst) {
+  initInstruction(inst->getopcode());
+  curInst.push_back(inst->getResultTypeId());
+  curInst.push_back(inst->getResultId());
+  if (inst->getBitwidth() == 16) {
+    // According to the SPIR-V Spec:
+    // When the type's bit width is less than 32-bits, the literal's value
+    // appears in the low-order bits of the word, and the high-order bits must
+    // be 0 for a floating-point type.
+    curInst.push_back(zeroExtendTo32Bits(inst->getValue16()));
+  } else if (inst->getBitwidth() == 32) {
+    curInst.push_back(cast::BitwiseCast<uint32_t, float>(inst->getValue32()));
+  } else {
+    // TODO: The ordering of the 2 words depends on the endian-ness of the host
+    // machine.
+    struct wideFloat {
+      uint32_t word0;
+      uint32_t word1;
+    };
+    wideFloat words = cast::BitwiseCast<wideFloat, double>(inst->getValue64());
+    curInst.push_back(words.word0);
+    curInst.push_back(words.word1);
+  }
+  finalizeInstruction();
+  emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
+  return true;
+}
+
+bool EmitVisitor::visit(SpirvConstantComposite *inst) {
+  initInstruction(inst->getopcode());
+  curInst.push_back(inst->getResultTypeId());
+  curInst.push_back(inst->getResultId());
+  for (const auto constituent : inst->getConstituents())
+    curInst.push_back(constituent->getResultId());
+  finalizeInstruction();
+  emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
+  return true;
+}
+
+bool EmitVisitor::visit(SpirvConstantNull *inst) {
+  initInstruction(inst->getopcode());
+  curInst.push_back(inst->getResultTypeId());
+  curInst.push_back(inst->getResultId());
+  finalizeInstruction();
+  emitDebugNameForInstruction(inst->getResultId(), inst->getDebugName());
+  return true;
+}
+
 bool EmitVisitor::visit(SpirvComposite *inst) {
   initInstruction(inst->getopcode());
   curInst.push_back(inst->getResultTypeId());

+ 1 - 3
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -7,10 +7,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "TypeTranslator.h"
-#include "clang/SPIRV/FeatureManager.h"
 #include "clang/SPIRV/SpirvBuilder.h"
-#include "llvm/Support/MathExtras.h"
+#include "TypeTranslator.h"
 
 namespace clang {
 namespace spirv {

+ 154 - 0
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "clang/SPIRV/BitwiseCast.h"
 #include "clang/SPIRV/SpirvBasicBlock.h"
 #include "clang/SPIRV/SpirvFunction.h"
 #include "clang/SPIRV/SpirvInstruction.h"
@@ -50,6 +51,11 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBarrier)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBinaryOp)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldExtract)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvBitFieldInsert)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantBoolean)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantInteger)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantFloat)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantComposite)
+DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvConstantNull)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvComposite)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeExtract)
 DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvCompositeInsert)
@@ -345,6 +351,154 @@ SpirvComposite::SpirvComposite(
                        resultType, resultId, loc),
       consituents(constituentsVec.begin(), constituentsVec.end()) {}
 
+SpirvConstant::SpirvConstant(Kind kind, spv::Op op, QualType resultType,
+                             uint32_t resultId, SourceLocation loc)
+    : SpirvInstruction(kind, op, resultType, resultId, loc) {}
+
+SpirvConstantBoolean::SpirvConstantBoolean(bool val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantBoolean,
+                    val ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse,
+                    resultType, resultId, loc),
+      value(val) {}
+
+SpirvConstantInteger::SpirvConstantInteger(uint16_t val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, resultType,
+                    resultId, loc),
+      bitwidth(16), value(static_cast<uint64_t>(val)) {
+  assert(resultType->isUnsignedIntegerType());
+}
+
+SpirvConstantInteger::SpirvConstantInteger(int16_t val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, resultType,
+                    resultId, loc),
+      bitwidth(16), value(static_cast<uint64_t>(val)) {
+  assert(resultType->isSignedIntegerType());
+}
+
+SpirvConstantInteger::SpirvConstantInteger(uint32_t val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, resultType,
+                    resultId, loc),
+      bitwidth(32), value(static_cast<uint64_t>(val)) {
+  assert(resultType->isUnsignedIntegerType());
+}
+
+SpirvConstantInteger::SpirvConstantInteger(int32_t val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, resultType,
+                    resultId, loc),
+      bitwidth(32), value(static_cast<uint64_t>(val)) {
+  assert(resultType->isSignedIntegerType());
+}
+
+SpirvConstantInteger::SpirvConstantInteger(uint64_t val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, resultType,
+                    resultId, loc),
+      bitwidth(64), value(val) {
+  assert(resultType->isUnsignedIntegerType());
+}
+
+SpirvConstantInteger::SpirvConstantInteger(int64_t val, QualType resultType,
+                                           uint32_t resultId,
+                                           SourceLocation loc)
+    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, resultType,
+                    resultId, loc),
+      bitwidth(64), value(static_cast<uint64_t>(val)) {
+  assert(resultType->isSignedIntegerType());
+}
+
+uint16_t SpirvConstantInteger::getUnsignedInt16Value() const {
+  assert(!isSigned());
+  assert(bitwidth == 16);
+  return static_cast<uint16_t>(value);
+}
+
+int16_t SpirvConstantInteger::getSignedInt16Value() const {
+  assert(isSigned());
+  assert(bitwidth == 16);
+  return static_cast<int16_t>(value);
+}
+
+uint32_t SpirvConstantInteger::getUnsignedInt32Value() const {
+  assert(!isSigned());
+  assert(bitwidth == 32);
+  return static_cast<uint32_t>(value);
+}
+
+int32_t SpirvConstantInteger::getSignedInt32Value() const {
+  assert(isSigned());
+  assert(bitwidth == 32);
+  return static_cast<int32_t>(value);
+}
+
+uint64_t SpirvConstantInteger::getUnsignedInt64Value() const {
+  assert(!isSigned());
+  assert(bitwidth == 64);
+  return value;
+}
+
+int64_t SpirvConstantInteger::getSignedInt64Value() const {
+  assert(isSigned());
+  assert(bitwidth == 64);
+  return static_cast<int64_t>(value);
+}
+
+SpirvConstantFloat::SpirvConstantFloat(uint16_t val, QualType resultType,
+                                       uint32_t resultId, SourceLocation loc)
+    : SpirvConstant(IK_ConstantFloat, spv::Op::OpConstant, resultType, resultId,
+                    loc),
+      bitwidth(16), value(static_cast<uint64_t>(val)) {}
+
+SpirvConstantFloat::SpirvConstantFloat(float val, QualType resultType,
+                                       uint32_t resultId, SourceLocation loc)
+    : SpirvConstant(IK_ConstantFloat, spv::Op::OpConstant, resultType, resultId,
+                    loc),
+      bitwidth(32),
+      value(static_cast<uint64_t>(cast::BitwiseCast<uint32_t, float>(val))) {}
+
+SpirvConstantFloat::SpirvConstantFloat(double val, QualType resultType,
+                                       uint32_t resultId, SourceLocation loc)
+    : SpirvConstant(IK_ConstantFloat, spv::Op::OpConstant, resultType, resultId,
+                    loc),
+      bitwidth(64), value(cast::BitwiseCast<uint64_t, double>(val)) {}
+
+uint16_t SpirvConstantFloat::getValue16() const {
+  assert(bitwidth == 16);
+  return static_cast<uint16_t>(value);
+}
+
+float SpirvConstantFloat::getValue32() const {
+  assert(bitwidth == 32);
+  return cast::BitwiseCast<float, uint32_t>(static_cast<uint32_t>(value));
+}
+
+double SpirvConstantFloat::getValue64() const {
+  assert(bitwidth == 64);
+  return cast::BitwiseCast<double, uint64_t>(value);
+}
+
+SpirvConstantComposite::SpirvConstantComposite(
+    llvm::ArrayRef<SpirvConstant *> constituentsVec, QualType resultType,
+    uint32_t resultId, SourceLocation loc)
+    : SpirvConstant(IK_ConstantComposite, spv::Op::OpConstantComposite,
+                    resultType, resultId, loc),
+      constituents(constituentsVec) {}
+
+SpirvConstantNull::SpirvConstantNull(QualType resultType, uint32_t resultId,
+                                     SourceLocation loc)
+    : SpirvConstant(IK_ConstantNull, spv::Op::OpConstantNull, resultType,
+                    resultId, loc) {}
+
 SpirvCompositeExtract::SpirvCompositeExtract(QualType resultType,
                                              uint32_t resultId,
                                              SourceLocation loc,

+ 1 - 0
tools/clang/unittests/SPIRV/CMakeLists.txt

@@ -18,6 +18,7 @@ add_clang_unittest(clang-spirv-tests
   SPIRVTestOptions.cpp
   StructureTest.cpp
   TestMain.cpp
+  SpirvConstantTest.cpp
   StringTest.cpp
   TypeTest.cpp
   WholeFileTestFixture.cpp

+ 36 - 0
tools/clang/unittests/SPIRV/SpirvConstantTest.cpp

@@ -0,0 +1,36 @@
+//===- unittests/SPIRV/SpirvConstantTest.cpp --- SPIR-V Constant tests ----===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+#include "clang/SPIRV/SpirvInstruction.h"
+
+namespace {
+using namespace clang::spirv;
+
+TEST(SpirvConstant, Float16) {
+  const uint16_t f16 = 12;
+  SpirvConstantFloat constant(f16, {}, 0, {});
+  EXPECT_EQ(f16, constant.getValue16());
+}
+
+TEST(SpirvConstant, Float32) {
+  const float f32 = 1.5;
+  SpirvConstantFloat constant(f32, {}, 0, {});
+  EXPECT_EQ(f32, constant.getValue32());
+}
+
+TEST(SpirvConstant, Float64) {
+  const double f64 = 3.14;
+  SpirvConstantFloat constant(f64, {}, 0, {});
+  EXPECT_EQ(f64, constant.getValue64());
+}
+
+} // anonymous namespace