2
0
Эх сурвалжийг харах

hlslFlags is now a null-terminator separated list of compiler options (excluding target, entry, and defines) (#2698)

Adam Yang 5 жил өмнө
parent
commit
bfb8143b27

+ 39 - 1
lib/DxilDia/DxilDiaSymbolManager.cpp

@@ -814,7 +814,45 @@ HRESULT dxil_dia::hlsl_symbols::CompilandDetailsSymbol::GetChildren(std::vector<
 HRESULT dxil_dia::hlsl_symbols::CompilandEnvSymbol::CreateFlags(IMalloc *pMalloc, Session *pSession, Symbol **ppSym) {
   IFR(AllocAndInit(pMalloc, pSession, HlslCompilandEnvFlagsId, SymTagCompilandEnv, (CompilandEnvSymbol**)ppSym));
   (*ppSym)->SetName(L"hlslFlags");
-  (*ppSym)->SetValue(pSession->DxilModuleRef().GetGlobalFlags());
+
+  const char *specialCases[] = { "/T", "-T", "-D", "/D", "-E", "/E", };
+
+  llvm::MDNode *argsNode = pSession->Arguments()->getOperand(0);
+  // Construct a double null terminated string for defines with L"\0" as a delimiter
+  CComBSTR pBSTR;
+  for (llvm::MDNode::op_iterator it = argsNode->op_begin(); it != argsNode->op_end(); ++it) {
+    llvm::StringRef strRef = llvm::dyn_cast<llvm::MDString>(*it)->getString();
+
+    bool skip = false;
+    bool skipTwice = false;
+    for (unsigned i = 0; i < _countof(specialCases); i++) {
+      if (strRef == specialCases[i]) {
+        skipTwice = true;
+        skip = true;
+        break;
+      }
+      else if (strRef.startswith(specialCases[i])) {
+        skip = true;
+        break;
+      }
+    }
+
+    if (skip) {
+      if (skipTwice)
+        ++it;
+      continue;
+    }
+
+    std::string str(strRef.begin(), strRef.size());
+    CA2W cv(str.c_str(), CP_UTF8);
+    pBSTR.Append(cv);
+    pBSTR.Append(L"\0", 1);
+  }
+  pBSTR.Append(L"\0", 1);
+  VARIANT Variant;
+  Variant.bstrVal = pBSTR;
+  Variant.vt = VARENUM::VT_BSTR;
+  (*ppSym)->SetValue(&Variant);
   return S_OK;
 }
 

+ 164 - 0
tools/clang/unittests/HLSL/CompilerTest.cpp

@@ -262,6 +262,7 @@ public:
   TEST_METHOD(DiaLoadDebugSubrangeNegativeThenOK)
   TEST_METHOD(DiaLoadRelocatedBitcode)
   TEST_METHOD(DiaLoadBitcodePlusExtraData)
+  TEST_METHOD(DiaCompileArgs)
 
   TEST_METHOD(CodeGenFloatingPointEnvironment)
   TEST_METHOD(CodeGenInclude)
@@ -2554,6 +2555,169 @@ TEST_F(CompilerTest, DiaLoadRelocatedBitcode) {
   VERIFY_SUCCEEDED(pDiaDataSource->loadDataFromIStream(pNewProgramStream));
 }
 
+TEST_F(CompilerTest, DiaCompileArgs) {
+  static const char source[] = R"(
+    SamplerState  samp0 : register(s0);
+    Texture2DArray tex0 : register(t0);
+
+    float4 foo(Texture2DArray textures[], int idx, SamplerState samplerState, float3 uvw) {
+      return textures[NonUniformResourceIndex(idx)].Sample(samplerState, uvw);
+    }
+
+    [RootSignature( "DescriptorTable(SRV(t0)), DescriptorTable(Sampler(s0)) " )]
+    float4 main(int index : INDEX, float3 uvw : TEXCOORD) : SV_Target {
+      Texture2DArray textures[] = {
+        tex0,
+      };
+      return foo(textures, index, samp0, uvw);
+    }
+  )";
+
+  CComPtr<IDxcBlob> pPart;
+  CComPtr<IDiaDataSource> pDiaSource;
+  CComPtr<IStream> pStream;
+
+  CComPtr<IDxcLibrary> pLib;
+  VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcLibrary, &pLib));
+
+  const WCHAR *FlagList[] = {
+    L"/Zi",
+    L"-Zpr",
+    L"/Qembed_debug",
+    L"/Fd", L"F:\\my dir\\",
+    L"-Fo", L"F:\\my dir\\file.dxc",
+  };
+  const WCHAR *DefineList[] = {
+    L"MY_SPECIAL_DEFINE",
+    L"MY_OTHER_SPECIAL_DEFINE=\"MY_STRING\"",
+  };
+
+  std::vector<LPCWSTR> args;
+  for (unsigned i = 0; i < _countof(FlagList); i++) {
+    args.push_back(FlagList[i]);
+  }
+  for (unsigned i = 0; i < _countof(DefineList); i++) {
+    args.push_back(L"/D");
+    args.push_back(DefineList[i]);
+  }
+
+  auto CompileAndGetDebugPart = [&args](dxc::DxcDllSupport &dllSupport, const char *source, wchar_t *profile, IDxcBlob **ppDebugPart) {
+    CComPtr<IDxcBlob> pContainer;
+    CComPtr<IDxcLibrary> pLib;
+    CComPtr<IDxcContainerReflection> pReflection;
+    UINT32 index;
+
+    VerifyCompileOK(dllSupport, source, profile, args, &pContainer);
+    VERIFY_SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcLibrary, &pLib));
+    VERIFY_SUCCEEDED(dllSupport.CreateInstance(CLSID_DxcContainerReflection, &pReflection));
+    VERIFY_SUCCEEDED(pReflection->Load(pContainer));
+    VERIFY_SUCCEEDED(pReflection->FindFirstPartKind(hlsl::DFCC_ShaderDebugInfoDXIL, &index));
+    VERIFY_SUCCEEDED(pReflection->GetPartContent(index, ppDebugPart));
+  };
+
+  CompileAndGetDebugPart(m_dllSupport, source, L"ps_6_0", &pPart);
+
+  CComPtr<IStream> pNewProgramStream;
+  VERIFY_SUCCEEDED(pLib->CreateStreamFromBlobReadOnly(pPart, &pNewProgramStream));
+
+  CComPtr<IDiaDataSource> pDiaDataSource;
+  VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcDiaDataSource, &pDiaDataSource));
+
+  VERIFY_SUCCEEDED(pDiaDataSource->loadDataFromIStream(pNewProgramStream));
+
+  CComPtr<IDiaSession> pSession;
+  VERIFY_SUCCEEDED(pDiaDataSource->openSession(&pSession));
+
+  CComPtr<IDiaEnumTables> pEnumTables;
+  VERIFY_SUCCEEDED(pSession->getEnumTables(&pEnumTables));
+
+  CComPtr<IDiaTable> pSymbolTable;
+
+  LONG uCount = 0;
+  VERIFY_SUCCEEDED(pEnumTables->get_Count(&uCount));
+  for (int i = 0; i < uCount; i++) {
+    CComPtr<IDiaTable> pTable;
+    VARIANT index = {};
+    index.vt = VT_I4;
+    index.intVal = i;
+    VERIFY_SUCCEEDED(pEnumTables->Item(index, &pTable));
+
+    CComBSTR pName;
+    VERIFY_SUCCEEDED(pTable->get_name(&pName));
+
+    if (pName == "Symbols") {
+      pSymbolTable = pTable;
+      break;
+    }
+  }
+
+  std::wstring Args;
+  std::wstring Entry;
+  std::wstring Target;
+  std::vector<std::wstring> Defines;
+  std::vector<std::wstring> Flags;
+
+  auto ReadNullSeparatedTokens = [](BSTR Str) -> std::vector<std::wstring> {
+    std::vector<std::wstring> Result;
+    while (*Str) {
+      Result.push_back(std::wstring(Str));
+      Str += wcslen(Str)+1;
+    }
+    return Result;
+  };
+
+  VERIFY_SUCCEEDED(pSymbolTable->get_Count(&uCount));
+  for (int i = 0; i < uCount; i++) {
+    CComPtr<IUnknown> pSymbolUnk;
+    CComPtr<IDiaSymbol> pSymbol;
+    CComVariant pValue;
+    CComBSTR pName;
+    VERIFY_SUCCEEDED(pSymbolTable->Item(i, &pSymbolUnk));
+    VERIFY_SUCCEEDED(pSymbolUnk->QueryInterface(&pSymbol));
+    VERIFY_SUCCEEDED(pSymbol->get_name(&pName));
+    VERIFY_SUCCEEDED(pSymbol->get_value(&pValue));
+    if (pName == "hlslTarget") {
+      if (pValue.vt == VT_BSTR)
+        Target = pValue.bstrVal;
+    }
+    else if (pName == "hlslEntry") {
+      if (pValue.vt == VT_BSTR)
+        Entry = pValue.bstrVal;
+    }
+    else if (pName == "hlslFlags") {
+      if (pValue.vt == VT_BSTR)
+        Flags = ReadNullSeparatedTokens(pValue.bstrVal);
+    }
+    else if (pName == "hlslArguments") {
+      if (pValue.vt == VT_BSTR)
+        Args = pValue.bstrVal;
+    }
+    else if (pName == "hlslDefines") {
+      if (pValue.vt == VT_BSTR)
+        Defines = ReadNullSeparatedTokens(pValue.bstrVal);
+    }
+  }
+
+  auto VectorContains = [](std::vector<std::wstring> &Tokens, std::wstring Sub) {
+    for (unsigned i = 0; i < Tokens.size(); i++) {
+      if (Tokens[i].find(Sub) != std::wstring::npos)
+        return true;
+    }
+    return false;
+  };
+
+  VERIFY_IS_TRUE(Target == L"ps_6_0");
+  VERIFY_IS_TRUE(Entry == L"main");
+
+  VERIFY_IS_TRUE(_countof(FlagList) == Flags.size());
+  for (unsigned i = 0; i < _countof(FlagList); i++) {
+    VERIFY_IS_TRUE(Flags[i] == FlagList[i]);
+  }
+  for (unsigned i = 0; i < _countof(DefineList); i++) {
+    VERIFY_IS_TRUE(VectorContains(Defines, DefineList[i]));
+  }
+}
+
 TEST_F(CompilerTest, DiaLoadBitcodePlusExtraData) {
   // Test that dia doesn't crash when bitcode has unused extra data at the end