2
0
Эх сурвалжийг харах

[spirv] Add support for specialization constants.

Ehsan 6 жил өмнө
parent
commit
3f63b41fb6

+ 21 - 7
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -72,14 +72,28 @@ public:
   // instructions into the annotationsBinary.
   uint32_t emitType(const SpirvType *);
 
-  // Emits an inter OpConstant instruction and returns its result-id.
-  // If such constant has already been emitted, just returns its resutl-id.
-  // Modifies the curTypeInst. Do not call in the middle of construction of
-  // another instruction.
   uint32_t getOrCreateConstant(SpirvConstant *);
-  uint32_t getOrCreateConstantInt(llvm::APInt value, const SpirvType *type);
-  uint32_t getOrCreateConstantFloat(llvm::APFloat value, const SpirvType *type);
-  uint32_t getOrCreateConstantComposite(SpirvConstantComposite *inst);
+
+  // Emits an OpConstant instruction and returns its result-id.
+  // For non-specialization constants, if an identical constant has already been
+  // emitted, returns the existing constant's result-id.
+  //
+  // Note1: This method modifies the curTypeInst. Do not call in the middle of
+  // construction of another instruction.
+  //
+  // Note 2: Integer constants may need to be generated for cases where there is
+  // no SpirvConstantInteger instruction in the module. For example, we need to
+  // emit an integer in order to create an array type. Therefore,
+  // 'getOrCreateConstantInt' has a different signature than others. If a
+  // constant instruction is provided, and it already has a result-id assigned,
+  // it will be used. Otherwise a new result-id will be allocated for the
+  // instruction.
+  uint32_t
+  getOrCreateConstantInt(llvm::APInt value, const SpirvType *type,
+                         bool isSpecConst,
+                         SpirvInstruction *constantInstruction = nullptr);
+  uint32_t getOrCreateConstantFloat(SpirvConstantFloat *);
+  uint32_t getOrCreateConstantComposite(SpirvConstantComposite *);
   uint32_t getOrCreateConstantNull(SpirvConstantNull *);
   uint32_t getOrCreateConstantBool(SpirvConstantBoolean *);
 

+ 103 - 58
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -141,6 +141,7 @@ void EmitVisitor::finalizeInstruction() {
   case spv::Op::OpSpecConstantTrue:
   case spv::Op::OpSpecConstantFalse:
   case spv::Op::OpSpecConstant:
+  case spv::Op::OpSpecConstantOp:
     typeConstantBinary.insert(typeConstantBinary.end(), curInst.begin(),
                               curInst.end());
     break;
@@ -499,17 +500,17 @@ bool EmitVisitor::visit(SpirvAtomic *inst) {
 
   curInst.push_back(typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getScope())),
-      context.getUIntType(32)));
+      context.getUIntType(32), /*isSpecConst */ false));
 
   curInst.push_back(typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getMemorySemantics())),
-      context.getUIntType(32)));
+      context.getUIntType(32), /*isSpecConst */ false));
 
   if (inst->hasComparator())
     curInst.push_back(typeHandler.getOrCreateConstantInt(
         llvm::APInt(32,
                     static_cast<uint32_t>(inst->getMemorySemanticsUnequal())),
-        context.getUIntType(32)));
+        context.getUIntType(32), /*isSpecConst */ false));
 
   if (inst->hasValue())
     curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getValue()));
@@ -528,16 +529,16 @@ bool EmitVisitor::visit(SpirvBarrier *inst) {
           ? typeHandler.getOrCreateConstantInt(
                 llvm::APInt(32,
                             static_cast<uint32_t>(inst->getExecutionScope())),
-                context.getUIntType(32))
+                context.getUIntType(32), /*isSpecConst */ false)
           : 0;
 
   const uint32_t memoryScopeId = typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getMemoryScope())),
-      context.getUIntType(32));
+      context.getUIntType(32), /*isSpecConst */ false);
 
   const uint32_t memorySemanticsId = typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getMemorySemantics())),
-      context.getUIntType(32));
+      context.getUIntType(32), /* isSpecConst */ false);
 
   initInstruction(inst);
   if (inst->isControlBarrier())
@@ -712,7 +713,7 @@ bool EmitVisitor::visit(SpirvNonUniformBinaryOp *inst) {
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
   curInst.push_back(typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
-      context.getUIntType(32)));
+      context.getUIntType(32), /* isSpecConst */ false));
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg1()));
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg2()));
   finalizeInstruction();
@@ -727,7 +728,7 @@ bool EmitVisitor::visit(SpirvNonUniformElect *inst) {
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
   curInst.push_back(typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
-      context.getUIntType(32)));
+      context.getUIntType(32), /* isSpecConst */ false));
   finalizeInstruction();
   emitDebugNameForInstruction(getOrAssignResultId<SpirvInstruction>(inst),
                               inst->getDebugName());
@@ -740,7 +741,7 @@ bool EmitVisitor::visit(SpirvNonUniformUnaryOp *inst) {
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
   curInst.push_back(typeHandler.getOrCreateConstantInt(
       llvm::APInt(32, static_cast<uint32_t>(inst->getExecutionScope())),
-      context.getUIntType(32)));
+      context.getUIntType(32), /* isSpecConst */ false));
   if (inst->hasGroupOp())
     curInst.push_back(static_cast<uint32_t>(inst->getGroupOp()));
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst->getArg()));
@@ -991,15 +992,11 @@ uint32_t EmitTypeHandler::getResultIdForType(const SpirvType *type,
 
 uint32_t EmitTypeHandler::getOrCreateConstant(SpirvConstant *inst) {
   if (auto *constInt = dyn_cast<SpirvConstantInteger>(inst)) {
-    const uint32_t resultId =
-        getOrCreateConstantInt(constInt->getValue(), constInt->getResultType());
-    inst->setResultId(resultId);
-    return resultId;
+    return getOrCreateConstantInt(constInt->getValue(),
+                                  constInt->getResultType(),
+                                  inst->isSpecConstant(), inst);
   } else if (auto *constFloat = dyn_cast<SpirvConstantFloat>(inst)) {
-    const uint32_t resultId = getOrCreateConstantFloat(
-        constFloat->getValue(), constFloat->getResultType());
-    inst->setResultId(resultId);
-    return resultId;
+    return getOrCreateConstantFloat(constFloat);
   } else if (auto *constComposite = dyn_cast<SpirvConstantComposite>(inst)) {
     return getOrCreateConstantComposite(constComposite);
   } else if (auto *constNull = dyn_cast<SpirvConstantNull>(inst)) {
@@ -1013,8 +1010,11 @@ uint32_t EmitTypeHandler::getOrCreateConstant(SpirvConstant *inst) {
 
 uint32_t EmitTypeHandler::getOrCreateConstantBool(SpirvConstantBoolean *inst) {
   const auto index = static_cast<uint32_t>(inst->getValue());
+  const bool isSpecConst = inst->isSpecConstant();
 
-  if (emittedConstantBools[index]) {
+  // SpecConstants are not unique. We should not reuse them. e.g. it is possible
+  // to have multiple OpSpecConstantTrue instructions.
+  if (!isSpecConst && emittedConstantBools[index]) {
     // Already emitted this constant. Reuse.
     inst->setResultId(emittedConstantBools[index]->getResultId());
   } else {
@@ -1024,8 +1024,9 @@ uint32_t EmitTypeHandler::getOrCreateConstantBool(SpirvConstantBoolean *inst) {
     curTypeInst.push_back(typeId);
     curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
     finalizeTypeInstruction();
-    // Remember this constant for the future
-    emittedConstantBools[index] = inst;
+    // Remember this constant for the future (if not a spec constant)
+    if (!isSpecConst)
+      emittedConstantBools[index] = inst;
   }
 
   return inst->getResultId();
@@ -1055,8 +1056,11 @@ uint32_t EmitTypeHandler::getOrCreateConstantNull(SpirvConstantNull *inst) {
   return inst->getResultId();
 }
 
-uint32_t EmitTypeHandler::getOrCreateConstantFloat(llvm::APFloat value,
-                                                   const SpirvType *type) {
+uint32_t EmitTypeHandler::getOrCreateConstantFloat(SpirvConstantFloat *inst) {
+  llvm::APFloat value = inst->getValue();
+  const SpirvType *type = inst->getResultType();
+  const bool isSpecConst = inst->isSpecConstant();
+
   assert(isa<FloatType>(type));
   const auto *floatType = dyn_cast<FloatType>(type);
   const auto typeBitwidth = floatType->getBitwidth();
@@ -1086,18 +1090,26 @@ uint32_t EmitTypeHandler::getOrCreateConstantFloat(llvm::APFloat value,
     }
   }
 
-  // If this constant has already been emitted, return its result-id.
   auto valueTypePair = std::pair<uint64_t, const SpirvType *>(
       valueToUse.bitcastToAPInt().getZExtValue(), type);
-  auto foundResultId = emittedConstantFloats.find(valueTypePair);
-  if (foundResultId != emittedConstantFloats.end())
-    return foundResultId->second;
+
+  // SpecConstant instructions are not unique, so we should not re-use existing
+  // spec constants.
+  if (!isSpecConst) {
+    // If this constant has already been emitted, return its result-id.
+    auto foundResultId = emittedConstantFloats.find(valueTypePair);
+    if (foundResultId != emittedConstantFloats.end()) {
+      const uint32_t existingConstantResultId = foundResultId->second;
+      inst->setResultId(existingConstantResultId);
+      return existingConstantResultId;
+    }
+  }
 
   // Start constructing the instruction
-  const uint32_t constantResultId = takeNextIdFunction();
   const uint32_t typeId = emitType(type);
-  initTypeInstruction(spv::Op::OpConstant);
+  initTypeInstruction(inst->getopcode());
   curTypeInst.push_back(typeId);
+  const uint32_t constantResultId = getOrAssignResultId<SpirvInstruction>(inst);
   curTypeInst.push_back(constantResultId);
 
   // Start constructing the value word / words
@@ -1127,19 +1139,32 @@ uint32_t EmitTypeHandler::getOrCreateConstantFloat(llvm::APFloat value,
 
   finalizeTypeInstruction();
 
-  // Remember this constant for future
-  emittedConstantFloats[valueTypePair] = constantResultId;
+  // Remember this constant for future (if not a SpecConstant)
+  if (!isSpecConst)
+    emittedConstantFloats[valueTypePair] = constantResultId;
+
   return constantResultId;
 }
 
-uint32_t EmitTypeHandler::getOrCreateConstantInt(llvm::APInt value,
-                                                 const SpirvType *type) {
-  // If this constant has already been emitted, return its result-id.
+uint32_t
+EmitTypeHandler::getOrCreateConstantInt(llvm::APInt value,
+                                        const SpirvType *type, bool isSpecConst,
+                                        SpirvInstruction *constantInstruction) {
   auto valueTypePair =
       std::pair<uint64_t, const SpirvType *>(value.getZExtValue(), type);
-  auto foundResultId = emittedConstantInts.find(valueTypePair);
-  if (foundResultId != emittedConstantInts.end())
-    return foundResultId->second;
+
+  // SpecConstant instructions are not unique, so we should not re-use existing
+  // spec constants.
+  if (!isSpecConst) {
+    // If this constant has already been emitted, return its result-id.
+    auto foundResultId = emittedConstantInts.find(valueTypePair);
+    if (foundResultId != emittedConstantInts.end()) {
+      const uint32_t existingConstantResultId = foundResultId->second;
+      if (constantInstruction)
+        constantInstruction->setResultId(existingConstantResultId);
+      return existingConstantResultId;
+    }
+  }
 
   assert(isa<IntegerType>(type));
   const auto *intType = dyn_cast<IntegerType>(type);
@@ -1147,10 +1172,19 @@ uint32_t EmitTypeHandler::getOrCreateConstantInt(llvm::APInt value,
   const auto isSigned = intType->isSignedInt();
 
   // Start constructing the instruction
-  const uint32_t constantResultId = takeNextIdFunction();
   const uint32_t typeId = emitType(type);
-  initTypeInstruction(spv::Op::OpConstant);
+  initTypeInstruction(isSpecConst ? spv::Op::OpSpecConstant
+                                  : spv::Op::OpConstant);
   curTypeInst.push_back(typeId);
+
+  // Assign a result-id if one has not been provided.
+  uint32_t constantResultId = 0;
+  if (constantInstruction)
+    constantResultId =
+        getOrAssignResultId<SpirvInstruction>(constantInstruction);
+  else
+    constantResultId = takeNextIdFunction();
+
   curTypeInst.push_back(constantResultId);
 
   // Start constructing the value word / words
@@ -1181,8 +1215,10 @@ uint32_t EmitTypeHandler::getOrCreateConstantInt(llvm::APInt value,
 
   finalizeTypeInstruction();
 
-  // Remember this constant for future
-  emittedConstantInts[valueTypePair] = constantResultId;
+  // Remember this constant for future (not needed for SpecConstants)
+  if (!isSpecConst)
+    emittedConstantInts[valueTypePair] = constantResultId;
+
   return constantResultId;
 }
 
@@ -1192,23 +1228,30 @@ EmitTypeHandler::getOrCreateConstantComposite(SpirvConstantComposite *inst) {
   for (auto constituent : inst->getConstituents())
     getOrCreateConstant(constituent);
 
-  auto found = std::find_if(
-      emittedConstantComposites.begin(), emittedConstantComposites.end(),
-      [inst](SpirvConstantComposite *cachedConstant) {
-        if (inst->getopcode() != cachedConstant->getopcode())
-          return false;
-        auto instConstituents = inst->getConstituents();
-        auto cachedConstituents = cachedConstant->getConstituents();
-        if (instConstituents.size() != cachedConstituents.size())
-          return false;
-        for (size_t i = 0; i < instConstituents.size(); ++i)
-          if (instConstituents[i]->getResultId() !=
-              cachedConstituents[i]->getResultId())
+  // SpecConstant instructions are not unique, so we should not re-use existing
+  // spec constants.
+  const bool isSpecConst = inst->isSpecConstant();
+  SpirvConstantComposite **found = nullptr;
+
+  if (!isSpecConst) {
+    found = std::find_if(
+        emittedConstantComposites.begin(), emittedConstantComposites.end(),
+        [inst](SpirvConstantComposite *cachedConstant) {
+          if (inst->getopcode() != cachedConstant->getopcode())
+            return false;
+          auto instConstituents = inst->getConstituents();
+          auto cachedConstituents = cachedConstant->getConstituents();
+          if (instConstituents.size() != cachedConstituents.size())
             return false;
-        return true;
-      });
+          for (size_t i = 0; i < instConstituents.size(); ++i)
+            if (instConstituents[i]->getResultId() !=
+                cachedConstituents[i]->getResultId())
+              return false;
+          return true;
+        });
+  }
 
-  if (found != emittedConstantComposites.end()) {
+  if (!isSpecConst && found != emittedConstantComposites.end()) {
     // We have already emitted this constant. Reuse.
     inst->setResultId((*found)->getResultId());
   } else {
@@ -1221,8 +1264,9 @@ EmitTypeHandler::getOrCreateConstantComposite(SpirvConstantComposite *inst) {
       curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(constituent));
     finalizeTypeInstruction();
 
-    // Remember this constant for the future
-    emittedConstantComposites.push_back(inst);
+    // Remember this constant for the future (if not a spec constant)
+    if (!isSpecConst)
+      emittedConstantComposites.push_back(inst);
   }
 
   return inst->getResultId();
@@ -1320,7 +1364,8 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
     // Emit the OpConstant instruction that is needed to get the result-id for
     // the array length.
     const auto length = getOrCreateConstantInt(
-        llvm::APInt(32, arrayType->getElementCount()), context.getUIntType(32));
+        llvm::APInt(32, arrayType->getElementCount()), context.getUIntType(32),
+        /* isSpecConst */ false);
 
     // Emit the OpTypeArray instruction
     const uint32_t elemTypeId = emitType(arrayType->getElementType());

+ 2 - 1
tools/clang/lib/SPIRV/LiteralTypeVisitor.cpp

@@ -289,8 +289,9 @@ bool LiteralTypeVisitor::visit(SpirvCompositeExtract *inst) {
     const QualType newType =
         getTypeWithCustomBitwidth(astContext, baseType, resultTypeBitwidth);
     updateTypeForInstruction(base, newType);
-    return true;
   }
+
+  return true;
 }
 
 bool LiteralTypeVisitor::updateTypeForCompositeMembers(

+ 0 - 2
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -815,7 +815,6 @@ SpirvVariable *SpirvBuilder::addStageBuiltinVar(const SpirvType *type,
                                                 spv::BuiltIn builtin,
                                                 SourceLocation loc) {
   // Note: We store the underlying type in the variable, *not* the pointer type.
-  // TODO(ehsan): type pointer should be added in lowering the type.
   auto *var =
       new (context) SpirvVariable(/*QualType*/ {}, /*id*/ 0, loc, storageClass);
   var->setResultType(type);
@@ -834,7 +833,6 @@ SpirvVariable *SpirvBuilder::addStageBuiltinVar(QualType type,
                                                 spv::BuiltIn builtin,
                                                 SourceLocation loc) {
   // Note: We store the underlying type in the variable, *not* the pointer type.
-  // TODO(ehsan): type pointer should be added in lowering the type.
   auto *var = new (context) SpirvVariable(type, /*id*/ 0, loc, storageClass);
   module->addVariable(var);