瀏覽代碼

More fix for rewriter. (#2939)

1. print space for resource.
2. check use of global in init.
3. support nested struct.
Xiang Li 5 年之前
父節點
當前提交
0082ce0476

+ 3 - 0
tools/clang/lib/AST/DeclPrinter.cpp

@@ -1469,6 +1469,9 @@ void DeclPrinter::VisitHLSLUnusualAnnotation(const hlsl::UnusualAnnotation *UA)
       if (ra->RegisterOffset) {
         Out << "[" << ra->RegisterOffset << "]";
       }
+      if (ra->RegisterSpace.hasValue() != 0) {
+        Out << ", space" << ra->RegisterSpace.getValue();
+      }
       Out << ")";
     }
     break;

+ 1 - 1
tools/clang/test/HLSL/rewriter/correct_rewrites/attributes_gold.hlsl

@@ -62,7 +62,7 @@ int uav() {
 struct HSFoo {
   float3 pos : POSITION;
 };
-Texture2D<float4> tex1[10] : register(t20);
+Texture2D<float4> tex1[10] : register(t20, space10);
 [domain("quad")]
 [partitioning("integer")]
 [outputtopology("triangle_cw")]

+ 1 - 1
tools/clang/test/HLSL/rewriter/correct_rewrites/packreg_gold.hlsl

@@ -26,7 +26,7 @@ sampler myVar_1 : register(ps, s0[1]);
 sampler myVar_11 : register(ps, s0[2]);
 sampler myVar_16 : register(ps, s0[15]);
 sampler myVar_n1p5 : register(ps, s0);
-sampler myVar_s1 : register(ps, s0[1]);
+sampler myVar_s1 : register(ps, s0[1], space1);
 cbuffer MyBuffer {
   const float4 Element1 : packoffset(c0);
   const float1 Element2 : packoffset(c1);

+ 27 - 0
tools/clang/test/HLSLFileCheck/rewriter/init_use.hlsl

@@ -0,0 +1,27 @@
+// RUN: %dxr -E main -remove-unused-globals %s | FileCheck %s
+
+// Make sure global used for init is not removed.
+// CHECK:float c;
+// CHECK:float a;
+// CHECK:float d;
+// CHECK:int e;
+
+float c;
+float a;
+
+struct S {
+  float x;
+  float b;
+};
+
+static S s = {c, a};
+
+float d;
+static uint cast = d;
+
+int e;
+static int init = e;
+
+float main() : SV_Target {
+  return s.x + s.b + cast + init;
+}

+ 29 - 0
tools/clang/test/HLSLFileCheck/rewriter/nested_struct.hlsl

@@ -0,0 +1,29 @@
+// RUN: %dxr -E main -remove-unused-globals %s | FileCheck %s
+
+// Makre sure nested struct is not removed.
+
+// CHECK:struct A
+// CHECK:struct B
+// CHECK-NOT:Get(
+// CHECK:StructuredBuffer<C> buf : register(t0, space6)
+
+struct A {
+  float a;
+};
+
+struct B : A {
+  float b;
+};
+
+struct C {
+  B b;
+  float c;
+  float Get() { return c + b.b + b.a; }
+};
+
+StructuredBuffer<C> buf : register(t0, space6);
+
+float main(uint i:I) : SV_Target {
+  return buf[i].c;
+}
+

+ 117 - 11
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -81,6 +81,47 @@ static FunctionDecl *getFunctionWithBody(FunctionDecl *F) {
   return nullptr;
 }
 
+static void SaveTypeDecl(TagDecl *tagDecl,
+                          SmallPtrSetImpl<TypeDecl *> &visitedTypes) {
+  if (visitedTypes.count(tagDecl))
+    return;
+  visitedTypes.insert(tagDecl);
+  if (CXXRecordDecl *recordDecl = dyn_cast<CXXRecordDecl>(tagDecl)) {
+    // If template, save template args
+    if (const ClassTemplateSpecializationDecl *templateSpecializationDecl =
+            dyn_cast<ClassTemplateSpecializationDecl>(recordDecl)) {
+      const clang::TemplateArgumentList &args =
+          templateSpecializationDecl->getTemplateInstantiationArgs();
+      for (unsigned i = 0; i < args.size(); ++i) {
+        const clang::TemplateArgument &arg = args[i];
+        switch (arg.getKind()) {
+        case clang::TemplateArgument::ArgKind::Type:
+          if (TagDecl *tagDecl = arg.getAsType()->getAsTagDecl()) {
+            SaveTypeDecl(tagDecl, visitedTypes);
+          };
+          break;
+        default:
+          break;
+        }
+      }
+    }
+    // Add field types.
+    for (FieldDecl *fieldDecl : recordDecl->fields()) {
+      if (TagDecl *tagDecl = fieldDecl->getType()->getAsTagDecl()) {
+        SaveTypeDecl(tagDecl, visitedTypes);
+      }
+    }
+    // Add base types.
+    if (recordDecl->getNumBases()) {
+      for (auto &I : recordDecl->bases()) {
+        CXXRecordDecl *BaseDecl =
+            cast<CXXRecordDecl>(I.getType()->castAs<RecordType>()->getDecl());
+        SaveTypeDecl(BaseDecl, visitedTypes);
+      }
+    }
+  }
+}
+
 class VarReferenceVisitor : public RecursiveASTVisitor<VarReferenceVisitor> {
 private:
   SmallPtrSetImpl<VarDecl*>& m_unusedGlobals;
@@ -88,6 +129,10 @@ private:
   SmallVectorImpl<FunctionDecl*>& m_pendingFunctions;
   SmallPtrSetImpl<TypeDecl *> &m_visitedTypes;
 
+  void AddRecordType(TagDecl *tagDecl) {
+    SaveTypeDecl(tagDecl, m_visitedTypes);
+  }
+
 public:
   VarReferenceVisitor(
     SmallPtrSetImpl<VarDecl*>& unusedGlobals,
@@ -111,16 +156,33 @@ public:
       }
       if (fnDeclWithbody && fnDeclWithbody != fnDecl) {
         // In case fnDecl is only a decl, setDecl to fnDeclWithbody.
-        // fnDecl will be removed.
         ref->setDecl(fnDeclWithbody);
+        // Keep the fnDecl for now, since it might be predecl.
+        m_visitedFunctions.insert(fnDecl);
       }
     }
     else if (VarDecl* varDecl = dyn_cast_or_null<VarDecl>(valueDecl)) {
       m_unusedGlobals.erase(varDecl);
       if (TagDecl *tagDecl = varDecl->getType()->getAsTagDecl()) {
-        m_visitedTypes.insert(tagDecl);
+        AddRecordType(tagDecl);
       }
-      varDecl->getType();
+      if (Expr *initExp = varDecl->getInit()) {
+        if (InitListExpr *initList =
+                dyn_cast<InitListExpr>(initExp)) {
+          TraverseInitListExpr(initList);
+        } else if (ImplicitCastExpr *initCast = dyn_cast<ImplicitCastExpr>(initExp)) {
+          TraverseImplicitCastExpr(initCast);
+        } else if (DeclRefExpr *initRef = dyn_cast<DeclRefExpr>(initExp)) {
+          TraverseDeclRefExpr(initRef);
+        }
+      }
+    }
+    return true;
+  }
+  bool VisitMemberExpr(MemberExpr *expr) {
+    // Save nested struct type.
+    if (TagDecl *tagDecl = expr->getType()->getAsTagDecl()) {
+      m_visitedTypes.insert(tagDecl);
     }
     return true;
   }
@@ -132,7 +194,30 @@ public:
       }
     }
     if (CXXRecordDecl *recordDecl = expr->getRecordDecl()) {
-      m_visitedTypes.insert(recordDecl);
+      AddRecordType(recordDecl);
+    }
+    return true;
+  }
+  bool VisitHLSLBufferDecl(HLSLBufferDecl *bufDecl) {
+    if (!bufDecl->isCBuffer())
+      return false;
+    for (Decl *decl : bufDecl->decls()) {
+      if (VarDecl *constDecl = dyn_cast<VarDecl>(decl)) {
+        if (TagDecl *tagDecl = constDecl->getType()->getAsTagDecl()) {
+          AddRecordType(tagDecl);
+        }
+      } else if (isa<EmptyDecl>(decl)) {
+        // Nothing to do for this declaration.
+      } else if (CXXRecordDecl *recordDecl = dyn_cast<CXXRecordDecl>(decl)) {
+        m_visitedTypes.insert(recordDecl);
+      } else if (isa<FunctionDecl>(decl)) {
+        // A function within an cbuffer is effectively a top-level function,
+        // as it only refers to globally scoped declarations.
+        // Nothing to do for this declaration.
+      } else {
+        HLSLBufferDecl *inner = cast<HLSLBufferDecl>(decl);
+        VisitHLSLBufferDecl(inner);
+      }
     }
     return true;
   }
@@ -493,7 +578,8 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
   DenseMap<RecordDecl*, unsigned> anonymousRecordRefCounts;
   SmallPtrSet<FunctionDecl*, 128> unusedFunctions;
   SmallPtrSet<TypeDecl*, 32> unusedTypes;
-  SmallVector<VarDecl*, 32> nonStaticGlobals;
+  SmallVector<VarDecl *, 32> nonStaticGlobals;
+  SmallVector<HLSLBufferDecl *, 16> cbufferDecls;
   for (Decl *tuDecl : tu->decls()) {
     if (tuDecl->isImplicit()) continue;
 
@@ -518,6 +604,13 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
       continue;
     }
 
+    if (HLSLBufferDecl *CB = dyn_cast<HLSLBufferDecl>(tuDecl)) {
+      if (!CB->isCBuffer())
+        continue;
+      cbufferDecls.emplace_back(CB);
+      continue;
+    }
+
     FunctionDecl* fnDecl = dyn_cast_or_null<FunctionDecl>(tuDecl);
     if (fnDecl != nullptr) {
       FunctionDecl *fnDeclWithbody = getFunctionWithBody(fnDecl);
@@ -529,6 +622,11 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
 
     if (TagDecl *tagDecl = dyn_cast<TagDecl>(tuDecl)) {
       unusedTypes.insert(tagDecl);
+      if (CXXRecordDecl *recordDecl = dyn_cast<CXXRecordDecl>(tagDecl)) {
+        for (CXXMethodDecl *methodDecl : recordDecl->methods()) {
+          unusedFunctions.insert(methodDecl);
+        }
+      }
     }
   }
 
@@ -561,6 +659,10 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
     visitedFunctions.insert(pendingDecl);
     visitor.TraverseDecl(pendingDecl);
   }
+  // Traverse cbuffers to save types for cbuffer constant.
+  for (auto *CBDecl : cbufferDecls) {
+    visitor.TraverseDecl(CBDecl);
+  }
 
   // Don't bother doing work if there are no globals to remove.
   if (unusedGlobals.empty() && unusedFunctions.empty() && unusedTypes.empty()) {
@@ -575,15 +677,15 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
   }
   w << "//found " << unusedFunctions.size() << " functions to remove\n";
 
-  for (TypeDecl *typeDecl : visitedTypes) {
-    unusedTypes.erase(typeDecl);
-  }
-
   for (VarDecl *varDecl : nonStaticGlobals) {
     if (TagDecl *tagDecl = varDecl->getType()->getAsTagDecl()) {
-      unusedTypes.erase(tagDecl);
+      SaveTypeDecl(tagDecl, visitedTypes);
     }
   }
+  for (TypeDecl *typeDecl : visitedTypes) {
+    unusedTypes.erase(typeDecl);
+  }
+
   w << "//found " << unusedTypes.size() << " types to remove\n";
 
 
@@ -618,7 +720,11 @@ static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
   for (FunctionDecl *unusedFn : unusedFunctions) {
     // remove name of function to workaround assert when update lookup table.
     unusedFn->setDeclName(DeclarationName());
-    tu->removeDecl(unusedFn);
+    if (CXXMethodDecl *methodDecl = dyn_cast<CXXMethodDecl>(unusedFn)) {
+      methodDecl->getParent()->removeDecl(unusedFn);
+    } else {
+      tu->removeDecl(unusedFn);
+    }
   }
 
   for (TypeDecl *unusedTy : unusedTypes) {