Quellcode durchsuchen

[SPIRV] Fix missing implicit decls (#3054)

* Add astcontext const decls to the astDecl table if not inserted before

Fix [SPIRV] Inline ray tracing doesn't compile #3047
1)Trying to add astcontext decls to the astDecls tables if decl is not
inserted before
2) Add the unit test of previously failed cs shader.

* [spirv] Create implicit constant VarDecls lazily.

Co-authored-by: Ehsan Nasiri <[email protected]>
JiaoluAMD vor 5 Jahren
Ursprung
Commit
a17bd0e347

+ 19 - 3
tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

@@ -636,7 +636,17 @@ DeclResultIdMapper::getDeclSpirvInfo(const ValueDecl *decl) const {
 
 SpirvInstruction *DeclResultIdMapper::getDeclEvalInfo(const ValueDecl *decl,
                                                       SourceLocation loc) {
-  if (const auto *info = getDeclSpirvInfo(decl)) {
+  const DeclSpirvInfo *info = getDeclSpirvInfo(decl);
+
+  // If DeclSpirvInfo is not found for this decl, it might be because it is an
+  // implicit VarDecl. All implicit VarDecls are lazily created in order to
+  // avoid creating large number of unused variables/constants/enums.
+  if (!info) {
+    tryToCreateImplicitConstVar(decl);
+    info = getDeclSpirvInfo(decl);
+  }
+
+  if (info) {
     if (info->indexInCTBuffer >= 0) {
       // If this is a VarDecl inside a HLSLBufferDecl, we need to do an extra
       // OpAccessChain to get the pointer to the variable since we created
@@ -3545,9 +3555,15 @@ DeclResultIdMapper::createRayTracingNVStageVar(spv::StorageClass sc,
   return retVal;
 }
 
-void DeclResultIdMapper::createRayTracingNVImplicitVar(const VarDecl *varDecl) {
+void DeclResultIdMapper::tryToCreateImplicitConstVar(const ValueDecl *decl) {
+  const VarDecl *varDecl = dyn_cast<VarDecl>(decl);
+  if (!varDecl || !varDecl->isImplicit())
+    return;
+
   APValue *val = varDecl->evaluateValue();
-  assert(val);
+  if(!val)
+    return;
+
   SpirvInstruction *constVal =
       spvBuilder.getConstantInt(astContext.UnsignedIntTy, val->getInt());
   constVal->setRValue(true);

+ 8 - 4
tools/clang/lib/SPIRV/DeclResultIdMapper.h

@@ -394,11 +394,15 @@ public:
   /// \brief Sets the entry function.
   void setEntryFunction(SpirvFunction *fn) { entryFunction = fn; }
 
-  /// Raytracing specific functions
-  /// \brief Handle specific implicit declarations present only in raytracing
-  /// stages.
-  void createRayTracingNVImplicitVar(const VarDecl *varDecl);
+  /// \brief If the given decl is an implicit VarDecl that evaluates to a
+  /// constant, it evaluates the constant and registers the resulting SPIR-V
+  /// instruction in the astDecls map. Otherwise returns without doing anything.
+  ///
+  /// Note: There are many cases where the front-end might create such implicit
+  /// VarDecls (such as some ray tracing enums).
+  void tryToCreateImplicitConstVar(const ValueDecl *);
 
+  /// Raytracing specific functions
   /// \brief Creates a ShaderRecordBufferNV block from the given decl.
   SpirvVariable *createShaderRecordBufferNV(const VarDecl *decl);
   SpirvVariable *createShaderRecordBufferNV(const HLSLBufferDecl *decl);

+ 5 - 18
tools/clang/lib/SPIRV/SpirvEmitter.cpp

@@ -632,8 +632,8 @@ void SpirvEmitter::doDecl(const Decl *decl) {
   if (isa<EmptyDecl>(decl) || isa<TypedefDecl>(decl))
     return;
 
+  // Implicit decls are lazily created when needed.
   if (decl->isImplicit()) {
-    doImplicitDecl(decl);
     return;
   }
 
@@ -1143,19 +1143,6 @@ void SpirvEmitter::doHLSLBufferDecl(const HLSLBufferDecl *bufferDecl) {
   }
 }
 
-void SpirvEmitter::doImplicitDecl(const Decl *decl) {
-  // We only handle specific implicit declaration for raytracing
-  // which are RayFlag/HitKind constant unsigned integers
-  // Ignore others
-  if (spvContext.isLib() || spvContext.isRay()) {
-    const VarDecl *implDecl = dyn_cast<VarDecl>(decl);
-    if (implDecl && (implDecl->getName().startswith(StringRef("RAY_FLAG")) ||
-                     implDecl->getName().startswith(StringRef("HIT_KIND")))) {
-      (void)declIdMapper.createRayTracingNVImplicitVar(implDecl);
-    }
-  }
-}
-
 void SpirvEmitter::doRecordDecl(const RecordDecl *recordDecl) {
   // Ignore implict records
   // Somehow we'll have implicit records with:
@@ -11291,8 +11278,8 @@ void SpirvEmitter::addFunctionToWorkQueue(hlsl::DXIL::ShaderKind shaderKind,
 
 SpirvInstruction *
 SpirvEmitter::processTraceRayInline(const CXXMemberCallExpr *expr) {
-  emitWarning("SPV_KHR_ray_query is currently a provisional extension and might"
-              "change in ways that are not backwards compatible",
+  emitWarning("SPV_KHR_ray_query is currently a provisional extension and "
+              "might change in ways that are not backwards compatible",
               expr->getExprLoc());
   const auto object = expr->getImplicitObjectArgument();
   uint32_t templateFlags = hlsl::GetHLSLResourceTemplateUInt(object->getType());
@@ -11374,8 +11361,8 @@ SpirvEmitter::processTraceRayInline(const CXXMemberCallExpr *expr) {
 SpirvInstruction *
 SpirvEmitter::processRayQueryIntrinsics(const CXXMemberCallExpr *expr,
                                         hlsl::IntrinsicOp opcode) {
-  emitWarning("SPV_KHR_ray_query is currently a provisional extension and might"
-              "change in ways that are not backwards compatible",
+  emitWarning("SPV_KHR_ray_query is currently a provisional extension and "
+              "might change in ways that are not backwards compatible",
               expr->getExprLoc());
   const auto object = expr->getImplicitObjectArgument();
   SpirvInstruction *rayqueryObj = loadIfAliasVarRef(object);

+ 35 - 0
tools/clang/test/CodeGenSPIRV/rayquery_equal.cs.hlsl

@@ -0,0 +1,35 @@
+// Run: %dxc -T cs_6_5 -E main
+RaytracingAccelerationStructure g_topLevel : register(t0, space0);
+RWTexture2D<float4> g_output : register(u1, space0);
+
+[numthreads(64, 1, 1)]
+void main(uint2 launchIndex: SV_DispatchThreadID)
+{
+    float3 T = (float3)0;
+    float sampleCount = 0;
+    RayDesc ray;
+
+    ray.Origin = float3(0, 0, 0);
+    ray.Direction = float3(0, 1, 0);
+    ray.TMin = 0.0;
+    ray.TMax = 1000.0;
+
+    RayQuery<RAY_FLAG_FORCE_OPAQUE> q;
+
+    q.TraceRayInline(g_topLevel, 0, 0xff, ray);
+// CHECK:  [[rayquery:%\d+]] = OpVariable %_ptr_Function_rayQueryProvisionalKHR Function
+    q.Proceed();
+// CHECK:  OpRayQueryProceedKHR %bool [[rayquery]]
+    if(q.CommittedStatus() == COMMITTED_TRIANGLE_HIT)
+// CHECK:  [[status:%\d+]] = OpRayQueryGetIntersectionTypeKHR %uint [[rayquery]] %uint_1
+// CHECK:  OpIEqual %bool [[status]] %uint_1
+    {
+        T += float3(1, 0, 1);
+    }
+    else
+    {
+        T += float3(0, 1, 0);
+    }
+
+    g_output[launchIndex] += float4(T, 1);
+}