Browse Source

[spirv] Improve constant API using templates.

Ehsan 6 years ago
parent
commit
2a4957eda7

+ 50 - 0
tools/clang/include/clang/SPIRV/SPIRVContext.h

@@ -209,12 +209,62 @@ public:
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getByteAddressBufferType(bool isWritable);
   const StructType *getACSBufferCounterType();
   const StructType *getACSBufferCounterType();
 
 
+  SpirvConstant *getConstantUint16(uint16_t value, bool specConst = false);
+  SpirvConstant *getConstantInt16(int16_t value, bool specConst = false);
   SpirvConstant *getConstantUint32(uint32_t value, bool specConst = false);
   SpirvConstant *getConstantUint32(uint32_t value, bool specConst = false);
   SpirvConstant *getConstantInt32(int32_t value, bool specConst = false);
   SpirvConstant *getConstantInt32(int32_t value, bool specConst = false);
+  SpirvConstant *getConstantUint64(uint64_t value, bool specConst = false);
+  SpirvConstant *getConstantInt64(int64_t value, bool specConst = false);
+  SpirvConstant *getConstantFloat16(uint16_t value, bool specConst = false);
   SpirvConstant *getConstantFloat32(float value, bool specConst = false);
   SpirvConstant *getConstantFloat32(float value, bool specConst = false);
+  SpirvConstant *getConstantFloat64(double value, bool specConst = false);
   SpirvConstant *getConstantBool(bool value, bool specConst = false);
   SpirvConstant *getConstantBool(bool value, bool specConst = false);
   // TODO: Add getConstant* methods for other types.
   // TODO: Add getConstant* methods for other types.
 
 
+private:
+  template <class T>
+  SpirvConstant *getConstantInt(T value, bool isSigned, uint32_t bitwidth,
+                                bool specConst) {
+    const IntegerType *intType =
+        isSigend ? getSIntType(bitwidth) : getUIntType(bitwidth);
+    SpirvConstantInteger tempConstant(intType, value, specConst);
+
+    auto found =
+        std::find_if(integerConstants.begin(), integerConstants.end(),
+                     [&tempConstant](SpirvConstantInteger *cachedConstant) {
+                       return tempConstant == *cachedConstant;
+                     });
+
+    if (found != integerConstants.end())
+      return *found;
+
+    // Couldn't find the constant. Create one.
+    auto *intConst = new (this) SpirvConstantInteger(intType, value, specConst);
+    integerConstants.push_back(intConst);
+    return intConst;
+  }
+
+  template <class T>
+  SpirvConstant *getConstantFloat(T value, uint32_t bitwidth, bool specConst) {
+    const FloatType *floatType = getFloatType(bitwidth);
+    SpirvConstantFloat tempConstant(floatType, value, specConst);
+
+    auto found =
+        std::find_if(floatConstants.begin(), floatConstants.end(),
+                     [&tempConstant](SpirvConstantFloat *cachedConstant) {
+                       return tempConstant == *cachedConstant;
+                     });
+
+    if (found != floatConstants.end())
+      return *found;
+
+    // Couldn't find the constant. Create one.
+    auto *floatConst =
+        new (this) SpirvConstantFloat(floatType, value, specConst);
+    floatConstants.push_back(floatConst);
+    return floatConst;
+  }
+
 private:
 private:
   /// \brief The allocator used to create SPIR-V entity objects.
   /// \brief The allocator used to create SPIR-V entity objects.
   ///
   ///

+ 22 - 49
tools/clang/lib/SPIRV/SPIRVContext.cpp

@@ -330,61 +330,34 @@ const StructType *SpirvContext::getACSBufferCounterType() {
   return type;
   return type;
 }
 }
 
 
+SpirvConstant *SpirvContext::getConstantUint16(uint16_t value, bool specConst) {
+  return getConstantInt<uint16_t>(value, /*isSigned*/ false, 16, specConst);
+}
+SpirvConstant *SpirvContext::getConstantInt16(int16_t value, bool specConst) {
+  return getConstantInt<int16_t>(value, /*isSigned*/ true, 16, specConst);
+}
 SpirvConstant *SpirvContext::getConstantUint32(uint32_t value, bool specConst) {
 SpirvConstant *SpirvContext::getConstantUint32(uint32_t value, bool specConst) {
-  const IntegerType *intType = getUIntType(32);
-  SpirvConstantInteger tempConstant(intType, value, specConst);
-
-  auto found =
-      std::find_if(integerConstants.begin(), integerConstants.end(),
-                   [&tempConstant](SpirvConstantInteger *cachedConstant) {
-                     return tempConstant == *cachedConstant;
-                   });
-
-  if (found != integerConstants.end())
-    return *found;
-
-  // Couldn't find the constant. Create one.
-  auto *intConst = new (this) SpirvConstantInteger(intType, value, specConst);
-  integerConstants.push_back(intConst);
-  return intConst;
+  return getConstantInt<uint32_t>(value, /*isSigned*/ false, 32, specConst);
 }
 }
-
 SpirvConstant *SpirvContext::getConstantInt32(int32_t value, bool specConst) {
 SpirvConstant *SpirvContext::getConstantInt32(int32_t value, bool specConst) {
-  const IntegerType *intType = getSIntType(32);
-  SpirvConstantInteger tempConstant(intType, value, specConst);
-
-  auto found =
-      std::find_if(integerConstants.begin(), integerConstants.end(),
-                   [&tempConstant](SpirvConstantInteger *cachedConstant) {
-                     return tempConstant == *cachedConstant;
-                   });
-
-  if (found != integerConstants.end())
-    return *found;
-
-  // Couldn't find the constant. Create one.
-  auto *intConst = new (this) SpirvConstantInteger(intType, value, specConst);
-  integerConstants.push_back(intConst);
-  return intConst;
+  return getConstantInt<int32_t>(value, /*isSigned*/ true, 32, specConst);
+}
+SpirvConstant *SpirvContext::getConstantUint64(uint64_t value, bool specConst) {
+  return getConstantInt<uint64_t>(value, /*isSigned*/ false, 64, specConst);
+}
+SpirvConstant *SpirvContext::getConstantInt64(int64_t value, bool specConst) {
+  return getConstantInt<int64_t>(value, /*isSigned*/ true, 64, specConst);
 }
 }
 
 
+SpirvConstant *SpirvContext::getConstantFloat16(uint16_t value,
+                                                bool specConst) {
+  return getConstantFloat<uint16_t>(value, 16, specConst);
+}
 SpirvConstant *SpirvContext::getConstantFloat32(float value, bool specConst) {
 SpirvConstant *SpirvContext::getConstantFloat32(float value, bool specConst) {
-  const FloatType *floatType = getFloatType(32);
-  SpirvConstantFloat tempConstant(floatType, value, specConst);
-
-  auto found =
-      std::find_if(floatConstants.begin(), floatConstants.end(),
-                   [&tempConstant](SpirvConstantFloat *cachedConstant) {
-                     return tempConstant == *cachedConstant;
-                   });
-
-  if (found != floatConstants.end())
-    return *found;
-
-  // Couldn't find the constant. Create one.
-  auto *floatConst = new (this) SpirvConstantFloat(floatType, value, specConst);
-  floatConstants.push_back(floatConst);
-  return floatConst;
+  return getConstantFloat<float>(value, 32, specConst);
+}
+SpirvConstant *SpirvContext::getConstantFloat64(double value, bool specConst) {
+  return getConstantFloat<double>(value, 64, specConst);
 }
 }
 
 
 SpirvConstant *SpirvContext::getConstantBool(bool value, bool specConst) {
 SpirvConstant *SpirvContext::getConstantBool(bool value, bool specConst) {