2
0
Эх сурвалжийг харах

Rewriter: improvements plus extract uniforms to global scope (#2730)

- Added rewriter functions for extracting uniforms into global scope
- Will not work if name collides (namespaces have bugs)
- Values under cbuffer _Params, resources outside (more bugs)
- Added rewriter HLSLOptions and support all through RewriteWithOptions
- Refactored dxr.exe to use HLSLOptions and RewriteWithOptions
- Exposed Write*VersionInfo functions from dxclib
- Fixed issue with External[Lib/Fn] for version printing.
Tex Riddell 5 жил өмнө
parent
commit
2f4f6f6bcd

+ 15 - 1
include/dxc/Support/HLSLOptions.h

@@ -41,6 +41,7 @@ enum HlslFlags {
   NoArgumentUnused = (1 << 14),
   CoreOption = (1 << 15),
   ISenseOption = (1 << 16),
+  RewriteOption = (1 << 17),
 };
 
 enum ID {
@@ -64,7 +65,7 @@ static const unsigned CompilerFlags = HlslFlags::CoreOption;
 /// Flags for dxc.exe command-line tool.
 static const unsigned DxcFlags = HlslFlags::CoreOption | HlslFlags::DriverOption;
 /// Flags for dxr.exe command-line tool.
-static const unsigned DxrFlags = HlslFlags::CoreOption | HlslFlags::DriverOption;
+static const unsigned DxrFlags = HlslFlags::RewriteOption | HlslFlags::DriverOption;
 /// Flags for IDxcIntelliSense APIs.
 static const unsigned ISenseFlags = HlslFlags::CoreOption | HlslFlags::ISenseOption;
 
@@ -85,6 +86,16 @@ public:
   unsigned size() const { return DefineVector.size(); }
 };
 
+struct RewriterOpts {
+  bool Unchanged = false;                   // OPT_rw_unchanged
+  bool SkipFunctionBody = false;            // OPT_rw_skip_function_body
+  bool SkipStatic = false;                  // OPT_rw_skip_static
+  bool GlobalExternByDefault = false;       // OPT_rw_global_extern_by_default
+  bool KeepUserMacro = false;               // OPT_rw_keep_user_macro
+  bool ExtractEntryUniforms = false;        // OPT_rw_extract_entry_uniforms
+  bool RemoveUnusedGlobals = false;         // OPT_rw_remove_unused_globals
+};
+
 /// Use this class to capture all options.
 class DxcOpts {
 public:
@@ -174,6 +185,9 @@ public:
   unsigned long ValVerMajor = UINT_MAX, ValVerMinor = UINT_MAX; // OPT_validator_version
   unsigned ScanLimit = 0; // OPT_memdep_block_scan_limit
 
+  // Rewriter Options
+  RewriterOpts RWOpt;
+
   std::vector<std::string> Warnings;
 
   bool IsRootSignatureProfile();

+ 33 - 11
include/dxc/Support/HLSLOptions.td

@@ -26,6 +26,9 @@ def CoreOption : OptionFlag;
 // ISenseOption - This option is only supported for IntelliSense.
 def ISenseOption : OptionFlag;
 
+// RewriteOption - This is considered a "rewriter" HLSL option.
+def RewriteOption : OptionFlag;
+
 //////////////////////////////////////////////////////////////////////////////
 // Groups
 
@@ -66,6 +69,7 @@ def hlslcomp_Group : OptionGroup<"HLSL Compilation">, HelpText<"Compilation Opti
 def hlsloptz_Group : OptionGroup<"HLSL Optimization">, HelpText<"Optimization Options">;
 def hlslutil_Group : OptionGroup<"HLSL Utility">, HelpText<"Utility Options">;
 def hlslcore_Group : OptionGroup<"HLSL Core">, HelpText<"Common Options">;
+def hlslrewrite_Group : OptionGroup<"HLSL Rewriter">, HelpText<"Rewriter Options">;
 
 def spirv_Group : OptionGroup<"SPIR-V CodeGen">, HelpText<"SPIR-V CodeGen Options">; // SPIRV Change
 
@@ -82,11 +86,11 @@ def spirv_Group : OptionGroup<"SPIR-V CodeGen">, HelpText<"SPIR-V CodeGen Option
 // The general approach is to include only things that are in use, in the
 // same order as in Options.td.
 
-def D : JoinedOrSeparate<["-", "/"], "D">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
+def D : JoinedOrSeparate<["-", "/"], "D">, Group<hlslcomp_Group>, Flags<[CoreOption, RewriteOption]>,
     HelpText<"Define macro">;
 def H : Flag<["-"], "H">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
     HelpText<"Show header includes and nesting depth">;
-def I : JoinedOrSeparate<["-", "/"], "I">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
+def I : JoinedOrSeparate<["-", "/"], "I">, Group<hlslcomp_Group>, Flags<[CoreOption, RewriteOption]>,
     HelpText<"Add directory to include search path">;
 def O0 : Flag<["-", "/"], "O0">, Group<hlsloptz_Group>, Flags<[CoreOption]>,
     HelpText<"Optimization Level 0">;
@@ -212,13 +216,13 @@ def _help_question : Flag<["-", "/"], "?">, Flags<[DriverOption]>, Alias<help>;
 
 def ast_dump : Flag<["-", "/"], "ast-dump">, Flags<[CoreOption, DriverOption, HelpHidden]>,
   HelpText<"Dumps the parsed Abstract Syntax Tree.">; // should not be core, but handy workaround until explicit API written
-def external_lib : Separate<["-", "/"], "external">, Group<hlslcore_Group>, Flags<[DriverOption, HelpHidden]>,
+def external_lib : Separate<["-", "/"], "external">, Group<hlslcore_Group>, Flags<[DriverOption, RewriteOption, HelpHidden]>,
   HelpText<"External DLL name to load for compiler support">;
-def external_fn : Separate<["-", "/"], "external-fn">, Group<hlslcore_Group>, Flags<[DriverOption, HelpHidden]>,
+def external_fn : Separate<["-", "/"], "external-fn">, Group<hlslcore_Group>, Flags<[DriverOption, RewriteOption, HelpHidden]>,
   HelpText<"External function name to load for compiler support">;
 def fcgl : Flag<["-", "/"], "fcgl">, Group<hlslcore_Group>, Flags<[CoreOption, HelpHidden]>,
   HelpText<"Generate high-level code only">;
-def flegacy_macro_expansion : Flag<["-", "/"], "flegacy-macro-expansion">, Group<hlslcomp_Group>, Flags<[CoreOption, DriverOption]>,
+def flegacy_macro_expansion : Flag<["-", "/"], "flegacy-macro-expansion">, Group<hlslcomp_Group>, Flags<[CoreOption, RewriteOption, DriverOption]>,
     HelpText<"Expand the operands before performing token-pasting operation (fxc behavior)">;
 def flegacy_resource_reservation : Flag<["-", "/"], "flegacy-resource-reservation">, Group<hlslcomp_Group>, Flags<[CoreOption, DriverOption]>,
     HelpText<"Reserve unused explicit register assignments for compatibility with shader model 5.0 and below">;
@@ -234,13 +238,13 @@ def pack_optimized : Flag<["-", "/"], "pack-optimized">, Group<hlslcomp_Group>,
   HelpText<"Optimize signature packing assuming identical signature provided for each connecting stage">;
 def pack_optimized_ : Flag<["-", "/"], "pack_optimized">, Group<hlslcomp_Group>, Flags<[CoreOption, HelpHidden]>,
   HelpText<"Optimize signature packing assuming identical signature provided for each connecting stage">;
-def hlsl_version : Separate<["-", "/"], "HV">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
+def hlsl_version : Separate<["-", "/"], "HV">, Group<hlslcomp_Group>, Flags<[CoreOption, RewriteOption]>,
   HelpText<"HLSL version (2016, 2017, 2018). Default is 2018">;
-def no_warnings : Flag<["-", "/"], "no-warnings">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
+def no_warnings : Flag<["-", "/"], "no-warnings">, Group<hlslcomp_Group>, Flags<[CoreOption, RewriteOption]>,
   HelpText<"Suppress warnings">;
 def rootsig_define : Separate<["-", "/"], "rootsig-define">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
   HelpText<"Read root signature from a #define">;
-def enable_16bit_types: Flag<["-", "/"], "enable-16bit-types">, Flags<[CoreOption, DriverOption]>, Group<hlslcomp_Group>,
+def enable_16bit_types: Flag<["-", "/"], "enable-16bit-types">, Flags<[CoreOption, RewriteOption, DriverOption]>, Group<hlslcomp_Group>,
   HelpText<"Enable 16bit types and disable min precision types. Available in HLSL 2018 and shader model 6.2">;
 def ignore_line_directives : Flag<["-", "/"], "ignore-line-directives">, HelpText<"Ignore line directives">, Flags<[CoreOption]>, Group<hlslcomp_Group>;
 def auto_binding_space : Separate<["-", "/"], "auto-binding-space">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
@@ -251,7 +255,7 @@ def export_shaders_only : Flag<["-", "/"], "export-shaders-only">, Group<hlslcom
   HelpText<"Only export shaders when compiling a library">;
 def default_linkage : Separate<["-", "/"], "default-linkage">, Group<hlslcomp_Group>, Flags<[CoreOption]>,
   HelpText<"Set default linkage for non-shader functions when compiling or linking to a library target (internal, external)">;
-def encoding : Separate<["-", "/"], "encoding">, Group<hlslcomp_Group>, Flags<[CoreOption, DriverOption]>,
+def encoding : Separate<["-", "/"], "encoding">, Group<hlslcomp_Group>, Flags<[CoreOption, RewriteOption, DriverOption]>,
   HelpText<"Set default encoding for text outputs (utf8|utf16) default=utf8">;
 def validator_version : Separate<["-", "/"], "validator-version">, Group<hlslcomp_Group>, Flags<[CoreOption, HelpHidden]>,
   HelpText<"Override validator version for module.  Format: <major.minor> ; Default: DXIL.dll version or current internal version.">;
@@ -314,7 +318,7 @@ def target_profile : JoinedOrSeparate<["-", "/"], "T">, Flags<[CoreOption]>, Gro
   // VALRULE-TEXT:BEGIN
   HelpText<"Set target profile. \n\t<profile>: ps_6_0, ps_6_1, ps_6_2, ps_6_3, ps_6_4, ps_6_5, \n\t\t vs_6_0, vs_6_1, vs_6_2, vs_6_3, vs_6_4, vs_6_5, \n\t\t gs_6_0, gs_6_1, gs_6_2, gs_6_3, gs_6_4, gs_6_5, \n\t\t hs_6_0, hs_6_1, hs_6_2, hs_6_3, hs_6_4, hs_6_5, \n\t\t ds_6_0, ds_6_1, ds_6_2, ds_6_3, ds_6_4, ds_6_5, \n\t\t cs_6_0, cs_6_1, cs_6_2, cs_6_3, cs_6_4, cs_6_5, \n\t\t lib_6_1, lib_6_2, lib_6_3, lib_6_4, lib_6_5, \n\t\t ms_6_5, \n\t\t as_6_5, \n\t\t ">;
   // VALRULE-TEXT:END
-def entrypoint :  JoinedOrSeparate<["-", "/"], "E">, Flags<[CoreOption]>, Group<hlslcomp_Group>,
+def entrypoint :  JoinedOrSeparate<["-", "/"], "E">, Flags<[CoreOption, RewriteOption]>, Group<hlslcomp_Group>,
   HelpText<"Entry point name">;
 // /I <include> - already defined above
 def _vi : Flag<["-", "/"], "Vi">, Alias<H>, Flags<[CoreOption]>, Group<hlslcomp_Group>,
@@ -348,7 +352,7 @@ def Gis : Flag<["-", "/"], "Gis">, HelpText<"Force IEEE strictness">, Flags<[Cor
 
 def denorm : JoinedOrSeparate<["-", "/"], "denorm">, HelpText<"select denormal value options (any, preserve, ftz). any is the default.">, Flags<[CoreOption]>, Group<hlslcomp_Group>;
 
-def Fo : JoinedOrSeparate<["-", "/"], "Fo">, MetaVarName<"<file>">, HelpText<"Output object file">, Flags<[CoreOption, DriverOption]>, Group<hlslcomp_Group>;
+def Fo : JoinedOrSeparate<["-", "/"], "Fo">, MetaVarName<"<file>">, HelpText<"Output object file">, Flags<[CoreOption, RewriteOption, DriverOption]>, Group<hlslcomp_Group>;
 // def Fl : JoinedOrSeparate<["-", "/"], "Fl">, MetaVarName<"<file>">, HelpText<"Output a library">;
 def Fc : JoinedOrSeparate<["-", "/"], "Fc">, MetaVarName<"<file>">, HelpText<"Output assembly code listing file">, Flags<[DriverOption]>, Group<hlslcomp_Group>;
 //def Fx : JoinedOrSeparate<["-", "/"], "Fx">, MetaVarName<"<file>">, HelpText<"Output assembly code and hex listing file">;
@@ -426,5 +430,23 @@ def getprivate : JoinedOrSeparate<["-", "/"], "getprivate">, Flags<[DriverOption
 def nologo : Flag<["-", "/"], "nologo">, Group<hlslcore_Group>, Flags<[DriverOption]>,
   HelpText<"Suppress copyright message">;
 
+//////////////////////////////////////////////////////////////////////////////
+// Rewriter Options
+
+def rw_unchanged : Flag<["-", "/"], "unchanged">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Rewrite HLSL, without changes.">;
+def rw_skip_function_body : Flag<["-", "/"], "skip-fn-body">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Translate function definitions to declarations">;
+def rw_skip_static : Flag<["-", "/"], "skip-static">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Remove static functions and globals when used with -skip-fn-body">;
+def rw_global_extern_by_default : Flag<["-", "/"], "global-extern-by-default">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Set extern on non-static globals">;
+def rw_keep_user_macro : Flag<["-", "/"], "keep-user-macro">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Write out user defines after rewritten HLSL">;
+def rw_extract_entry_uniforms : Flag<["-", "/"], "extract-entry-uniforms">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Move uniform parameters from entry point to global scope">;
+def rw_remove_unused_globals : Flag<["-", "/"], "remove-unused-globals">, Group<hlslrewrite_Group>, Flags<[RewriteOption]>,
+  HelpText<"Remove unused static globals and functions">;
+
 // Also removed: compress, decompress, /Gch (child effect), /Gpp (partial precision)
 // /Op - no support for preshaders.

+ 1 - 1
include/llvm/Option/OptTable.h

@@ -42,7 +42,7 @@ public:
     unsigned ID;
     unsigned char Kind;
     unsigned char Param;
-    unsigned short Flags;
+    unsigned long Flags;
     unsigned short GroupID;
     unsigned short AliasID;
     const char *AliasArgs;

+ 20 - 1
lib/DxcSupport/HLSLOptions.cpp

@@ -680,7 +680,9 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
 
   // XXX TODO: Sort this out, since it's required for new API, but a separate argument for old APIs.
   if ((flagsToInclude & hlsl::options::DriverOption) &&
-      opts.TargetProfile.empty() && !opts.DumpBin && opts.Preprocess.empty() && !opts.RecompileFromBinary) {
+      !(flagsToInclude & hlsl::options::RewriteOption) &&
+      opts.TargetProfile.empty() && !opts.DumpBin && opts.Preprocess.empty() && !opts.RecompileFromBinary
+      ) {
     // Target profile is required in arguments only for drivers when compiling;
     // APIs take this through an argument.
     errors << "Target profile argument is missing";
@@ -880,6 +882,23 @@ int ReadDxcOpts(const OptTable *optionTable, unsigned flagsToInclude,
     return 1;
   }
 
+  // Rewriter Options
+  if (flagsToInclude & hlsl::options::RewriteOption) {
+    opts.RWOpt.Unchanged = Args.hasFlag(OPT_rw_unchanged, OPT_INVALID, false);
+    opts.RWOpt.SkipFunctionBody = Args.hasFlag(OPT_rw_skip_function_body, OPT_INVALID, false);
+    opts.RWOpt.SkipStatic = Args.hasFlag(OPT_rw_skip_static, OPT_INVALID, false);
+    opts.RWOpt.GlobalExternByDefault = Args.hasFlag(OPT_rw_global_extern_by_default, OPT_INVALID, false);
+    opts.RWOpt.KeepUserMacro = Args.hasFlag(OPT_rw_keep_user_macro, OPT_INVALID, false);
+    opts.RWOpt.ExtractEntryUniforms = Args.hasFlag(OPT_rw_extract_entry_uniforms, OPT_INVALID, false);
+    opts.RWOpt.RemoveUnusedGlobals = Args.hasFlag(OPT_rw_remove_unused_globals, OPT_INVALID, false);
+
+    if (opts.EntryPoint.empty() &&
+        (opts.RWOpt.RemoveUnusedGlobals || opts.RWOpt.ExtractEntryUniforms)) {
+      errors << "-rw-remove-unused-globals and -rw-extract-entry-uniforms requires entry point (-E) to be specified.";
+      return 1;
+    }
+  }
+
   opts.Args = std::move(Args);
   return 0;
 }

+ 7 - 1
tools/clang/include/clang/AST/PrettyPrinter.h

@@ -43,7 +43,8 @@ struct PrintingPolicy {
       Bool(LO.Bool), TerseOutput(false), PolishForDeclaration(false),
       Half(LO.HLSL || LO.Half), // HLSL Change - always print 'half' for HLSL
       MSWChar(LO.MicrosoftExt && !LO.WChar),
-      IncludeNewlines(true) { }
+      IncludeNewlines(true),
+      HLSLSuppressUniformParameters(false) { }
 
   /// \brief What language we're printing.
   LangOptions LangOpts;
@@ -164,6 +165,11 @@ struct PrintingPolicy {
 
   /// \brief When true, include newlines after statements like "break", etc.
   unsigned IncludeNewlines : 1;
+
+  // HLSL Change Begin
+  /// \brief When true, exclude uniform function parameters
+  unsigned HLSLSuppressUniformParameters : 1;
+  // HLSL Change Ends
 };
 
 } // end namespace clang

+ 4 - 0
tools/clang/lib/AST/DeclPrinter.cpp

@@ -510,6 +510,10 @@ void DeclPrinter::VisitFunctionDecl(FunctionDecl *D) {
       llvm::raw_string_ostream POut(Proto);
       DeclPrinter ParamPrinter(POut, SubPolicy, Indentation);
       for (unsigned i = 0, e = D->getNumParams(); i != e; ++i) {
+        if (Policy.HLSLSuppressUniformParameters &&
+            Policy.LangOpts.HLSL &&
+            D->getParamDecl(i)->hasAttr<HLSLUniformAttr>())  // HLSL Change
+          continue;
         if (i) POut << ", ";
         ParamPrinter.VisitParmVarDecl(D->getParamDecl(i));
       }

+ 1 - 1
tools/clang/lib/Sema/SemaHLSL.cpp

@@ -12405,7 +12405,7 @@ void hlsl::CustomPrintHLSLAttr(const clang::Attr *A, llvm::raw_ostream &Out, con
     Attr * noconst = const_cast<Attr*>(A);
     HLSLRootSignatureAttr *ACast = static_cast<HLSLRootSignatureAttr*>(noconst);
     Indent(Indentation, Out);
-    Out << "[RootSignature(" << ACast->getSignatureName() << ")]\n";
+    Out << "[RootSignature(\"" << ACast->getSignatureName() << "\")]\n";
     break;
   }
 

+ 1 - 1
tools/clang/test/HLSL/rewriter/correct_rewrites/array-length-rw_gold.hlsl

@@ -1,6 +1,6 @@
 // Rewrite unchanged result:
 const float4 planes[8];
-[RootSignature(CBV(b0, space=0, visibility=SHADER_VISIBILITY_ALL))]
+[RootSignature("CBV(b0, space=0, visibility=SHADER_VISIBILITY_ALL)")]
 float main() {
   float4 x = float4(1., 1., 1., 1.);
   for (uint i = 0; i < planes.Length; ++i) {

+ 15 - 0
tools/clang/test/HLSL/rewriter/rewrite-uniforms.hlsl

@@ -0,0 +1,15 @@
+#define RS "RootFlags(0),DescriptorTable(UAV(u0, numDescriptors = 1), CBV(b0, numDescriptors = 1))"
+
+[RootSignature(RS)]
+[numthreads(4,8,16)]
+void FloatFunc(uint3 id : SV_DispatchThreadID, uniform RWStructuredBuffer<float> buf, uniform uint ui)
+{
+    buf[id.x+id.y+id.z] = id.x;
+}
+
+[RootSignature(RS)]
+[numthreads(4,8,16)]
+void IntFunc(uint3 id : SV_DispatchThreadID, uniform RWStructuredBuffer<int> buf, uniform uint ui)
+{
+    buf[id.x+id.y+id.z] = id.x + ui;
+}

+ 38 - 12
tools/clang/tools/dxclib/dxc.cpp

@@ -1070,9 +1070,14 @@ bool GetDLLProductVersionInfo(const char *dllPath, std::string &productVersion)
   return false;
 }
 
-// Collects compiler/validator version info
-void DxcContext::GetCompilerVersionInfo(llvm::raw_string_ostream &OS) {
-  if (m_dxcSupport.IsEnabled()) {
+namespace dxc {
+
+// Writes compiler version info to stream
+void WriteDxCompilerVersionInfo(llvm::raw_ostream &OS,
+                                const char *ExternalLib,
+                                const char *ExternalFn,
+                                DxcDllSupport &DxcSupport) {
+  if (DxcSupport.IsEnabled()) {
     UINT32 compilerMajor = 1;
     UINT32 compilerMinor = 0;
     CComPtr<IDxcVersionInfo> VerInfo;
@@ -1083,10 +1088,12 @@ void DxcContext::GetCompilerVersionInfo(llvm::raw_string_ostream &OS) {
     CComPtr<IDxcVersionInfo2> VerInfo2;
 #endif // SUPPORT_QUERY_GIT_COMMIT_INFO
 
-    const char *compilerName =
-      m_Opts.ExternalFn.empty() ? "dxcompiler.dll" : m_Opts.ExternalFn.data();
+    const char *dllName = !ExternalLib ? "dxcompiler.dll" : ExternalLib;
+    std::string compilerName(dllName);
+    if (ExternalFn)
+      compilerName = compilerName + "!" + ExternalFn;
 
-    if (SUCCEEDED(CreateInstance(CLSID_DxcCompiler, &VerInfo))) {
+    if (SUCCEEDED(DxcSupport.CreateInstance(CLSID_DxcCompiler, &VerInfo))) {
       VerInfo->GetVersion(&compilerMajor, &compilerMinor);
 #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
       if (SUCCEEDED(VerInfo->QueryInterface(&VerInfo2)))
@@ -1095,13 +1102,16 @@ void DxcContext::GetCompilerVersionInfo(llvm::raw_string_ostream &OS) {
       OS << compilerName << ": " << compilerMajor << "." << compilerMinor;
     }
     // compiler.dll 1.0 did not support IdxcVersionInfo
-    else if (m_Opts.ExternalFn.empty()) {
+    else if (!ExternalLib) {
       OS << compilerName << ": " << 1 << "." << 0;
+    } else {
+      // ExternalLib/ExternalFn, no version info:
+      OS << compilerName;
     }
 
 #ifdef _WIN32
     unsigned int version[4];
-    if (GetDLLFileVersionInfo(compilerName, version)) {
+    if (GetDLLFileVersionInfo(dllName, version)) {
       // back-compat - old dev buidls had version 3.7.0.0
       if (version[0] == 3 && version[1] == 7 && version[2] == 0 && version[3] == 0) {
 #endif
@@ -1115,17 +1125,18 @@ void DxcContext::GetCompilerVersionInfo(llvm::raw_string_ostream &OS) {
       }
       else {
         std::string productVersion;
-        if (GetDLLProductVersionInfo(compilerName, productVersion)) {
+        if (GetDLLProductVersionInfo(dllName, productVersion)) {
           OS << " - " << productVersion;
         }
       }
     }
 #endif
   }
+}
 
-  // Print validator if exists
-  DxcDllSupport DxilSupport;
-  DxilSupport.InitializeForDll(L"dxil.dll", "DxcCreateInstance");
+// Writes compiler version info to stream
+void WriteDXILVersionInfo(llvm::raw_ostream &OS,
+                          DxcDllSupport &DxilSupport) {
   if (DxilSupport.IsEnabled()) {
     CComPtr<IDxcVersionInfo> VerInfo;
     if (SUCCEEDED(DxilSupport.CreateInstance(CLSID_DxcValidator, &VerInfo))) {
@@ -1149,6 +1160,21 @@ void DxcContext::GetCompilerVersionInfo(llvm::raw_string_ostream &OS) {
   }
 }
 
+} // namespace dxc
+
+// Collects compiler/validator version info
+void DxcContext::GetCompilerVersionInfo(llvm::raw_string_ostream &OS) {
+  WriteDxCompilerVersionInfo(OS,
+    m_Opts.ExternalLib.empty() ? nullptr : m_Opts.ExternalLib.data(),
+    m_Opts.ExternalFn.empty() ? nullptr : m_Opts.ExternalFn.data(),
+    m_dxcSupport);
+
+  // Print validator if exists
+  DxcDllSupport DxilSupport;
+  DxilSupport.InitializeForDll(L"dxil.dll", "DxcCreateInstance");
+  WriteDXILVersionInfo(OS, DxilSupport);
+}
+
 #ifndef VERSION_STRING_SUFFIX
 #define VERSION_STRING_SUFFIX ""
 #endif

+ 14 - 0
tools/clang/tools/dxclib/dxc.h

@@ -13,8 +13,22 @@
 #ifndef __DXC_DXCLIB__
 #define __DXC_DXCLIB__
 
+namespace llvm {
+class raw_ostream;
+}
+
 namespace dxc
 {
+class DxcDllSupport;
+
+// Writes compiler version info to stream
+void WriteDxCompilerVersionInfo(llvm::raw_ostream &OS,
+                                const char *ExternalLib,
+                                const char *ExternalFn,
+                                dxc::DxcDllSupport &DxcSupport);
+void WriteDXILVersionInfo(llvm::raw_ostream &OS,
+                          dxc::DxcDllSupport &DxilSupport);
+
 #ifdef _WIN32
 int main(int argc, const wchar_t **argv_);
 #else

+ 5 - 1
tools/clang/tools/dxr/CMakeLists.txt

@@ -5,6 +5,7 @@
 set( LLVM_LINK_COMPONENTS
   ${LLVM_TARGETS_TO_BUILD}
   dxcsupport
+  Option     # option library
   Support    # For Atomic increment/decrement
   )
 
@@ -14,13 +15,16 @@ add_clang_executable(dxr
   )
 
 target_link_libraries(dxr
+  dxclib
   dxcompiler
   )
 
 set_target_properties(dxr PROPERTIES VERSION ${CLANG_EXECUTABLE_VERSION})
 # set_target_properties(dxr PROPERTIES ENABLE_EXPORTS 1)
 
-add_dependencies(dxr dxcompiler)
+include_directories(${LLVM_SOURCE_DIR}/tools/clang/tools)
+
+add_dependencies(dxr dxclib dxcompiler)
 
 if(UNIX)
   set(CLANGXX_LINK_OR_COPY create_symlink)

+ 87 - 169
tools/clang/tools/dxr/dxr.cpp

@@ -9,69 +9,27 @@
 //                                                                           //
 ///////////////////////////////////////////////////////////////////////////////
 
-#include "dxc/Support/WinIncludes.h"
-
 #include "dxc/Support/Global.h"
 #include "dxc/Support/Unicode.h"
+#include "dxc/Support/WinIncludes.h"
+#include "dxc/Support/WinFunctions.h"
 #include "dxc/Support/microcom.h"
+#include "dxclib/dxc.h"
 #include <vector>
 #include <string>
 
 #include "dxc/dxcapi.h"
 #include "dxc/dxctools.h"
 #include "dxc/Support/dxcapi.use.h"
+#include "dxc/Support/HLSLOptions.h"
+#include "llvm/Support/raw_ostream.h"
 
 inline bool wcsieq(LPCWSTR a, LPCWSTR b) { return _wcsicmp(a, b) == 0; }
 
 using namespace dxc;
+using namespace llvm::opt;
+using namespace hlsl::options;
 
-class DxrContext {
-
-private:
-  bool m_outputWarnings;
-  LPCWSTR m_pEntryPoint;
-  LPCWSTR m_pName;
-  DxcDefine *m_pDefines;
-  UINT32 m_definesCount;
-  DxcDllSupport& m_dxcSupport;
-
-public:
-  DxrContext(LPCWSTR pName, LPCWSTR pEntryPoint, DxcDefine *pDefines,
-             UINT32 definesCount, bool outputWarnings,
-             DxcDllSupport& dxcSupport) :
-    m_pName(pName), m_pEntryPoint(pEntryPoint), m_pDefines(pDefines),
-    m_definesCount(definesCount), m_outputWarnings(outputWarnings),
-    m_dxcSupport(dxcSupport) {
-  }
-
-  void RunRemoveUnusedGlobals();
-  void RunRewriteUnchanged();
-  HRESULT ReadFromFile(LPCWSTR pFileName, _COM_Outptr_ IDxcBlobEncoding** pBlobEncoding);
-};
-
-void DxrContext::RunRemoveUnusedGlobals() {
-  CComPtr<IDxcRewriter> pRewriter;
-  CComPtr<IDxcOperationResult> pRewriteResult;
-  CComPtr<IDxcBlobEncoding> pBlobEncoding;
-
-  IFT_Data(ReadFromFile(m_pName, &pBlobEncoding), m_pName);
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcRewriter, &pRewriter));
-  IFT(pRewriter->RemoveUnusedGlobals(pBlobEncoding, m_pEntryPoint, m_pDefines, m_definesCount, &pRewriteResult));
-
-  WriteOperationResultToConsole(pRewriteResult, m_outputWarnings);
-}
-
-void DxrContext::RunRewriteUnchanged() {
-  CComPtr<IDxcRewriter> pRewriter;
-  CComPtr<IDxcOperationResult> pRewriteResult;
-  CComPtr<IDxcBlobEncoding> pBlobEncoding;
-
-  IFT(ReadFromFile(m_pName, &pBlobEncoding));
-  IFT(m_dxcSupport.CreateInstance(CLSID_DxcRewriter, &pRewriter));
-  IFT(pRewriter->RewriteUnchanged(pBlobEncoding, m_pDefines, m_definesCount, &pRewriteResult));
-
-  WriteOperationResultToConsole(pRewriteResult, m_outputWarnings);
-}
 
 class FileMapDxcBlobEncoding : public IDxcBlobEncoding {
 private:
@@ -152,141 +110,101 @@ public:
   }
 };
 
-_Use_decl_annotations_
-HRESULT DxrContext::ReadFromFile(LPCWSTR pFileName, IDxcBlobEncoding** pBlobEncoding) {
-  return FileMapDxcBlobEncoding::CreateForFile(pFileName, pBlobEncoding);
-}
-
-void PrintUsage() {
-  wprintf(L"Usage: dxr.exe MODE FILE OPTIONS\n");
-  wprintf(L"MODE can be either -unchanged or -remove-unused-globals.\n");
-  wprintf(L"FILE is the .hlsl file to be rewritten.\n");
-  wprintf(L"  Note that this file will be read using the system default Windows ANSI code page.\n");
-  wprintf(L"OPTIONS currently supports:\n"
-          L"  -E<entry point>\n"
-          L"  -D<define-name>\n"
-          L"  -D<define-name>=<define-value>\n"
-          L"  -external <dxcompiler-path> <entry-point>\n"
-          L"  -no-warnings\n");
-}
-
 int __cdecl wmain(int argc, const wchar_t **argv_) {
-  if (argc < 2 || wcsieq(argv_[1], L"-?") || wcsieq(argv_[1], L"/?")) {
-    PrintUsage();
-    return 0;
-  }
-
-  if (argc < 3) {
-    PrintUsage();
-    return 1;
-  }
-
-  // Determine type of rewrite
-  LPCWSTR pModeName = argv_[1];
-  int modeNum = 0;
-
-  if (wcsieq(pModeName, L"-unchanged")) {
-    modeNum = 0;
-  }
-  else if (wcsieq(pModeName, L"-remove-unused-globals")) {
-    modeNum = 1;
-  }
-  else {
-    printf("Mode does not exist: [%S].\n", pModeName);
-    return 1;
-  }
-
-  // Get the file name
-  LPCWSTR pFileName = argv_[2];
-
-  // Parse command line options
+  if (FAILED(DxcInitThreadMalloc())) return 1;
+  DxcSetThreadMallocToDefault();
   try {
-    DxcDllSupport dxcSupport;
-    bool outputWarnings = true;
-    LPCWSTR pEntryPoint = nullptr;
-    int definesCount = 0;
-    std::vector<DxcDefine> definesVector;
-    std::vector<std::wstring> definesPieces; // This ensures that the memory needed for the DXCDefine ptrs won't get freed too soon
-
-    for (int i = 2; i < argc; i++) {
-      std::wstring argString(argv_[i]);
-      std::wstring start = argString.substr(0, 2);
+    if (initHlslOptTable()) throw std::bad_alloc();
 
-      if (wcsieq(argv_[i], L"-external")) {
-        if (dxcSupport.IsEnabled()) {
-          printf("-external already specified\n");
-          return 1;
-        }
-        if (i+2 >= argc) {
-          printf("-external requires a DLL name and function entry point\n");
-          return 1;
-        }
-        ++i;
-        CW2A entryFnName(argv_[i+1]);
-        HRESULT hrLoad = dxcSupport.InitializeForDll(argv_[i], entryFnName);
-        if (FAILED(hrLoad)) {
-          wprintf(L"Unable to load support for external DLL %s with function %s - 0x%08x\n", argv_[i], argv_[i+1], hrLoad);
-          return 1;
-        }
-        ++i; // also consumed the function name
-        continue;
-      }
-
-      if (wcsieq(argv_[i], L"-no-warnings")) {
-        outputWarnings = false;
-        continue;
-      }
+    // Parse command line options.
+    const OptTable *optionTable = getHlslOptTable();
+    MainArgs argStrings(argc, argv_);
+    DxcOpts dxcOpts;
+    DxcDllSupport dxcSupport;
 
-      if (wcsieq(start.c_str(), L"-E") || wcsieq(start.c_str(), L"/E")) {
-         pEntryPoint = argv_[i];
-         pEntryPoint += 2;
-         continue;
+    // Read options and check errors.
+    {
+      std::string errorString;
+      llvm::raw_string_ostream errorStream(errorString);
+      int optResult =
+          ReadDxcOpts(optionTable, DxrFlags, argStrings, dxcOpts, errorStream);
+      errorStream.flush();
+      if (errorString.size()) {
+        fprintf(stderr, "dxc failed : %s\n", errorString.data());
       }
-
-      if (wcsieq(start.c_str(), L"-D") || wcsieq(start.c_str(), L"/D")) {
-        std::wstring thisDefine = argv_[i];
-        int index = thisDefine.find(L"=");
-        if (index != std::wstring::npos) {
-          std::wstring tempName = thisDefine.substr(2, index - 2);
-          std::wstring tempVal = thisDefine.substr(index + 1, std::wstring::npos);
-
-          definesPieces.push_back(tempName);
-          definesPieces.push_back(tempVal);
-        }
-        else {
-          LPCWSTR name = argv_[i];
-          name += 2;
-          LPCWSTR value = nullptr;
-          definesVector.push_back(GetDefine(name, value));
-        }
-        definesCount++;
+      if (optResult != 0) {
+        return optResult;
       }
     }
 
-    // Now that definesPieces will no longer be changed, use it to construct the DXCDefines
-    for (size_t i = 0; i < definesPieces.size(); i += 2) {
-      definesVector.push_back(GetDefine(definesPieces[i].c_str(), definesPieces[i + 1].c_str()));
+    // Apply defaults.
+    if (dxcOpts.EntryPoint.empty() && !dxcOpts.RecompileFromBinary) {
+      dxcOpts.EntryPoint = "main";
     }
 
-    DxcDefine *pDefinesArray = nullptr;
-    if (!definesVector.empty())
-      pDefinesArray = definesVector.data(); 
+    // Setup a helper DLL.
+    {
+      std::string dllErrorString;
+      llvm::raw_string_ostream dllErrorStream(dllErrorString);
+      int dllResult = SetupDxcDllSupport(dxcOpts, dxcSupport, dllErrorStream);
+      dllErrorStream.flush();
+      if (dllErrorString.size()) {
+        fprintf(stderr, "%s\n", dllErrorString.data());
+      }
+      if (dllResult)
+        return dllResult;
+    }
 
     EnsureEnabled(dxcSupport);
-    DxrContext context(pFileName, pEntryPoint, pDefinesArray, definesCount, outputWarnings, dxcSupport);
+    // Handle help request, which overrides any other processing.
+    if (dxcOpts.ShowHelp) {
+      std::string helpString;
+      llvm::raw_string_ostream helpStream(helpString);
+      std::string version;
+      llvm::raw_string_ostream versionStream(version);
+      WriteDxCompilerVersionInfo(versionStream,
+        dxcOpts.ExternalLib.empty() ? (LPCSTR)nullptr : dxcOpts.ExternalLib.data(),
+        dxcOpts.ExternalFn.empty() ? (LPCSTR)nullptr : dxcOpts.ExternalFn.data(),
+        dxcSupport);
+      versionStream.flush();
+      optionTable->PrintHelp(helpStream, "dxr.exe", "HLSL Rewriter",
+                             version.c_str(),
+                             hlsl::options::RewriteOption,
+                             (dxcOpts.ShowHelpHidden ? 0 : HelpHidden));
+      helpStream.flush();
+      WriteUtf8ToConsoleSizeT(helpString.data(), helpString.size());
+      return 0;
+    }
 
-    switch (modeNum) {
-    case 0:
-      context.RunRewriteUnchanged();
-      break;
-    case 1:
-      if (pEntryPoint == nullptr) {
-        printf("Cannot use -remove-unused-globals without specifying an entry point.\n");
-        return 1;
+    CComPtr<IDxcRewriter2> pRewriter;
+    CComPtr<IDxcOperationResult> pRewriteResult;
+    CComPtr<IDxcBlobEncoding> pSource;
+    std::wstring wName(CA2W(dxcOpts.InputFile.empty()? "" : dxcOpts.InputFile.data()));
+    if (!dxcOpts.InputFile.empty()) {
+      IFT_Data(FileMapDxcBlobEncoding::CreateForFile(wName.c_str(), &pSource), wName.c_str());
+    }
+    IFT(dxcSupport.CreateInstance(CLSID_DxcRewriter, &pRewriter));
+    IFT(pRewriter->RewriteWithOptions(pSource, wName.c_str(),
+                                      argv_, argc,
+                                      nullptr, 0, nullptr,
+                                      &pRewriteResult));
+
+    if (dxcOpts.OutputObject.empty()) {
+      // No -Fo, print to console
+      WriteOperationResultToConsole(pRewriteResult, !dxcOpts.OutputWarnings);
+    } else {
+      WriteOperationErrorsToConsole(pRewriteResult, !dxcOpts.OutputWarnings);
+      HRESULT hr;
+      IFT(pRewriteResult->GetStatus(&hr));
+      if (SUCCEEDED(hr)) {
+        CA2W wOutputObject(dxcOpts.OutputObject.data());
+        CComPtr<IDxcBlob> pObject;
+        IFT(pRewriteResult->GetResult(&pObject));
+        WriteBlobToFile(pObject, wOutputObject.m_psz, dxcOpts.DefaultTextCodePage);
+        printf("Rewrite output: %s", dxcOpts.OutputObject.data());
       }
-      context.RunRemoveUnusedGlobals();
-      break;
     }
+
   }
   catch (const ::hlsl::Exception& hlslException) {
     try {

+ 198 - 100
tools/clang/tools/libclang/dxcrewriteunused.cpp

@@ -329,17 +329,16 @@ void WriteUserMacroDefines(CompilerInstance &compiler, raw_string_ostream &o) {
 }
 
 static
-HRESULT ReadOptsAndValidate(LPCWSTR *pArguments, _In_ UINT32 argCount,
+HRESULT ReadOptsAndValidate(hlsl::options::MainArgs &mainArgs,
                             hlsl::options::DxcOpts &opts,
                             _COM_Outptr_ IDxcOperationResult **ppResult) {
-  hlsl::options::MainArgs mainArgs(argCount, pArguments, 0);
   const llvm::opt::OptTable *table = ::options::getHlslOptTable();
 
   CComPtr<AbstractMemoryStream> pOutputStream;
   IFT(CreateMemoryStream(GetGlobalHeapMalloc(), &pOutputStream));
   raw_stream_ostream outStream(pOutputStream);
 
-  if (0 != hlsl::options::ReadDxcOpts(table, hlsl::options::CompilerFlags,
+  if (0 != hlsl::options::ReadDxcOpts(table, hlsl::options::HlslFlags::RewriteOption,
                                       mainArgs, opts, outStream)) {
     CComPtr<IDxcBlob> pErrorBlob;
     IFT(pOutputStream->QueryInterface(&pErrorBlob));
@@ -350,40 +349,70 @@ HRESULT ReadOptsAndValidate(LPCWSTR *pArguments, _In_ UINT32 argCount,
       }, ppResult));
     return S_OK;
   }
-  DXASSERT(opts.HLSLVersion > 2015,
-           "else ReadDxcOpts didn't fail for non-isense");
   return S_OK;
 }
 
-
 static
-HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
-                     _In_ LPCSTR pFileName,
-                     _In_ ASTUnit::RemappedFile *pRemap,
-                     _In_ LPCSTR pEntryPoint,
-                     _In_ LPCSTR pDefines,
-                     std::string &warnings,
-                     std::string &result) {
-
-  raw_string_ostream o(result);
-  raw_string_ostream w(warnings);
+bool HasUniformParams(FunctionDecl *FD) {
+  for (auto PD : FD->params()) {
+    if (PD->hasAttr<HLSLUniformAttr>())
+      return true;
+  }
+  return false;
+}
 
-  // Setup a compiler instance.
-  CompilerInstance compiler;
-  std::unique_ptr<TextDiagnosticPrinter> diagPrinter =
-      llvm::make_unique<TextDiagnosticPrinter>(w, &compiler.getDiagnosticOpts());  
+static
+void WriteUniformParamsAsGlobals(FunctionDecl *FD,
+                                 raw_ostream &o,
+                                 PrintingPolicy &p) {
+  // Extract resources first, to avoid placing in cbuffer _Params
+  for (auto PD : FD->params()) {
+    if (PD->hasAttr<HLSLUniformAttr>() &&
+        hlsl::IsHLSLResourceType(PD->getType())) {
+      PD->print(o, p);
+      o << ";\n";
+    }
+  }
+  // Extract any non-resource uniforms into cbuffer _Params
+  bool startedParams = false;
+  for (auto PD : FD->params()) {
+    if (PD->hasAttr<HLSLUniformAttr>() &&
+        !hlsl::IsHLSLResourceType(PD->getType())) {
+      if (!startedParams) {
+        o << "cbuffer _Params {\n";
+        startedParams = true;
+      }
+      PD->print(o, p);
+      o << ";\n";
+    }
+  }
+  if (startedParams) {
+    o << "}\n";
+  }
+}
 
-  hlsl::options::DxcOpts opts;
-  opts.HLSLVersion = 2015;
+static
+void PrintTranslationUnitWithTranslatedUniformParams(
+    TranslationUnitDecl *tu,
+    FunctionDecl *entryFnDecl,
+    raw_ostream &o,
+    PrintingPolicy &p) {
+  // Print without the entry function
+  entryFnDecl->setImplicit(true); // Prevent printing of this decl
+  tu->print(o, p);
+  entryFnDecl->setImplicit(false);
 
-  SetupCompilerForRewrite(compiler, pHelper, pFileName, diagPrinter.get(), pRemap, opts, pDefines);
+  WriteUniformParamsAsGlobals(entryFnDecl, o, p);
 
-  // Parse the source file.
-  compiler.getDiagnosticClient().BeginSourceFile(compiler.getLangOpts(), &compiler.getPreprocessor());
-  ParseAST(compiler.getSema(), false, false);
+  PrintingPolicy SubPolicy(p);
+  SubPolicy.HLSLSuppressUniformParameters = true;
+  entryFnDecl->print(o, SubPolicy);
+}
 
-  ASTContext& C = compiler.getASTContext();
-  TranslationUnitDecl *tu = C.getTranslationUnitDecl();
+static HRESULT DoRewriteUnused( TranslationUnitDecl *tu,
+                                LPCSTR pEntryPoint,
+                                raw_ostream &w) {
+  ASTContext& C = tu->getASTContext();
 
   // Gather all global variables that are not in cbuffers and all functions.
   SmallPtrSet<VarDecl*, 128> unusedGlobals;
@@ -418,82 +447,123 @@ HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
   DeclContext::lookup_result l = tu->lookup(DeclarationName(&C.Idents.get(StringRef(pEntryPoint))));
   if (l.empty()) {
     w << "//entry point not found\n";
+    return E_FAIL;
   }
-  else {
-    w << "//entry point found\n";
-    NamedDecl *entryDecl = l.front();
-    FunctionDecl *entryFnDecl = dyn_cast_or_null<FunctionDecl>(entryDecl);
-    if (entryFnDecl == nullptr) {
-      o << "//entry point found but is not a function declaration\n";
-    }
-    else {
-      // Traverse reachable functions and variables.
-      SmallPtrSet<FunctionDecl*, 128> visitedFunctions;
-      SmallVector<FunctionDecl*, 32> pendingFunctions;
-      VarReferenceVisitor visitor(unusedGlobals, visitedFunctions, pendingFunctions);
-      pendingFunctions.push_back(entryFnDecl);
-      while (!pendingFunctions.empty() && !unusedGlobals.empty()) {
-        FunctionDecl* pendingDecl = pendingFunctions.pop_back_val();
-        visitedFunctions.insert(pendingDecl);
-        visitor.TraverseDecl(pendingDecl);
-      }
 
-      // Don't bother doing work if there are no globals to remove.
-      if (unusedGlobals.empty()) {
-        w << "//no unused globals found - no work to be done\n";
-        StringRef contents = C.getSourceManager().getBufferData(C.getSourceManager().getMainFileID());
-        o << contents;
-      }
-      else {
-        w << "//found " << unusedGlobals.size() << " globals to remove\n";
+  w << "//entry point found\n";
+  NamedDecl *entryDecl = l.front();
+  FunctionDecl *entryFnDecl = dyn_cast_or_null<FunctionDecl>(entryDecl);
+  if (entryFnDecl == nullptr) {
+    w << "//entry point found but is not a function declaration\n";
+    return E_FAIL;
+  }
 
-        // Don't remove visited functions.
-        for (FunctionDecl *visitedFn : visitedFunctions) {
-          unusedFunctions.erase(visitedFn);
-        }
-        w << "//found " << unusedFunctions.size() << " functions to remove\n";
-
-        // Remove all unused variables and functions.
-        for (VarDecl *unusedGlobal : unusedGlobals) {
-          if (const RecordType *recordTy = unusedGlobal->getType()->getAs<RecordType>()) {
-            RecordDecl *recordDecl = recordTy->getDecl();
-            if (recordDecl && recordDecl->getName().empty()) {
-              // Anonymous structs can only be referenced by the variable they declare.
-              // If we've removed all declared variables of such a struct, remove it too,
-              // because anonymous structs without variable declarations in global scope are illegal.
-              auto recordRefCountIter = anonymousRecordRefCounts.find(recordDecl);
-              DXASSERT_NOMSG(recordRefCountIter != anonymousRecordRefCounts.end() && recordRefCountIter->second > 0);
-              recordRefCountIter->second--;
-              if (recordRefCountIter->second == 0) {
-                tu->removeDecl(recordDecl);
-                anonymousRecordRefCounts.erase(recordRefCountIter);
-              }
-            }
-          }
-
-          tu->removeDecl(unusedGlobal);
-        }
+  // Traverse reachable functions and variables.
+  SmallPtrSet<FunctionDecl*, 128> visitedFunctions;
+  SmallVector<FunctionDecl*, 32> pendingFunctions;
+  VarReferenceVisitor visitor(unusedGlobals, visitedFunctions, pendingFunctions);
+  pendingFunctions.push_back(entryFnDecl);
+  while (!pendingFunctions.empty() && !unusedGlobals.empty()) {
+    FunctionDecl* pendingDecl = pendingFunctions.pop_back_val();
+    visitedFunctions.insert(pendingDecl);
+    visitor.TraverseDecl(pendingDecl);
+  }
 
-        for (FunctionDecl *unusedFn : unusedFunctions) {
-          tu->removeDecl(unusedFn);
-        }
+  // Don't bother doing work if there are no globals to remove.
+  if (unusedGlobals.empty()) {
+    return S_FALSE;
+  }
 
-        o << "// Rewrite unused globals result:\n";
-        PrintingPolicy p = PrintingPolicy(C.getPrintingPolicy());
-        p.Indentation = 1;
-        tu->print(o, p);
+  w << "//found " << unusedGlobals.size() << " globals to remove\n";
 
-        WriteSemanticDefines(compiler, pHelper, o);
+  // Don't remove visited functions.
+  for (FunctionDecl *visitedFn : visitedFunctions) {
+    unusedFunctions.erase(visitedFn);
+  }
+  w << "//found " << unusedFunctions.size() << " functions to remove\n";
+
+  // Remove all unused variables and functions.
+  for (VarDecl *unusedGlobal : unusedGlobals) {
+    if (const RecordType *recordTy = unusedGlobal->getType()->getAs<RecordType>()) {
+      RecordDecl *recordDecl = recordTy->getDecl();
+      if (recordDecl && recordDecl->getName().empty()) {
+        // Anonymous structs can only be referenced by the variable they declare.
+        // If we've removed all declared variables of such a struct, remove it too,
+        // because anonymous structs without variable declarations in global scope are illegal.
+        auto recordRefCountIter = anonymousRecordRefCounts.find(recordDecl);
+        DXASSERT_NOMSG(recordRefCountIter != anonymousRecordRefCounts.end() && recordRefCountIter->second > 0);
+        recordRefCountIter->second--;
+        if (recordRefCountIter->second == 0) {
+          tu->removeDecl(recordDecl);
+          anonymousRecordRefCounts.erase(recordRefCountIter);
+        }
       }
     }
+
+    tu->removeDecl(unusedGlobal);
+  }
+
+  for (FunctionDecl *unusedFn : unusedFunctions) {
+    tu->removeDecl(unusedFn);
   }
 
   // Flush and return results.
-  o.flush();
   w.flush();
+  return S_OK;
+}
+
+static
+HRESULT DoRewriteUnused(_In_ DxcLangExtensionsHelper *pHelper,
+                     _In_ LPCSTR pFileName,
+                     _In_ ASTUnit::RemappedFile *pRemap,
+                     _In_ LPCSTR pEntryPoint,
+                     _In_ LPCSTR pDefines,
+                     std::string &warnings,
+                     std::string &result) {
+
+  raw_string_ostream o(result);
+  raw_string_ostream w(warnings);
+
+  // Setup a compiler instance.
+  CompilerInstance compiler;
+  std::unique_ptr<TextDiagnosticPrinter> diagPrinter =
+      llvm::make_unique<TextDiagnosticPrinter>(w, &compiler.getDiagnosticOpts());
+
+  hlsl::options::DxcOpts opts;
+  opts.HLSLVersion = 2015;
+
+  SetupCompilerForRewrite(compiler, pHelper, pFileName, diagPrinter.get(), pRemap, opts, pDefines);
+
+  // Parse the source file.
+  compiler.getDiagnosticClient().BeginSourceFile(compiler.getLangOpts(), &compiler.getPreprocessor());
+  ParseAST(compiler.getSema(), false, false);
+
+  ASTContext& C = compiler.getASTContext();
+  TranslationUnitDecl *tu = C.getTranslationUnitDecl();
 
   if (compiler.getDiagnosticClient().getNumErrors() > 0)
     return E_FAIL;
+
+  HRESULT hr = DoRewriteUnused(tu, pEntryPoint, w);
+  if (FAILED(hr))
+    return hr;
+
+  if (hr == S_FALSE) {
+    w << "//no unused globals found - no work to be done\n";
+    StringRef contents = C.getSourceManager().getBufferData(C.getSourceManager().getMainFileID());
+    o << contents;
+  } else {
+    PrintingPolicy p = PrintingPolicy(C.getPrintingPolicy());
+    p.Indentation = 1;
+    tu->print(o, p);
+  }
+
+  WriteSemanticDefines(compiler, pHelper, o);
+
+  // Flush and return results.
+  o.flush();
+  w.flush();
+
   return S_OK;
 }
 
@@ -546,10 +616,10 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
                std::string &warnings,
                std::string &result) {
 
-  bool bSkipFunctionBody = rewriteOption & RewriterOptionMask::SkipFunctionBody;
-  bool bSkipStatic = rewriteOption & RewriterOptionMask::SkipStatic;
-  bool bGlobalExternByDefault = rewriteOption & RewriterOptionMask::GlobalExternByDefault;
-  bool bKeepUserMacro = rewriteOption & RewriterOptionMask::KeepUserMacro;
+  opts.RWOpt.SkipFunctionBody |= rewriteOption & RewriterOptionMask::SkipFunctionBody;
+  opts.RWOpt.SkipStatic |= rewriteOption & RewriterOptionMask::SkipStatic;
+  opts.RWOpt.GlobalExternByDefault |= rewriteOption & RewriterOptionMask::GlobalExternByDefault;
+  opts.RWOpt.KeepUserMacro |= rewriteOption & RewriterOptionMask::KeepUserMacro;
 
   raw_string_ostream o(result);
   raw_string_ostream w(warnings);
@@ -563,27 +633,54 @@ HRESULT DoSimpleReWrite(_In_ DxcLangExtensionsHelper *pHelper,
   // Parse the source file.
   compiler.getDiagnosticClient().BeginSourceFile(compiler.getLangOpts(), &compiler.getPreprocessor());
 
-  ParseAST(compiler.getSema(), false, bSkipFunctionBody);
+  ParseAST(compiler.getSema(), false, opts.RWOpt.SkipFunctionBody);
 
   ASTContext& C = compiler.getASTContext();
   TranslationUnitDecl *tu = C.getTranslationUnitDecl();
 
-  if (bSkipStatic && bSkipFunctionBody) {
+  if (opts.RWOpt.SkipStatic && opts.RWOpt.SkipFunctionBody) {
     // Remove static functions and globals.
     RemoveStaticDecls(*tu);
   }
 
-  if (bGlobalExternByDefault) {
+  if (opts.RWOpt.GlobalExternByDefault) {
     GlobalVariableAsExternByDefault(*tu);
   }
 
-  o << "// Rewrite unchanged result:\n";
+  if (opts.EntryPoint.empty())
+    opts.EntryPoint = "main";
+
+  if (opts.RWOpt.RemoveUnusedGlobals) {
+    HRESULT hr = DoRewriteUnused(tu, opts.EntryPoint.data(), w);
+    if (FAILED(hr))
+      return hr;
+  } else {
+    o << "// Rewrite unchanged result:\n";
+  }
+
+  FunctionDecl *entryFnDecl = nullptr;
+  if (opts.RWOpt.ExtractEntryUniforms) {
+    DeclContext::lookup_result l = tu->lookup(DeclarationName(&C.Idents.get(opts.EntryPoint)));
+    if (l.empty()) {
+      w << "//entry point not found\n";
+      return E_FAIL;
+    }
+    entryFnDecl = dyn_cast_or_null<FunctionDecl>(l.front());
+    if (!HasUniformParams(entryFnDecl))
+      entryFnDecl = nullptr;
+  }
+
   PrintingPolicy p = PrintingPolicy(C.getPrintingPolicy());
   p.Indentation = 1;
-  tu->print(o, p);
+
+  if (entryFnDecl) {
+    PrintTranslationUnitWithTranslatedUniformParams(tu, entryFnDecl, o, p);
+  } else {
+    tu->print(o, p);
+  }
 
   WriteSemanticDefines(compiler, pHelper, o);
-  if (bKeepUserMacro)
+  if (opts.RWOpt.KeepUserMacro)
     WriteUserMacroDefines(compiler, o);
 
   // Flush and return results.
@@ -802,8 +899,9 @@ public:
 
       std::string definesStr = DefinesToString(pDefines, defineCount);
 
+      hlsl::options::MainArgs mainArgs(argCount, pArguments, 0);
       hlsl::options::DxcOpts opts;
-      IFR(ReadOptsAndValidate(pArguments, argCount, opts, ppResult));
+      IFR(ReadOptsAndValidate(mainArgs, opts, ppResult));
       HRESULT hr;
       if (*ppResult && SUCCEEDED((*ppResult)->GetStatus(&hr)) && FAILED(hr)) {
         // Looks odd, but this call succeeded enough to allocate a result

+ 20 - 20
tools/clang/unittests/HLSL/OptionsTest.cpp

@@ -115,14 +115,14 @@ TEST_F(OptionsTest, ReadOptionsWhenExtensionsThenOK) {
     L"exe.exe",   L"/E",        L"main",    L"/T",           L"ps_6_0",
     L"hlsl.hlsl", L"-external", L"foo.dll" };
   MainArgsArr ArgsArr(Args);
-  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxrFlags);
+  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxcFlags);
   VERIFY_ARE_EQUAL_STR("CreateObj", o->ExternalFn.data());
   VERIFY_ARE_EQUAL_STR("foo.dll", o->ExternalLib.data());
 
   MainArgsArr ArgsNoLibArr(ArgsNoLib);
-  ReadOptsTest(ArgsNoLibArr, DxrFlags, true, true);
+  ReadOptsTest(ArgsNoLibArr, DxcFlags, true, true);
   MainArgsArr ArgsNoFnArr(ArgsNoFn);
-  ReadOptsTest(ArgsNoFnArr, DxrFlags, true, true);
+  ReadOptsTest(ArgsNoFnArr, DxcFlags, true, true);
 }
 
 TEST_F(OptionsTest, ReadOptionsForOutputObject) {
@@ -130,7 +130,7 @@ TEST_F(OptionsTest, ReadOptionsForOutputObject) {
       L"exe.exe",   L"/E",        L"main",    L"/T",           L"ps_6_0",
       L"hlsl.hlsl", L"-Fo", L"hlsl.dxbc"};
   MainArgsArr ArgsArr(Args);
-  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxrFlags);
+  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxcFlags);
   VERIFY_ARE_EQUAL_STR("hlsl.dxbc", o->OutputObject.data());  
 }
 
@@ -140,33 +140,33 @@ TEST_F(OptionsTest, ReadOptionsConflict) {
       L"-Zpr", L"-Zpc",
       L"hlsl.hlsl"};
   MainArgsArr ArgsArr(matrixArgs);
-  ReadOptsTest(ArgsArr, DxrFlags, "Cannot specify /Zpr and /Zpc together, use /? to get usage information");
+  ReadOptsTest(ArgsArr, DxcFlags, "Cannot specify /Zpr and /Zpc together, use /? to get usage information");
 
   const wchar_t *controlFlowArgs[] = {
       L"exe.exe",   L"/E",        L"main",    L"/T",           L"ps_6_0",
       L"-Gfa", L"-Gfp",
       L"hlsl.hlsl"};
   MainArgsArr controlFlowArr(controlFlowArgs);
-  ReadOptsTest(controlFlowArr, DxrFlags, "Cannot specify /Gfa and /Gfp together, use /? to get usage information");
+  ReadOptsTest(controlFlowArr, DxcFlags, "Cannot specify /Gfa and /Gfp together, use /? to get usage information");
 
   const wchar_t *libArgs[] = {
       L"exe.exe",   L"/E",        L"main",    L"/T",           L"lib_6_1",
       L"hlsl.hlsl"};
   MainArgsArr libArr(libArgs);
-  ReadOptsTest(libArr, DxrFlags, "Must disable validation for unsupported lib_6_1 or lib_6_2 targets.");
+  ReadOptsTest(libArr, DxcFlags, "Must disable validation for unsupported lib_6_1 or lib_6_2 targets.");
 }
 
 TEST_F(OptionsTest, ReadOptionsWhenHelpThenShortcut) {
   const wchar_t *Args[] = { L"exe.exe", L"--help", L"--unknown-flag" };
   MainArgsArr ArgsArr(Args);
-  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxrFlags);
+  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxcFlags);
   EXPECT_EQ(true, o->ShowHelp);
 }
 
 TEST_F(OptionsTest, ReadOptionsWhenValidThenOK) {
   const wchar_t *Args[] = { L"exe.exe", L"/E", L"main", L"/T", L"ps_6_0", L"hlsl.hlsl" };
   MainArgsArr ArgsArr(Args);
-  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxrFlags);
+  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxcFlags);
   VERIFY_ARE_EQUAL_STR("main", o->EntryPoint.data());
   VERIFY_ARE_EQUAL_STR("ps_6_0", o->TargetProfile.data());
   VERIFY_ARE_EQUAL_STR("hlsl.hlsl", o->InputFile.data());
@@ -175,7 +175,7 @@ TEST_F(OptionsTest, ReadOptionsWhenValidThenOK) {
 TEST_F(OptionsTest, ReadOptionsWhenJoinedThenOK) {
   const wchar_t *Args[] = { L"exe.exe", L"/Emain", L"/Tps_6_0", L"hlsl.hlsl" };
   MainArgsArr ArgsArr(Args);
-  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxrFlags);
+  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxcFlags);
   VERIFY_ARE_EQUAL_STR("main", o->EntryPoint.data());
   VERIFY_ARE_EQUAL_STR("ps_6_0", o->TargetProfile.data());
   VERIFY_ARE_EQUAL_STR("hlsl.hlsl", o->InputFile.data());
@@ -186,7 +186,7 @@ TEST_F(OptionsTest, ReadOptionsWhenNoEntryThenOK) {
   // set to 'main' on behalf of callers either.
   const wchar_t *Args[] = { L"exe.exe", L"/T", L"ps_6_0", L"hlsl.hlsl" };
   MainArgsArr ArgsArr(Args);
-  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxrFlags);
+  std::unique_ptr<DxcOpts> o = ReadOptsTest(ArgsArr, DxcFlags);
   VERIFY_IS_TRUE(o->EntryPoint.empty());
 }
 
@@ -202,11 +202,11 @@ TEST_F(OptionsTest, ReadOptionsWhenInvalidThenFail) {
   MainArgsArr ArgsNoTargetArr(ArgsNoTarget),
       ArgsNoInputArr(ArgsNoInput), ArgsNoArgArr(ArgsNoArg),
     ArgsUnknownArr(ArgsUnknown), ArgsUnknownButIgnoreArr(ArgsUnknownButIgnore);
-  ReadOptsTest(ArgsNoTargetArr, DxrFlags, true, true);
-  ReadOptsTest(ArgsNoInputArr, DxrFlags, true, true);
-  ReadOptsTest(ArgsNoArgArr, DxrFlags, true, true);
-  ReadOptsTest(ArgsUnknownArr, DxrFlags, true, true);
-  ReadOptsTest(ArgsUnknownButIgnoreArr, DxrFlags);
+  ReadOptsTest(ArgsNoTargetArr, DxcFlags, true, true);
+  ReadOptsTest(ArgsNoInputArr, DxcFlags, true, true);
+  ReadOptsTest(ArgsNoArgArr, DxcFlags, true, true);
+  ReadOptsTest(ArgsUnknownArr, DxcFlags, true, true);
+  ReadOptsTest(ArgsUnknownButIgnoreArr, DxcFlags);
 }
 
 TEST_F(OptionsTest, ReadOptionsWhenDefinesThenInit) {
@@ -219,22 +219,22 @@ TEST_F(OptionsTest, ReadOptionsWhenDefinesThenInit) {
       ArgsTwoDefinesArr(ArgsTwoDefines), ArgsEmptyDefineArr(ArgsEmptyDefine);
 
   std::unique_ptr<DxcOpts> o;
-  o = ReadOptsTest(ArgsNoDefinesArr, DxrFlags);
+  o = ReadOptsTest(ArgsNoDefinesArr, DxcFlags);
   EXPECT_EQ(0U, o->Defines.size());
   
-  o = ReadOptsTest(ArgsOneDefineArr, DxrFlags);
+  o = ReadOptsTest(ArgsOneDefineArr, DxcFlags);
   EXPECT_EQ(1U, o->Defines.size());
   EXPECT_STREQW(L"NAME1", o->Defines.data()[0].Name);
   EXPECT_STREQW(L"1", o->Defines.data()[0].Value);
 
-  o = ReadOptsTest(ArgsTwoDefinesArr, DxrFlags);
+  o = ReadOptsTest(ArgsTwoDefinesArr, DxcFlags);
   EXPECT_EQ(2U, o->Defines.size());
   EXPECT_STREQW(L"NAME1", o->Defines.data()[0].Name);
   EXPECT_STREQW(L"1", o->Defines.data()[0].Value);
   EXPECT_STREQW(L"NAME2", o->Defines.data()[1].Name);
   EXPECT_STREQW(L"2", o->Defines.data()[1].Value);
 
-  o = ReadOptsTest(ArgsEmptyDefineArr, DxrFlags);
+  o = ReadOptsTest(ArgsEmptyDefineArr, DxcFlags);
   EXPECT_EQ(1U, o->Defines.size());
   EXPECT_STREQW(L"NAME1", o->Defines.data()[0].Name);
   EXPECT_EQ(nullptr, o->Defines.data()[0].Value);

+ 55 - 2
tools/clang/unittests/HLSL/RewriterTest.cpp

@@ -86,6 +86,7 @@ public:
   TEST_METHOD(RunNoFunctionBodyInclude);
   TEST_METHOD(RunNoStatic);
   TEST_METHOD(RunKeepUserMacro);
+  TEST_METHOD(RunExtractUniforms);
   TEST_METHOD(RunRewriterFails)
 
   dxc::DxcDllSupport m_dllSupport;
@@ -219,9 +220,15 @@ public:
     DxcDefine myDefines[myDefinesCount] = {
         {L"myDefine", L"2"}, {L"myDefine3", L"1994"}, {L"myDefine4", nullptr}};
 
+    LPCWSTR args[] = {L"-HV", L"2016"};
+
+    CComPtr<IDxcRewriter2> rewriter2;
+    VERIFY_SUCCEEDED(rewriter->QueryInterface(&rewriter2));
     // Run rewrite unchanged on the source code
-    VERIFY_SUCCEEDED(rewriter->RewriteUnchanged(source.BlobEncoding, myDefines,
-                                                myDefinesCount, ppResult));
+    VERIFY_SUCCEEDED(rewriter2->RewriteWithOptions( source.BlobEncoding, path,
+                                                    args, _countof(args),
+                                                    myDefines, myDefinesCount,
+                                                    nullptr, ppResult));
 
     // check for compilation errors
     HRESULT hrStatus;
@@ -634,6 +641,52 @@ float test(float a, float b) {\n\
 ") == 0);
 }
 
+TEST_F(RewriterTest, RunExtractUniforms) {
+  CComPtr<IDxcRewriter> pRewriter;
+  CComPtr<IDxcRewriter2> pRewriter2;
+  VERIFY_SUCCEEDED(CreateRewriter(&pRewriter));
+  VERIFY_SUCCEEDED(pRewriter->QueryInterface(&pRewriter2));
+  CComPtr<IDxcOperationResult> pRewriteResult;
+
+  // Get the source text from a file
+  FileWithBlob source(
+      m_dllSupport,
+      GetPathToHlslDataFile(L"rewriter\\rewrite-uniforms.hlsl")
+          .c_str());
+
+  LPCWSTR compileOptions[] = {L"-E", L"FloatFunc", L"-extract-entry-uniforms"};
+
+  // Run rewrite on the source code to move uniform params to globals
+  VERIFY_SUCCEEDED(pRewriter2->RewriteWithOptions(
+    source.BlobEncoding, L"rewrite-uniforms.hlsl",
+    compileOptions, _countof(compileOptions),
+    nullptr, 0, nullptr, &pRewriteResult));
+
+  CComPtr<IDxcBlob> result;
+  VERIFY_SUCCEEDED(pRewriteResult->GetResult(&result));
+
+  VERIFY_IS_TRUE(strcmp(BlobToUtf8(result).c_str(),
+"// Rewrite unchanged result:\n\
+[RootSignature(\"RootFlags(0),DescriptorTable(UAV(u0, numDescriptors = 1), CBV(b0, numDescriptors = 1))\")]\n\
+[numthreads(4, 8, 16)]\n\
+void IntFunc(uint3 id : SV_DispatchThreadID, uniform RWStructuredBuffer<int> buf, uniform uint ui) {\n\
+  buf[id.x + id.y + id.z] = id.x + ui;\n\
+}\n\
+\n\
+\n\
+uniform RWStructuredBuffer<float> buf;\n\
+cbuffer _Params {\n\
+uniform uint ui;\n\
+}\n\
+[RootSignature(\"RootFlags(0),DescriptorTable(UAV(u0, numDescriptors = 1), CBV(b0, numDescriptors = 1))\")]\n\
+[numthreads(4, 8, 16)]\n\
+void FloatFunc(uint3 id : SV_DispatchThreadID) {\n\
+  buf[id.x + id.y + id.z] = id.x;\n\
+}\n\
+\n\
+") == 0);
+}
+
 TEST_F(RewriterTest, RunRewriterFails) {
   CComPtr<IDxcRewriter> pRewriter;
   CComPtr<IDxcRewriter2> pRewriter2;