瀏覽代碼

[spirv] Add isSpecConst boolean to SpirvContext APIs.

Ehsan Nasiri 6 年之前
父節點
當前提交
c61ee3081d

+ 6 - 4
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -209,10 +209,10 @@ public:
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getACSBufferCounterType();
 
-  SpirvConstant *getConstantUint32(uint32_t value);
-  SpirvConstant *getConstantInt32(int32_t value);
-  SpirvConstant *getConstantFloat32(float value);
-  SpirvConstant *getConstantBool(bool value);
+  SpirvConstant *getConstantUint32(uint32_t value, bool specConst = false);
+  SpirvConstant *getConstantInt32(int32_t value, bool specConst = false);
+  SpirvConstant *getConstantFloat32(float value, bool specConst = false);
+  SpirvConstant *getConstantBool(bool value, bool specConst = false);
   // TODO: Add getConstant* methods for other types.
 
 private:
@@ -271,6 +271,8 @@ private:
   llvm::SmallVector<SpirvConstantFloat *, 8> floatConstants;
   SpirvConstantBoolean *boolTrueConstant;
   SpirvConstantBoolean *boolFalseConstant;
+  SpirvConstantBoolean *boolTrueSpecConstant;
+  SpirvConstantBoolean *boolFalseSpecConstant;
   // TODO: Add vectors of other constant types here.
 };
 

+ 4 - 1
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -942,7 +942,10 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
 
 void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
                                               SpirvInstruction *specConstant) {
-  specConstant->setSpecConstant();
+  // TODO(ehsan): Remove the following line once we are sure the instruction
+  // that was created prior to calling this function was in fact a spec constant
+  // instruction.
+  // specConstant->setSpecConstant();
   specConstant->setRValue();
   astDecls[decl] = DeclSpirvInfo(specConstant);
 }

+ 39 - 22
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -77,7 +77,8 @@ const Decoration *SPIRVContext::registerDecoration(const Decoration &d) {
 SpirvContext::SpirvContext()
     : allocator(), voidType(nullptr), boolType(nullptr), sintTypes({}),
       uintTypes({}), floatTypes({}), samplerType(nullptr),
-      boolTrueConstant(nullptr), boolFalseConstant(nullptr) {
+      boolTrueConstant(nullptr), boolFalseConstant(nullptr),
+      boolTrueSpecConstant(nullptr), boolFalseSpecConstant(nullptr) {
   voidType = new (this) VoidType;
   boolType = new (this) BoolType;
   samplerType = new (this) SamplerType;
@@ -329,9 +330,9 @@ const StructType *SpirvContext::getACSBufferCounterType() {
   return type;
 }
 
-SpirvConstant *SpirvContext::getConstantUint32(uint32_t value) {
+SpirvConstant *SpirvContext::getConstantUint32(uint32_t value, bool specConst) {
   const IntegerType *intType = getUIntType(32);
-  SpirvConstantInteger tempConstant(intType, value);
+  SpirvConstantInteger tempConstant(intType, value, specConst);
 
   auto found =
       std::find_if(integerConstants.begin(), integerConstants.end(),
@@ -343,14 +344,14 @@ SpirvConstant *SpirvContext::getConstantUint32(uint32_t value) {
     return *found;
 
   // Couldn't find the constant. Create one.
-  auto *intConst = new (this) SpirvConstantInteger(intType, value);
+  auto *intConst = new (this) SpirvConstantInteger(intType, value, specConst);
   integerConstants.push_back(intConst);
   return intConst;
 }
 
-SpirvConstant *SpirvContext::getConstantInt32(int32_t value) {
+SpirvConstant *SpirvContext::getConstantInt32(int32_t value, bool specConst) {
   const IntegerType *intType = getSIntType(32);
-  SpirvConstantInteger tempConstant(intType, value);
+  SpirvConstantInteger tempConstant(intType, value, specConst);
 
   auto found =
       std::find_if(integerConstants.begin(), integerConstants.end(),
@@ -362,14 +363,14 @@ SpirvConstant *SpirvContext::getConstantInt32(int32_t value) {
     return *found;
 
   // Couldn't find the constant. Create one.
-  auto *intConst = new (this) SpirvConstantInteger(intType, value);
+  auto *intConst = new (this) SpirvConstantInteger(intType, value, specConst);
   integerConstants.push_back(intConst);
   return intConst;
 }
 
-SpirvConstant *SpirvContext::getConstantFloat32(float value) {
+SpirvConstant *SpirvContext::getConstantFloat32(float value, bool specConst) {
   const FloatType *floatType = getFloatType(32);
-  SpirvConstantFloat tempConstant(floatType, value);
+  SpirvConstantFloat tempConstant(floatType, value, specConst);
 
   auto found =
       std::find_if(floatConstants.begin(), floatConstants.end(),
@@ -381,25 +382,41 @@ SpirvConstant *SpirvContext::getConstantFloat32(float value) {
     return *found;
 
   // Couldn't find the constant. Create one.
-  auto *floatConst = new (this) SpirvConstantFloat(floatType, value);
+  auto *floatConst = new (this) SpirvConstantFloat(floatType, value, specConst);
   floatConstants.push_back(floatConst);
   return floatConst;
 }
 
-SpirvConstant *SpirvContext::getConstantBool(bool value) {
-  if (value && boolTrueConstant)
-    return boolTrueConstant;
-
-  if (!value && boolFalseConstant)
-    return boolFalseConstant;
+SpirvConstant *SpirvContext::getConstantBool(bool value, bool specConst) {
+  if (value) {
+    if (specConst) {
+      return boolTrueSpecConstant;
+    } else {
+      return boolTrueConstant;
+    }
+  } else {
+    if (specConst) {
+      return boolFalseSpecConstant;
+    } else {
+      return boolFalseConstant;
+    }
+  }
 
   // Couldn't find the constant. Create one.
-  auto *boolConst = new (this) SpirvConstantBoolean(getBoolType(), value);
-
-  if (value)
-    boolTrueConstant = boolConst;
-  else
-    boolFalseConstant = boolConst;
+  auto *boolConst =
+      new (this) SpirvConstantBoolean(getBoolType(), value, specConst);
+
+  if (value) {
+    if (specConst)
+      boolTrueSpecConstant = boolConst;
+    else
+      boolTrueConstant = boolConst;
+  } else {
+    if (specConst)
+      boolFalseSpecConstant = boolConst;
+    else
+      boolFalseConstant = boolConst;
+  }
 
   return boolConst;
 }

+ 1 - 2
tools/clang/lib/SPIRV/SpirvInstruction.cpp

@@ -88,8 +88,7 @@ SpirvInstruction::SpirvInstruction(Kind k, spv::Op op, QualType astType,
       debugName(), resultType(nullptr), resultTypeId(0),
       layoutRule(SpirvLayoutRule::Void), containsAlias(false),
       storageClass(spv::StorageClass::Max), isRValue_(false),
-      isConstant_(false), isSpecConstant_(false), isRelaxedPrecision_(false),
-      isNonUniform_(false) {}
+      isRelaxedPrecision_(false), isNonUniform_(false) {}
 
 SpirvCapability::SpirvCapability(SourceLocation loc, spv::Capability cap)
     : SpirvInstruction(IK_Capability, spv::Op::OpCapability, QualType(),