瀏覽代碼

[spirv] Handle Spec constants in SpirvConstant class.

Ehsan 6 年之前
父節點
當前提交
cf4ea8065c

+ 24 - 11
tools/clang/include/clang/SPIRV/SpirvInstruction.h

@@ -913,13 +913,16 @@ public:
            inst->getKind() <= IK_ConstantNull;
            inst->getKind() <= IK_ConstantNull;
   }
   }
 
 
+  bool isSpecConstant() const;
+
 protected:
 protected:
   SpirvConstant(Kind, spv::Op, const SpirvType *);
   SpirvConstant(Kind, spv::Op, const SpirvType *);
 };
 };
 
 
 class SpirvConstantBoolean : public SpirvConstant {
 class SpirvConstantBoolean : public SpirvConstant {
 public:
 public:
-  SpirvConstantBoolean(const BoolType *type, bool value);
+  SpirvConstantBoolean(const BoolType *type, bool value,
+                       bool isSpecConst = false);
 
 
   // For LLVM-style RTTI
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
   static bool classof(const SpirvInstruction *inst) {
@@ -939,12 +942,18 @@ private:
 /// \brief Represent OpConstant for integer values.
 /// \brief Represent OpConstant for integer values.
 class SpirvConstantInteger : public SpirvConstant {
 class SpirvConstantInteger : public SpirvConstant {
 public:
 public:
-  SpirvConstantInteger(const IntegerType *type, uint16_t value);
-  SpirvConstantInteger(const IntegerType *type, int16_t value);
-  SpirvConstantInteger(const IntegerType *type, uint32_t value);
-  SpirvConstantInteger(const IntegerType *type, int32_t value);
-  SpirvConstantInteger(const IntegerType *type, uint64_t value);
-  SpirvConstantInteger(const IntegerType *type, int64_t value);
+  SpirvConstantInteger(const IntegerType *type, uint16_t value,
+                       bool isSpecConst = false);
+  SpirvConstantInteger(const IntegerType *type, int16_t value,
+                       bool isSpecConst = false);
+  SpirvConstantInteger(const IntegerType *type, uint32_t value,
+                       bool isSpecConst = false);
+  SpirvConstantInteger(const IntegerType *type, int32_t value,
+                       bool isSpecConst = false);
+  SpirvConstantInteger(const IntegerType *type, uint64_t value,
+                       bool isSpecConst = false);
+  SpirvConstantInteger(const IntegerType *type, int64_t value,
+                       bool isSpecConst = false);
 
 
   // For LLVM-style RTTI
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
   static bool classof(const SpirvInstruction *inst) {
@@ -974,9 +983,12 @@ private:
 
 
 class SpirvConstantFloat : public SpirvConstant {
 class SpirvConstantFloat : public SpirvConstant {
 public:
 public:
-  SpirvConstantFloat(const FloatType *type, uint16_t value);
-  SpirvConstantFloat(const FloatType *type, float value);
-  SpirvConstantFloat(const FloatType *type, double value);
+  SpirvConstantFloat(const FloatType *type, uint16_t value,
+                     bool isSpecConst = false);
+  SpirvConstantFloat(const FloatType *type, float value,
+                     bool isSpecConst = false);
+  SpirvConstantFloat(const FloatType *type, double value,
+                     bool isSpecConst = false);
 
 
   // For LLVM-style RTTI
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
   static bool classof(const SpirvInstruction *inst) {
@@ -1002,7 +1014,8 @@ private:
 class SpirvConstantComposite : public SpirvConstant {
 class SpirvConstantComposite : public SpirvConstant {
 public:
 public:
   SpirvConstantComposite(const SpirvType *type,
   SpirvConstantComposite(const SpirvType *type,
-                         llvm::ArrayRef<const SpirvConstant *> constituents);
+                         llvm::ArrayRef<const SpirvConstant *> constituents,
+                         bool isSpecConst = false);
 
 
   // For LLVM-style RTTI
   // For LLVM-style RTTI
   static bool classof(const SpirvInstruction *inst) {
   static bool classof(const SpirvInstruction *inst) {

+ 70 - 25
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -360,55 +360,83 @@ SpirvConstant::SpirvConstant(Kind kind, spv::Op op, const SpirvType *spvType)
   setResultType(spvType);
   setResultType(spvType);
 }
 }
 
 
-SpirvConstantBoolean::SpirvConstantBoolean(const BoolType *type, bool val)
+bool SpirvConstant::isSpecConstant() const {
+  return opcode == spv::Op::OpSpecConstant ||
+         opcode == spv::Op::OpSpecConstantTrue ||
+         opcode == spv::Op::OpSpecConstantFalse ||
+         opcode == spv::Op::OpSpecConstantComposite;
+}
+
+SpirvConstantBoolean::SpirvConstantBoolean(const BoolType *type, bool val,
+                                           bool isSpecConst)
     : SpirvConstant(IK_ConstantBoolean,
     : SpirvConstant(IK_ConstantBoolean,
-                    val ? spv::Op::OpConstantTrue : spv::Op::OpConstantFalse,
+                    val ? (isSpecConst ? spv::Op::OpSpecConstantTrue
+                                       : spv::Op::OpConstantTrue)
+                        : (isSpecConst ? spv::Op::OpSpecConstantFalse
+                                       : spv::Op::OpConstantFalse),
                     type),
                     type),
       value(val) {}
       value(val) {}
 
 
 bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const {
 bool SpirvConstantBoolean::operator==(const SpirvConstantBoolean &that) const {
-  return resultType == that.getResultType() && value == that.getValue();
+  return resultType == that.getResultType() && value == that.getValue() &&
+         opcode == that.getopcode();
 }
 }
 
 
 SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type,
 SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type,
-                                           uint16_t val)
-    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, type),
+                                           uint16_t val, bool isSpecConst)
+    : SpirvConstant(IK_ConstantInteger,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(val)) {
       value(static_cast<uint64_t>(val)) {
   assert(type->getBitwidth() == 16);
   assert(type->getBitwidth() == 16);
   assert(!type->isSignedInt());
   assert(!type->isSignedInt());
 }
 }
 
 
-SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type, int16_t val)
-    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, type),
+SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type, int16_t val,
+                                           bool isSpecConst)
+    : SpirvConstant(IK_ConstantInteger,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(val)) {
       value(static_cast<uint64_t>(val)) {
   assert(type->getBitwidth() == 16);
   assert(type->getBitwidth() == 16);
   assert(type->isSignedInt());
   assert(type->isSignedInt());
 }
 }
 
 
 SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type,
 SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type,
-                                           uint32_t val)
-    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, type),
+                                           uint32_t val, bool isSpecConst)
+    : SpirvConstant(IK_ConstantInteger,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(val)) {
       value(static_cast<uint64_t>(val)) {
   assert(type->getBitwidth() == 32);
   assert(type->getBitwidth() == 32);
   assert(!type->isSignedInt());
   assert(!type->isSignedInt());
 }
 }
 
 
-SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type, int32_t val)
-    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, type),
+SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type, int32_t val,
+                                           bool isSpecConst)
+    : SpirvConstant(IK_ConstantInteger,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(val)) {
       value(static_cast<uint64_t>(val)) {
   assert(type->getBitwidth() == 32);
   assert(type->getBitwidth() == 32);
   assert(type->isSignedInt());
   assert(type->isSignedInt());
 }
 }
 
 
 SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type,
 SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type,
-                                           uint64_t val)
-    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, type), value(val) {
+                                           uint64_t val, bool isSpecConst)
+    : SpirvConstant(IK_ConstantInteger,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
+      value(val) {
   assert(type->getBitwidth() == 64);
   assert(type->getBitwidth() == 64);
   assert(!type->isSignedInt());
   assert(!type->isSignedInt());
 }
 }
 
 
-SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type, int64_t val)
-    : SpirvConstant(IK_ConstantInteger, spv::Op::OpConstant, type),
+SpirvConstantInteger::SpirvConstantInteger(const IntegerType *type, int64_t val,
+                                           bool isSpecConst)
+    : SpirvConstant(IK_ConstantInteger,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(val)) {
       value(static_cast<uint64_t>(val)) {
   assert(type->getBitwidth() == 64);
   assert(type->getBitwidth() == 64);
   assert(type->isSignedInt());
   assert(type->isSignedInt());
@@ -461,23 +489,33 @@ int64_t SpirvConstantInteger::getSignedInt64Value() const {
 }
 }
 
 
 bool SpirvConstantInteger::operator==(const SpirvConstantInteger &that) const {
 bool SpirvConstantInteger::operator==(const SpirvConstantInteger &that) const {
-  return resultType == that.getResultType() && value == that.getValueBits();
+  return resultType == that.getResultType() && value == that.getValueBits() &&
+         opcode == that.getopcode();
 }
 }
 
 
-SpirvConstantFloat::SpirvConstantFloat(const FloatType *type, uint16_t val)
-    : SpirvConstant(IK_ConstantFloat, spv::Op::OpConstant, type),
+SpirvConstantFloat::SpirvConstantFloat(const FloatType *type, uint16_t val,
+                                       bool isSpecConst)
+    : SpirvConstant(IK_ConstantFloat,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(val)) {
       value(static_cast<uint64_t>(val)) {
   assert(type->getBitwidth() == 16);
   assert(type->getBitwidth() == 16);
 }
 }
 
 
-SpirvConstantFloat::SpirvConstantFloat(const FloatType *type, float val)
-    : SpirvConstant(IK_ConstantFloat, spv::Op::OpConstant, type),
+SpirvConstantFloat::SpirvConstantFloat(const FloatType *type, float val,
+                                       bool isSpecConst)
+    : SpirvConstant(IK_ConstantFloat,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(static_cast<uint64_t>(cast::BitwiseCast<uint32_t, float>(val))) {
       value(static_cast<uint64_t>(cast::BitwiseCast<uint32_t, float>(val))) {
   assert(type->getBitwidth() == 32);
   assert(type->getBitwidth() == 32);
 }
 }
 
 
-SpirvConstantFloat::SpirvConstantFloat(const FloatType *type, double val)
-    : SpirvConstant(IK_ConstantFloat, spv::Op::OpConstant, type),
+SpirvConstantFloat::SpirvConstantFloat(const FloatType *type, double val,
+                                       bool isSpecConst)
+    : SpirvConstant(IK_ConstantFloat,
+                    isSpecConst ? spv::Op::OpSpecConstant : spv::Op::OpConstant,
+                    type),
       value(cast::BitwiseCast<uint64_t, double>(val)) {
       value(cast::BitwiseCast<uint64_t, double>(val)) {
   assert(type->getBitwidth() == 64);
   assert(type->getBitwidth() == 64);
 }
 }
@@ -503,17 +541,24 @@ double SpirvConstantFloat::getValue64() const {
 }
 }
 
 
 bool SpirvConstantFloat::operator==(const SpirvConstantFloat &that) const {
 bool SpirvConstantFloat::operator==(const SpirvConstantFloat &that) const {
-  return resultType == that.getResultType() && value == that.getValueBits();
+  return resultType == that.getResultType() && value == that.getValueBits() &&
+         opcode == that.getopcode();
 }
 }
 
 
 SpirvConstantComposite::SpirvConstantComposite(
 SpirvConstantComposite::SpirvConstantComposite(
     const SpirvType *type,
     const SpirvType *type,
-    llvm::ArrayRef<const SpirvConstant *> constituentsVec)
-    : SpirvConstant(IK_ConstantComposite, spv::Op::OpConstantComposite, type),
+    llvm::ArrayRef<const SpirvConstant *> constituentsVec, bool isSpecConst)
+    : SpirvConstant(IK_ConstantComposite,
+                    isSpecConst ? spv::Op::OpSpecConstantComposite
+                                : spv::Op::OpConstantComposite,
+                    type),
       constituents(constituentsVec.begin(), constituentsVec.end()) {}
       constituents(constituentsVec.begin(), constituentsVec.end()) {}
 
 
 bool SpirvConstantComposite::
 bool SpirvConstantComposite::
 operator==(const SpirvConstantComposite &other) const {
 operator==(const SpirvConstantComposite &other) const {
+  if (opcode != other.getopcode())
+    return false;
+
   if (resultType != other.getResultType())
   if (resultType != other.getResultType())
     return false;
     return false;
 
 

+ 42 - 0
tools/clang/unittests/SPIRV/SpirvConstantTest.cpp

@@ -194,4 +194,46 @@ TEST(SpirvConstant, CheckOperatorEqualOnComposite2) {
   EXPECT_FALSE(arrayConstant1 == arrayConstant2);
   EXPECT_FALSE(arrayConstant1 == arrayConstant2);
 }
 }
 
 
+TEST(SpirvConstant, BoolConstNotEqualSpecConst) {
+  SpirvContext ctx;
+  SpirvConstantBoolean constant1(ctx.getBoolType(), true, /*SpecConst*/ true);
+  SpirvConstantBoolean constant2(ctx.getBoolType(), false, /*SpecConst*/ false);
+  EXPECT_FALSE(constant1 == constant2);
+}
+
+TEST(SpirvConstant, IntConstNotEqualSpecConst) {
+  SpirvContext ctx;
+  SpirvConstantInteger constant1(ctx.getSIntType(32), 5, /*SpecConst*/ true);
+  SpirvConstantInteger constant2(ctx.getSIntType(32), 7, /*SpecConst*/ false);
+  EXPECT_FALSE(constant1 == constant2);
+}
+
+TEST(SpirvConstant, FloatConstNotEqualSpecConst) {
+  SpirvContext ctx;
+  SpirvConstantFloat constant1(ctx.getFloatType(64), 3.14, /*SpecConst*/ true);
+  SpirvConstantFloat constant2(ctx.getFloatType(64), 3.15, /*SpecConst*/ false);
+  EXPECT_FALSE(constant1 == constant2);
+}
+
+TEST(SpirvConstant, CompositeConstNotEqualSpecConstComposite) {
+  // Make a constant array of size 2.
+  // Each array element is a vector of 4 floats.
+  SpirvContext ctx;
+  const FloatType *f32Type = ctx.getFloatType(32);
+  const VectorType *vecType = ctx.getVectorType(f32Type, 4);
+  const ArrayType *arrType = ctx.getArrayType(vecType, 2);
+  SpirvConstantFloat f1(f32Type, 3.14);
+  SpirvConstantFloat f2(f32Type, 5.f);
+  SpirvConstantFloat f3(f32Type, -1.f);
+  SpirvConstantFloat f4(f32Type, 0.f);
+  llvm::SmallVector<SpirvConstant *, 4> vectorValues = {&f1, &f2, &f3, &f4};
+  SpirvConstantComposite vec4(vecType, vectorValues);
+  llvm::SmallVector<SpirvConstant *, 2> arrayValues = {&vec4, &vec4};
+  SpirvConstantComposite arrayConstant1(arrType, arrayValues,
+                                        /*SpecConst*/ true);
+  SpirvConstantComposite arrayConstant2(arrType, arrayValues,
+                                        /*SpecConst*/ false);
+  EXPECT_FALSE(arrayConstant1 == arrayConstant2);
+}
+
 } // anonymous namespace
 } // anonymous namespace