Explorar el Código

[spirv] Do not add a constant to the module while visitor in progress.

If you visit a constant instruction whose return type is a an array,
that will require us to add a constant for the array length. If we add
that constant to the module itself, it will invalidate the iterators
iterating over the module constants.
Ehsan Nasiri hace 6 años
padre
commit
f06b956ac9

+ 17 - 10
tools/clang/include/clang/SPIRV/EmitVisitor.h

@@ -20,7 +20,6 @@ namespace spirv {
 class SpirvFunction;
 class SpirvBasicBlock;
 class SpirvType;
-class SpirvBuilder;
 
 class EmitTypeHandler {
 public:
@@ -82,12 +81,12 @@ public:
   };
 
 public:
-  EmitTypeHandler(ASTContext &astCtx, SpirvBuilder &builder,
+  EmitTypeHandler(ASTContext &astCtx, SpirvContext &spvContext,
                   std::vector<uint32_t> *debugVec,
                   std::vector<uint32_t> *decVec,
                   std::vector<uint32_t> *typesVec,
                   const std::function<uint32_t()> &takeNextIdFn)
-      : astContext(astCtx), spirvBuilder(builder), debugBinary(debugVec),
+      : astContext(astCtx), context(spvContext), debugBinary(debugVec),
         annotationsBinary(decVec), typeConstantBinary(typesVec),
         takeNextIdFunction(takeNextIdFn) {
     assert(decVec);
@@ -109,6 +108,12 @@ public:
   // instructions into the annotationsBinary.
   uint32_t emitType(const SpirvType *, SpirvLayoutRule);
 
+  // Emits an OpConstant instruction with uint32 type and returns its result-id.
+  // If such constant has already been emitted, just returns its resutl-id.
+  // Modifies the curTypeInst. Do not call in the middle of construction of
+  // another instruction.
+  uint32_t getOrCreateConstantUint32(uint32_t value);
+
 private:
   void initTypeInstruction(spv::Op op);
   void finalizeTypeInstruction();
@@ -174,7 +179,7 @@ private:
 
 private:
   ASTContext &astContext;
-  SpirvBuilder &spirvBuilder;
+  SpirvContext &context;
   std::vector<uint32_t> curTypeInst;
   std::vector<uint32_t> curDecorationInst;
   std::vector<uint32_t> *debugBinary;
@@ -182,6 +187,11 @@ private:
   std::vector<uint32_t> *typeConstantBinary;
   std::function<uint32_t()> takeNextIdFunction;
 
+  // The array type requires the result-id of an OpConstant for its length. In
+  // order to avoid duplicate OpConstant instructions, we keep a map of constant
+  // uint value to the result-id of the OpConstant for that value.
+  llvm::DenseMap<uint32_t, uint32_t> UintConstantValueToResultIdMap;
+
   // emittedTypes is a map that caches the result-id of types with a given list
   // of decorations in order to avoid emitting an identical type multiple times.
   using DecorationSetToTypeIdMap =
@@ -210,12 +220,11 @@ public:
 
 public:
   EmitVisitor(ASTContext &astCtx, SpirvContext &spvCtx,
-              const SpirvCodeGenOptions &opts, SpirvBuilder &builder)
+              const SpirvCodeGenOptions &opts)
       : Visitor(opts, spvCtx), id(0),
-        typeHandler(astCtx, builder, &debugBinary, &annotationsBinary,
+        typeHandler(astCtx, spvCtx, &debugBinary, &annotationsBinary,
                     &typeConstantBinary,
-                    [this]() -> uint32_t { return takeNextId(); }),
-        spirvBuilder(builder) {}
+                    [this]() -> uint32_t { return takeNextId(); }) {}
 
   // Visit different SPIR-V constructs for emitting.
   bool visit(SpirvModule *, Phase phase);
@@ -320,8 +329,6 @@ private:
   uint32_t id;
   // Handler for emitting types and their related instructions.
   EmitTypeHandler typeHandler;
-  // Use spirvBuilder in case we need to create constants.
-  SpirvBuilder &spirvBuilder;
   // Current instruction being built
   SmallVector<uint32_t, 16> curInst;
   // All preamble instructions in the following order:

+ 39 - 34
tools/clang/lib/SPIRV/EmitVisitor.cpp

@@ -543,31 +543,15 @@ bool EmitVisitor::visit(SpirvAtomic *inst) {
 }
 
 bool EmitVisitor::visit(SpirvBarrier *inst) {
-  // Note: do not invoke this lambda in the middle of creating an instruction.
-  // This lambda changes the curInst variable
-  auto emitConstant = [this](uint32_t value) {
-    SpirvConstant *constInstr = spirvBuilder.getConstantUint32(value);
-    // This constant has never been emitted
-    if (constInstr->getResultId() == 0) {
-      const uint32_t uint32TypeId = typeHandler.emitType(
-          constInstr->getResultType(), SpirvLayoutRule::Void);
-      initInstruction(spv::Op::OpConstant);
-      curInst.push_back(uint32TypeId);
-      curInst.push_back(getOrAssignResultId<SpirvInstruction>(constInstr));
-      curInst.push_back(value);
-      finalizeInstruction();
-    }
-    return constInstr->getResultId();
-  };
-
   const uint32_t executionScopeId =
       inst->isControlBarrier()
-          ? emitConstant(static_cast<uint32_t>(inst->getExecutionScope()))
+          ? typeHandler.getOrCreateConstantUint32(
+                static_cast<uint32_t>(inst->getExecutionScope()))
           : 0;
-  const uint32_t memoryScopeId =
-      emitConstant(static_cast<uint32_t>(inst->getMemoryScope()));
-  const uint32_t memorySemanticsId =
-      emitConstant(static_cast<uint32_t>(inst->getMemorySemantics()));
+  const uint32_t memoryScopeId = typeHandler.getOrCreateConstantUint32(
+      static_cast<uint32_t>(inst->getMemoryScope()));
+  const uint32_t memorySemanticsId = typeHandler.getOrCreateConstantUint32(
+      static_cast<uint32_t>(inst->getMemorySemantics()));
 
   initInstruction(inst);
   if (inst->isControlBarrier())
@@ -628,6 +612,15 @@ bool EmitVisitor::visit(SpirvConstantBoolean *inst) {
 }
 
 bool EmitVisitor::visit(SpirvConstantInteger *inst) {
+  // Note: Since array types need to create uint 32-bit constants for result-id
+  // of array length, the typeHandler keeps track of uint32 constant uniqueness.
+  // Therefore emitting uint32 constants should be handled by the typeHandler.
+  if (!inst->isSigned() && inst->getBitwidth() == 32) {
+    inst->setResultId(
+        typeHandler.getOrCreateConstantUint32(inst->getUnsignedInt32Value()));
+    return true;
+  }
+
   initInstruction(inst);
   curInst.push_back(inst->getResultTypeId());
   curInst.push_back(getOrAssignResultId<SpirvInstruction>(inst));
@@ -645,7 +638,9 @@ bool EmitVisitor::visit(SpirvConstantInteger *inst) {
       curInst.push_back(
           cast::BitwiseCast<uint32_t, int32_t>(inst->getSignedInt32Value()));
     } else {
-      curInst.push_back(inst->getUnsignedInt32Value());
+      // Unsigned 32-bit integers are special cases that are handled above.
+      assert(false && "typeHandler should handle creation of unsigned 32-bit "
+                      "integer constants");
     }
   }
   // 64-bit cases
@@ -1070,6 +1065,25 @@ uint32_t EmitTypeHandler::getResultIdForType(const SpirvType *type,
   return id;
 }
 
+uint32_t EmitTypeHandler::getOrCreateConstantUint32(uint32_t value) {
+  // If this constant has already been emitted, return its result-id.
+  auto foundResultId = UintConstantValueToResultIdMap.find(value);
+  if (foundResultId != UintConstantValueToResultIdMap.end())
+    return foundResultId->second;
+
+  const uint32_t constantResultId = takeNextIdFunction();
+  const SpirvType *uintType = context.getUIntType(32);
+  const uint32_t uint32TypeId = emitType(uintType, SpirvLayoutRule::Void);
+  initTypeInstruction(spv::Op::OpConstant);
+  curTypeInst.push_back(uint32TypeId);
+  curTypeInst.push_back(constantResultId);
+  curTypeInst.push_back(value);
+  finalizeTypeInstruction();
+
+  UintConstantValueToResultIdMap[value] = constantResultId;
+  return constantResultId;
+}
+
 void EmitTypeHandler::getDecorationsForType(const SpirvType *type,
                                             SpirvLayoutRule rule,
                                             DecorationList *decs) {
@@ -1211,23 +1225,14 @@ uint32_t EmitTypeHandler::emitType(const SpirvType *type,
   else if (const auto *arrayType = dyn_cast<ArrayType>(type)) {
     // Emit the OpConstant instruction that is needed to get the result-id for
     // the array length.
-    SpirvConstant *constant =
-        spirvBuilder.getConstantUint32(arrayType->getElementCount());
-    if (constant->getResultId() == 0) {
-      const uint32_t uint32TypeId = emitType(constant->getResultType(), rule);
-      initTypeInstruction(spv::Op::OpConstant);
-      curTypeInst.push_back(uint32TypeId);
-      curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(constant));
-      curTypeInst.push_back(arrayType->getElementCount());
-      finalizeTypeInstruction();
-    }
+    const auto length = getOrCreateConstantUint32(arrayType->getElementCount());
 
     // Emit the OpTypeArray instruction
     const uint32_t elemTypeId = emitType(arrayType->getElementType(), rule);
     initTypeInstruction(spv::Op::OpTypeArray);
     curTypeInst.push_back(id);
     curTypeInst.push_back(elemTypeId);
-    curTypeInst.push_back(getOrAssignResultId<SpirvInstruction>(constant));
+    curTypeInst.push_back(length);
     finalizeTypeInstruction();
   }
   // RuntimeArray types

+ 1 - 86
tools/clang/lib/SPIRV/SpirvBuilder.cpp

@@ -940,91 +940,6 @@ void SpirvBuilder::decorateNonUniformEXT(SpirvInstruction *target,
   module->addDecoration(decor);
 }
 
-/*
-SpirvConstant *SpirvBuilder::getConstantUint16(uint16_t value, bool specConst) {
-  SpirvConstant *result = context.getConstantUint16(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantInt16(int16_t value, bool specConst) {
-  SpirvConstant *result = context.getConstantInt16(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantUint32(uint32_t value, bool specConst) {
-  SpirvConstant *result = context.getConstantUint32(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantInt32(int32_t value, bool specConst) {
-  SpirvConstant *result = context.getConstantInt32(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantUint64(uint64_t value, bool specConst) {
-  SpirvConstant *result = context.getConstantUint64(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantInt64(int64_t value, bool specConst) {
-  SpirvConstant *result = context.getConstantInt64(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantFloat16(uint16_t value,
-                                                bool specConst) {
-  SpirvConstant *result = context.getConstantFloat16(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantFloat32(float value, bool specConst) {
-  SpirvConstant *result = context.getConstantFloat32(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantFloat64(double value, bool specConst) {
-  SpirvConstant *result = context.getConstantFloat64(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantBool(bool value, bool specConst) {
-  SpirvConstant *result = context.getConstantBool(value, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *
-SpirvBuilder::getConstantComposite(QualType compositeType,
-                                   llvm::ArrayRef<SpirvConstant *> constituents,
-                                   bool specConst) {
-  SpirvConstant *result =
-      context.getConstantComposite(compositeType, constituents, specConst);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantNull(const SpirvType *type) {
-  SpirvConstant *result = context.getConstantNull(type);
-  module->addConstant(result);
-  return result;
-}
-
-SpirvConstant *SpirvBuilder::getConstantNull(QualType type) {
-  SpirvConstant *result = context.getConstantNull(type);
-  module->addConstant(result);
-  return result;
-}
-*/
-
 SpirvConstant *SpirvBuilder::getConstantUint16(uint16_t value, bool specConst) {
   return getConstantInt<uint16_t>(value, /*isSigned*/ false, 16, specConst);
 }
@@ -1124,7 +1039,7 @@ std::vector<uint32_t> SpirvBuilder::takeModule() {
   // Run necessary visitor passes first
   LowerTypeVisitor lowerTypeVisitor(astContext, context, spirvOptions);
   CapabilityVisitor capabilityVisitor(context, spirvOptions, *this);
-  EmitVisitor emitVisitor(astContext, context, spirvOptions, *this);
+  EmitVisitor emitVisitor(astContext, context, spirvOptions);
 
   // Lower types
   module->invokeVisitor(&lowerTypeVisitor);

+ 1 - 1
tools/clang/test/CodeGenSPIRV/bezier.hull.hlsl2spv

@@ -138,6 +138,7 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 //                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 5 Offset 192
 //                OpMemberDecorate %HS_CONSTANT_DATA_OUTPUT 6 Offset 256
 //        %uint = OpTypeInt 32 0
+//      %uint_0 = OpConstant %uint 0
 //       %float = OpTypeFloat 32
 //         %int = OpTypeInt 32 1
 //      %uint_3 = OpConstant %uint 3
@@ -178,7 +179,6 @@ BEZIER_CONTROL_POINT SubDToBezierHS(InputPatch<VS_CONTROL_POINT_OUTPUT, MAX_POIN
 // %_ptr_Function_VS_CONTROL_POINT_OUTPUT = OpTypePointer Function %VS_CONTROL_POINT_OUTPUT
 // %_ptr_Function_BEZIER_CONTROL_POINT = OpTypePointer Function %BEZIER_CONTROL_POINT
 // %_ptr_Function_v3float = OpTypePointer Function %v3float
-//      %uint_0 = OpConstant %uint 0
 //     %float_1 = OpConstant %float 1
 //       %int_0 = OpConstant %int 0
 //     %float_2 = OpConstant %float 2

+ 5 - 5
tools/clang/test/CodeGenSPIRV/empty-struct-interface.vs.hlsl2spv

@@ -17,13 +17,13 @@ VSOut main(VSIn input)
 // OpMemoryModel Logical GLSL450
 // OpEntryPoint Vertex %main "main"
 // OpSource HLSL 600
-// OpName %bb_entry "bb.entry"
-// OpName %src_main "src.main"
 // OpName %main "main"
 // OpName %VSIn "VSIn"
 // OpName %param_var_input "param.var.input"
 // OpName %VSOut "VSOut"
+// OpName %src_main "src.main"
 // OpName %input "input"
+// OpName %bb_entry "bb.entry"
 // OpName %result "result"
 // %void = OpTypeVoid
 // %3 = OpTypeFunction %void
@@ -33,10 +33,10 @@ VSOut main(VSIn input)
 // %12 = OpTypeFunction %VSOut %_ptr_Function_VSIn
 // %_ptr_Function_VSOut = OpTypePointer Function %VSOut
 // %main = OpFunction %void None %3
-// %5 = OpLabel
+// %4 = OpLabel
 // %param_var_input = OpVariable %_ptr_Function_VSIn Function
-// %9 = OpCompositeConstruct %VSIn
-// %11 = OpFunctionCall %VSOut %src_main %param_var_input
+// %8 = OpCompositeConstruct %VSIn
+// %10 = OpFunctionCall %VSOut %src_main %param_var_input
 // OpReturn
 // OpFunctionEnd
 // %src_main = OpFunction %VSOut None %12

+ 3 - 3
tools/clang/test/CodeGenSPIRV/passthru-cs.hlsl2spv

@@ -51,6 +51,9 @@ void main( uint3 DTid : SV_DispatchThreadID )
 //                OpDecorate %type_RWByteAddressBuffer BufferBlock
 //         %int = OpTypeInt 32 1
 //        %uint = OpTypeInt 32 0
+//      %uint_4 = OpConstant %uint 4
+//      %uint_2 = OpConstant %uint 2
+//      %uint_0 = OpConstant %uint 0
 // %_runtimearr_uint = OpTypeRuntimeArray %uint
 // %type_ByteAddressBuffer = OpTypeStruct %_runtimearr_uint
 // %_ptr_Uniform_type_ByteAddressBuffer = OpTypePointer Uniform %type_ByteAddressBuffer
@@ -65,9 +68,6 @@ void main( uint3 DTid : SV_DispatchThreadID )
 // %_ptr_Function_uint = OpTypePointer Function %uint
 // %_ptr_Uniform_uint = OpTypePointer Uniform %uint
 //       %int_0 = OpConstant %int 0
-//      %uint_4 = OpConstant %uint 4
-//      %uint_2 = OpConstant %uint 2
-//      %uint_0 = OpConstant %uint 0
 //     %Buffer0 = OpVariable %_ptr_Uniform_type_ByteAddressBuffer Uniform
 //   %BufferOut = OpVariable %_ptr_Uniform_type_RWByteAddressBuffer Uniform
 // %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input