Jaebaek Seo 6 лет назад
Родитель
Сommit
66b9e355b9

+ 3 - 0
tools/clang/include/clang/SPIRV/AstTypeProbe.h

@@ -38,6 +38,9 @@ bool isScalarType(QualType type, QualType *scalarType = nullptr);
 bool isVectorType(QualType type, QualType *elemType = nullptr,
                   uint32_t *elemCount = nullptr);
 
+/// Returns true if the given type is enum type based on AST parse.
+bool isEnumType(QualType type);
+
 /// Returns true if the given type is a 1x1 matrix type.
 ///
 /// If elemType is not nullptr, writes the element type to *elemType.

+ 3 - 0
tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp

@@ -143,6 +143,9 @@ std::pair<uint32_t, uint32_t> AlignmentSizeCalculator::getAlignmentAndSize(
     return result;
   }
 
+  if (isEnumType(type))
+    type = astContext.IntTy;
+
   { // Rule 1
     QualType ty = {};
     if (isScalarType(type, &ty))

+ 19 - 1
tools/clang/lib/SPIRV/AstTypeProbe.cpp

@@ -97,7 +97,7 @@ bool isScalarType(QualType type, QualType *scalarType) {
   bool isScalar = false;
   QualType ty = {};
 
-  if (type->isBuiltinType()) {
+  if (type->isBuiltinType() || isEnumType(type)) {
     isScalar = true;
     ty = type;
   } else if (hlsl::IsHLSLVecType(type) && hlsl::GetHLSLVecSize(type) == 1) {
@@ -152,6 +152,17 @@ bool isVectorType(QualType type, QualType *elemType, uint32_t *elemCount) {
   return isVec;
 }
 
+bool isEnumType(QualType type) {
+  if (isa<EnumType>(type.getTypePtr()))
+    return true;
+
+  if (const auto *elaboratedType = type->getAs<ElaboratedType>())
+    if (isa<EnumType>(elaboratedType->desugar().getTypePtr()))
+      return true;
+
+  return false;
+}
+
 bool is1x1Matrix(QualType type, QualType *elemType) {
   if (!hlsl::IsHLSLMatType(type))
     return false;
@@ -385,6 +396,10 @@ uint32_t getElementSpirvBitwidth(const ASTContext &astContext, QualType type,
     return getElementSpirvBitwidth(astContext, ptrType->getPointeeType(),
                                    is16BitTypeEnabled);
 
+  // Enum types
+  if (isEnumType(type))
+    return 32;
+
   // Scalar types
   QualType ty = {};
   const bool isScalar = isScalarType(type, &ty);
@@ -1007,6 +1022,9 @@ bool isBoolOrVecOfBoolType(QualType type) {
 /// Returns true if the given type is a signed integer or vector of signed
 /// integer type.
 bool isSintOrVecOfSintType(QualType type) {
+  if (isEnumType(type))
+    return true;
+
   QualType elemType = {};
   return (isScalarType(type, &elemType) || isVectorType(type, &elemType)) &&
          elemType->isSignedIntegerType();

+ 10 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -815,6 +815,16 @@ SpirvVariable *DeclResultIdMapper::createStructOrStructArrayVarOfExplicitLayout(
   return var;
 }
 
+void DeclResultIdMapper::createEnumConstant(const EnumConstantDecl *decl) {
+  const auto *valueDecl = dyn_cast<ValueDecl>(decl);
+  const auto enumConstant =
+      spvBuilder.getConstantInt(astContext.IntTy, decl->getInitVal());
+  SpirvVariable *varInstr = spvBuilder.addModuleVar(
+      astContext.IntTy, spv::StorageClass::Private, /*isPrecise*/ false,
+      decl->getName(), enumConstant, decl->getLocation());
+  astDecls[valueDecl] = DeclSpirvInfo(varInstr);
+}
+
 SpirvVariable *DeclResultIdMapper::createCTBuffer(const HLSLBufferDecl *decl) {
   const auto usageKind =
       decl->isCBuffer() ? ContextUsageKind::CBuffer : ContextUsageKind::TBuffer;

+ 3 - 0
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -327,6 +327,9 @@ public:
   /// \brief Creates an external-visible variable and returns its instruction.
   SpirvVariable *createExternVar(const VarDecl *var);
 
+  /// \brief Creates an Enum constant.
+  void createEnumConstant(const EnumConstantDecl *decl);
+
   /// \brief Creates a cbuffer/tbuffer from the given decl.
   ///
   /// In the AST, cbuffer/tbuffer is represented as a HLSLBufferDecl, which is

+ 5 - 0
tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

@@ -444,6 +444,11 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
     return lowerType(ptrType->getPointeeType(), rule, isRowMajor, srcLoc);
   }
 
+  // Enum types
+  if (isEnumType(type)) {
+    return spvContext.getSIntType(32);
+  }
+
   emitError("lower type %0 unimplemented", srcLoc) << type->getTypeClassName();
   type->dump();
   return 0;

+ 13 - 3
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -725,6 +725,8 @@ void SpirvEmitter::doDecl(const Decl *decl) {
     doHLSLBufferDecl(bufferDecl);
   } else if (const auto *recordDecl = dyn_cast<RecordDecl>(decl)) {
     doRecordDecl(recordDecl);
+  } else if (const auto *enumDecl = dyn_cast<EnumDecl>(decl)) {
+    doEnumDecl(enumDecl);
   } else {
     emitError("decl type %0 unimplemented", decl->getLocation())
         << decl->getDeclKindName();
@@ -1251,6 +1253,11 @@ void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) {
         doVarDecl(varDecl);
 }
 
+void SpirvEmitter::doEnumDecl(const EnumDecl *decl) {
+  for (auto it = decl->enumerator_begin(); it != decl->enumerator_end(); ++it)
+    declIdMapper.createEnumConstant(*it);
+}
+
 void SpirvEmitter::doVarDecl(const VarDecl *decl) {
   if (!validateVKAttributes(decl))
     return;
@@ -5429,7 +5436,7 @@ void SpirvEmitter::initOnce(QualType varType, std::string varName,
     var->setStorageClass(spv::StorageClass::Private);
     storeValue(
         // Static function variable are of private storage class
-        var, doExpr(varInit), varInit->getType(), varInit->getLocEnd());
+        var, loadIfGLValue(varInit), varInit->getType(), varInit->getLocEnd());
   } else {
     spvBuilder.createStore(var, spvBuilder.getConstantNull(varType), loc);
   }
@@ -6565,6 +6572,9 @@ SpirvInstruction *SpirvEmitter::castToBool(SpirvInstruction *fromVal,
 SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
                                           QualType fromType, QualType toIntType,
                                           SourceLocation srcLoc) {
+  if (isEnumType(fromType))
+    fromType = astContext.IntTy;
+
   if (isSameType(astContext, fromType, toIntType))
     return fromVal;
 
@@ -10253,7 +10263,7 @@ bool SpirvEmitter::emitEntryFunctionWrapperForRayTracing(
     const auto varInfo =
         declIdMapper.getDeclEvalInfo(varDecl, varDecl->getLocation());
     if (const auto *init = varDecl->getInit()) {
-      storeValue(varInfo, doExpr(init), varDecl->getType(),
+      storeValue(varInfo, loadIfGLValue(init), varDecl->getType(),
                  init->getLocStart());
 
       // Update counter variable associated with global variables
@@ -10623,7 +10633,7 @@ bool SpirvEmitter::emitEntryFunctionWrapper(const FunctionDecl *decl,
     const auto varInfo =
         declIdMapper.getDeclEvalInfo(varDecl, varDecl->getLocation());
     if (const auto *init = varDecl->getInit()) {
-      storeValue(varInfo, doExpr(init), varDecl->getType(),
+      storeValue(varInfo, loadIfGLValue(init), varDecl->getType(),
                  init->getLocStart());
 
       // Update counter variable associated with global variables

+ 1 - 0
tools/clang/lib/SPIRV/SpirvEmitter.h

@@ -77,6 +77,7 @@ private:
   void doFunctionDecl(const FunctionDecl *decl);
   void doVarDecl(const VarDecl *decl);
   void doRecordDecl(const RecordDecl *decl);
+  void doEnumDecl(const EnumDecl *decl);
   void doHLSLBufferDecl(const HLSLBufferDecl *decl);
   void doImplicitDecl(const Decl *decl);
 

+ 48 - 0
tools/clang/test/CodeGenSPIRV/raytracing.nv.enum.hlsl

@@ -0,0 +1,48 @@
+// Run: %dxc -T lib_6_3
+// CHECK:  OpCapability RayTracingNV
+// CHECK:  OpExtension "SPV_NV_ray_tracing"
+// CHECK:  OpDecorate [[a:%\d+]] BuiltIn LaunchIdNV
+// CHECK:  OpDecorate [[b:%\d+]] BuiltIn LaunchSizeNV
+
+// CHECK-COUNT-1: [[rstype:%\d+]] = OpTypeAccelerationStructureNV
+RaytracingAccelerationStructure rs;
+
+struct Payload
+{
+  float4 color;
+};
+struct CallData
+{
+  float4 data;
+};
+
+//CHECK:      %First = OpVariable %_ptr_Private_int Private %int_0
+//CHECK-NEXT: %Second = OpVariable %_ptr_Private_int Private %int_1
+enum Number {
+  First,
+  Second,
+};
+
+//CHECK:      [[first:%\d+]] = OpLoad %int %First
+//CHECK-NEXT:                  OpStore %foo [[first]]
+static ::Number foo = First;
+
+[shader("raygeneration")]
+void main() {
+//CHECK:      [[second:%\d+]] = OpLoad %int %Second
+//CHECK-NEXT:                   OpStore %bar [[second]]
+  static ::Number bar = Second;
+
+  uint3 a = DispatchRaysIndex();
+  uint3 b = DispatchRaysDimensions();
+
+  Payload myPayload = { float4(0.0f,0.0f,0.0f,0.0f) };
+  CallData myCallData = { float4(0.0f,0.0f,0.0f,0.0f) };
+  RayDesc rayDesc;
+  rayDesc.Origin = float3(0.0f, 0.0f, 0.0f);
+  rayDesc.Direction = float3(0.0f, 0.0f, -1.0f);
+  rayDesc.TMin = 0.0f;
+  rayDesc.TMax = 1000.0f;
+  TraceRay(rs, 0x0, 0xff, 0, 1, 0, rayDesc, myPayload);
+  CallShader(0, myCallData);
+}

+ 107 - 0
tools/clang/test/CodeGenSPIRV/type.enum.hlsl

@@ -0,0 +1,107 @@
+// Run: %dxc -T ps_6_0 -E main
+
+//CHECK:      %First = OpVariable %_ptr_Private_int Private %int_0
+//CHECK-NEXT: %Second = OpVariable %_ptr_Private_int Private %int_1
+//CHECK-NEXT: %Third = OpVariable %_ptr_Private_int Private %int_3
+//CHECK-NEXT: %Fourth = OpVariable %_ptr_Private_int Private %int_n1
+enum Number {
+  First,
+  Second,
+  Third = 3,
+  Fourth = -1,
+};
+
+//CHECK:      %a = OpVariable %_ptr_Private_int Private
+//CHECK-NEXT: %b = OpVariable %_ptr_Workgroup_int Workgroup
+//CHECK-NEXT: %c = OpVariable %_ptr_Uniform_type_AppendStructuredBuffer_ Uniform
+
+//CHECK:      [[second:%\d+]] = OpLoad %int %Second
+//CHECK-NEXT:                   OpStore %a [[second]]
+static ::Number a = Second;
+groupshared Number b;
+AppendStructuredBuffer<Number> c;
+
+void testParam(Number param) {}
+void testParamTypeCast(int param) {}
+
+void main() {
+//CHECK:      [[a:%\d+]] = OpLoad %int %a
+//CHECK-NEXT:              OpStore %foo [[a]]
+  int foo = a;
+
+//CHECK:      [[fourth:%\d+]] = OpLoad %int %Fourth
+//CHECK-NEXT:                   OpStore %b [[fourth]]
+  b = Fourth;
+
+//CHECK:          [[c:%\d+]] = OpAccessChain %_ptr_Uniform_int %c %uint_0
+//CHECK-NEXT: [[third:%\d+]] = OpLoad %int %Third
+//CHECK-NEXT:                  OpStore [[c]] [[third]]
+  c.Append(Third);
+
+//CHECK:          [[c:%\d+]] = OpAccessChain %_ptr_Uniform_int %c %uint_0 %57
+//CHECK-NEXT: [[third:%\d+]] = OpLoad %int %Third
+//CHECK-NEXT:                  OpStore [[c]] [[third]]
+  c.Append(Number::Third);
+
+  Number d;
+//CHECK:      [[d:%\d+]] = OpLoad %int %d
+//CHECK-NEXT:              OpSelectionMerge %switch_merge None
+//CHECK-NEXT:              OpSwitch [[d]] %switch_default 0 %switch_0 1 %switch_1
+  switch (d) {
+    case First:
+      d = Second;
+      break;
+    case Second:
+      d = First;
+      break;
+    default:
+      d = Third;
+      break;
+  }
+
+//CHECK:      [[fourth:%\d+]] = OpLoad %int %Fourth
+//CHECK-NEXT:                   OpStore %e [[fourth]]
+  static ::Number e = Fourth;
+
+//CHECK:          [[d:%\d+]] = OpLoad %int %d
+//CHECK-NEXT: [[third:%\d+]] = OpLoad %int %Third
+//CHECK-NEXT:                  OpSLessThan %bool [[d]] [[third]]
+  if (d < Third) {
+//CHECK:       [[first:%\d+]] = OpLoad %int %First
+//CHECK-NEXT: [[second:%\d+]] = OpLoad %int %Second
+//CHECK-NEXT:    [[add:%\d+]] = OpIAdd %int [[first]] [[second]]
+//CHECK-NEXT:                   OpStore %d [[add]]
+    d = First + Second;
+  }
+
+//CHECK:      [[foo:%\d+]] = OpLoad %int %foo
+//CHECK-NEXT: [[foo:%\d+]] = OpBitcast %int [[foo]]
+//CHECK-NEXT:                OpStore %d [[foo]]
+  if (First < Third)
+    d = (Number)foo;
+
+//CHECK:      [[a:%\d+]] = OpLoad %int %a
+//CHECK-NEXT:              OpStore %param_var_param [[a]]
+//CHECK-NEXT:              OpFunctionCall %void %testParam %param_var_param
+  testParam(a);
+
+//CHECK:      [[second:%\d+]] = OpLoad %int %Second
+//CHECK-NEXT:                   OpStore %param_var_param_0 [[second]]
+//CHECK-NEXT:                   OpFunctionCall %void %testParam %param_var_param_0
+  testParam(Second);
+
+//CHECK:      [[a:%\d+]] = OpLoad %int %a
+//CHECK-NEXT:              OpStore %param_var_param_1 [[a]]
+//CHECK-NEXT:              OpFunctionCall %void %testParamTypeCast %param_var_param_1
+  testParamTypeCast(a);
+
+//CHECK:      OpStore %param_var_param_2 %int_1
+//CHECK-NEXT: OpFunctionCall %void %testParamTypeCast %param_var_param_2
+  testParamTypeCast(Second);
+
+//CHECK:        [[a:%\d+]] = OpLoad %int %a
+//CHECK-NEXT:   [[a:%\d+]] = OpBitcast %float [[a]]
+//CHECK-NEXT: [[sin:%\d+]] = OpExtInst %float {{%\d+}} Sin [[a]]
+//CHECK-NEXT:                OpStore %bar [[sin]]
+  float bar = sin(a);
+}

+ 2 - 0
tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

@@ -82,6 +82,7 @@ TEST_F(FileTest, CBufferType) { runFileTest("type.cbuffer.hlsl"); }
 TEST_F(FileTest, ConstantBufferType) {
   runFileTest("type.constant-buffer.hlsl");
 }
+TEST_F(FileTest, EnumType) { runFileTest("type.enum.hlsl"); }
 TEST_F(FileTest, TBufferType) { runFileTest("type.tbuffer.hlsl"); }
 TEST_F(FileTest, TextureBufferType) { runFileTest("type.texture-buffer.hlsl"); }
 TEST_F(FileTest, StructuredBufferType) {
@@ -1942,6 +1943,7 @@ TEST_F(FileTest, PreprocessorError) {
 TEST_F(FileTest, RayTracingNVRaygen) {
   runFileTest("raytracing.nv.raygen.hlsl");
 }
+TEST_F(FileTest, RayTracingNVEnum) { runFileTest("raytracing.nv.enum.hlsl"); }
 TEST_F(FileTest, RayTracingNVIntersection) {
   runFileTest("raytracing.nv.intersection.hlsl");
 }