DxcLangExtensionsHelper.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // DxcLangExtensionsHelper.h //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides a helper class to implement language extensions to HLSL. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #ifndef __DXCLANGEXTENSIONSHELPER_H__
  12. #define __DXCLANGEXTENSIONSHELPER_H__
  13. #include "dxc/Support/Unicode.h"
  14. #include "dxc/Support/FileIOHelper.h"
  15. #include <vector>
  16. namespace llvm {
  17. class raw_string_ostream;
  18. class CallInst;
  19. class Value;
  20. }
  21. namespace clang {
  22. class CompilerInstance;
  23. }
  24. namespace hlsl {
  25. class DxcLangExtensionsHelper : public DxcLangExtensionsHelperApply {
  26. private:
  27. llvm::SmallVector<std::string, 2> m_semanticDefines;
  28. llvm::SmallVector<std::string, 2> m_semanticDefineExclusions;
  29. llvm::SmallVector<std::string, 2> m_defines;
  30. llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2> m_intrinsicTables;
  31. CComPtr<IDxcSemanticDefineValidator> m_semanticDefineValidator;
  32. std::string m_semanticDefineMetaDataName;
  33. HRESULT STDMETHODCALLTYPE RegisterIntoVector(LPCWSTR name, llvm::SmallVector<std::string, 2>& here)
  34. {
  35. try {
  36. IFTPTR(name);
  37. std::string s;
  38. if (!Unicode::UTF16ToUTF8String(name, &s)) {
  39. throw ::hlsl::Exception(E_INVALIDARG);
  40. }
  41. here.push_back(s);
  42. return S_OK;
  43. }
  44. CATCH_CPP_RETURN_HRESULT();
  45. }
  46. public:
  47. const llvm::SmallVector<std::string, 2>& GetSemanticDefines() const { return m_semanticDefines; }
  48. const llvm::SmallVector<std::string, 2>& GetSemanticDefineExclusions() const { return m_semanticDefineExclusions; }
  49. const llvm::SmallVector<std::string, 2>& GetDefines() const { return m_defines; }
  50. llvm::SmallVector<CComPtr<IDxcIntrinsicTable>, 2>& GetIntrinsicTables(){ return m_intrinsicTables; }
  51. const std::string &GetSemanticDefineMetadataName() { return m_semanticDefineMetaDataName; }
  52. HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name)
  53. {
  54. return RegisterIntoVector(name, m_semanticDefines);
  55. }
  56. HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name)
  57. {
  58. return RegisterIntoVector(name, m_semanticDefineExclusions);
  59. }
  60. HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name)
  61. {
  62. return RegisterIntoVector(name, m_defines);
  63. }
  64. HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable* pTable)
  65. {
  66. try {
  67. IFTPTR(pTable);
  68. LPCSTR tableName = nullptr;
  69. IFT(pTable->GetTableName(&tableName));
  70. IFTPTR(tableName);
  71. IFTARG(strcmp(tableName, "op") != 0); // "op" is reserved for builtin intrinsics
  72. for (auto &&table : m_intrinsicTables) {
  73. LPCSTR otherTableName = nullptr;
  74. IFT(table->GetTableName(&otherTableName));
  75. IFTPTR(otherTableName);
  76. IFTARG(strcmp(tableName, otherTableName) != 0); // Added a duplicate table name
  77. }
  78. m_intrinsicTables.push_back(pTable);
  79. return S_OK;
  80. }
  81. CATCH_CPP_RETURN_HRESULT();
  82. }
  83. // Set the validator used to validate semantic defines.
  84. // Only one validator stored and used to run validation.
  85. HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) {
  86. if (pValidator == nullptr)
  87. return E_POINTER;
  88. m_semanticDefineValidator = pValidator;
  89. return S_OK;
  90. }
  91. HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) {
  92. try {
  93. m_semanticDefineMetaDataName = name;
  94. return S_OK;
  95. }
  96. CATCH_CPP_RETURN_HRESULT();
  97. }
  98. // Get the name of the dxil intrinsic function.
  99. std::string GetIntrinsicName(UINT opcode) {
  100. LPCSTR pName = "";
  101. for (IDxcIntrinsicTable *table : m_intrinsicTables) {
  102. if (SUCCEEDED(table->GetIntrinsicName(opcode, &pName))) {
  103. return pName;
  104. }
  105. }
  106. return "";
  107. }
  108. // Get the dxil opcode for the extension opcode if one exists.
  109. // Return true if the opcode was mapped successfully.
  110. bool GetDxilOpCode(UINT opcode, UINT &dxilOpcode) {
  111. for (IDxcIntrinsicTable *table : m_intrinsicTables) {
  112. if (SUCCEEDED(table->GetDxilOpCode(opcode, &dxilOpcode))) {
  113. return true;
  114. }
  115. }
  116. return false;
  117. }
  118. // Result of validating a semantic define.
  119. // Stores any warning or error messages produced by the validator.
  120. // Successful validation means that there are no warning or error messages.
  121. struct SemanticDefineValidationResult {
  122. std::string Warning;
  123. std::string Error;
  124. bool HasError() { return Error.size() > 0; }
  125. bool HasWarning() { return Warning.size() > 0; }
  126. static SemanticDefineValidationResult Success() {
  127. return SemanticDefineValidationResult();
  128. }
  129. };
  130. // Use the contained semantice define validator to validate the given semantic define.
  131. SemanticDefineValidationResult ValidateSemanticDefine(const std::string &name, const std::string &value) {
  132. if (!m_semanticDefineValidator)
  133. return SemanticDefineValidationResult::Success();
  134. // Blobs for getting restul from validator. Strings for returning results to caller.
  135. CComPtr<IDxcBlobEncoding> pError;
  136. CComPtr<IDxcBlobEncoding> pWarning;
  137. std::string error;
  138. std::string warning;
  139. // Run semantic define validator.
  140. HRESULT result = m_semanticDefineValidator->GetSemanticDefineWarningsAndErrors(name.c_str(), value.c_str(), &pWarning, &pError);
  141. if (FAILED(result)) {
  142. // Failure indicates it was not able to even run validation so
  143. // we cannot say whether the define is invalid or not. Return a
  144. // generic error message about failure to run the valiadator.
  145. error = "failed to run semantic define validator for: ";
  146. error.append(name); error.append("="); error.append(value);
  147. return SemanticDefineValidationResult{ warning, error };
  148. }
  149. // Define a little function to convert encoded blob into a string.
  150. auto GetErrorAsString = [&name](const CComPtr<IDxcBlobEncoding> &pBlobString) -> std::string {
  151. CComPtr<IDxcBlobEncoding> pUTF8BlobStr;
  152. if (SUCCEEDED(hlsl::DxcGetBlobAsUtf8(pBlobString, &pUTF8BlobStr)))
  153. return std::string(static_cast<char*>(pUTF8BlobStr->GetBufferPointer()), pUTF8BlobStr->GetBufferSize());
  154. else
  155. return std::string("invalid semantic define " + name);
  156. };
  157. // Check to see if any warnings or errors were produced.
  158. if (pError && pError->GetBufferSize()) {
  159. error = GetErrorAsString(pError);
  160. }
  161. if (pWarning && pWarning->GetBufferSize()) {
  162. warning = GetErrorAsString(pWarning);
  163. }
  164. return SemanticDefineValidationResult{ warning, error };
  165. }
  166. __override void SetupSema(clang::Sema &S) {
  167. clang::ExternalASTSource *astSource = S.getASTContext().getExternalSource();
  168. if (clang::ExternalSemaSource *externalSema =
  169. llvm::dyn_cast_or_null<clang::ExternalSemaSource>(astSource)) {
  170. for (auto &&table : m_intrinsicTables) {
  171. hlsl::RegisterIntrinsicTable(externalSema, table);
  172. }
  173. }
  174. }
  175. __override void SetupPreprocessorOptions(clang::PreprocessorOptions &PPOpts) {
  176. for (const auto & define : m_defines) {
  177. PPOpts.addMacroDef(llvm::StringRef(define.c_str()));
  178. }
  179. }
  180. __override DxcLangExtensionsHelper *GetDxcLangExtensionsHelper() {
  181. return this;
  182. }
  183. DxcLangExtensionsHelper()
  184. : m_semanticDefineMetaDataName("hlsl.semdefs")
  185. {}
  186. };
  187. // Use this macro to embed an implementation that will delegate to a field.
  188. // Note that QueryInterface still needs to return the vtable.
  189. #define DXC_LANGEXTENSIONS_HELPER_IMPL(_helper_field_) \
  190. __override HRESULT STDMETHODCALLTYPE RegisterIntrinsicTable(_In_ IDxcIntrinsicTable *pTable) { \
  191. DxcThreadMalloc TM(m_pMalloc); \
  192. return (_helper_field_).RegisterIntrinsicTable(pTable); \
  193. } \
  194. __override HRESULT STDMETHODCALLTYPE RegisterSemanticDefine(LPCWSTR name) { \
  195. DxcThreadMalloc TM(m_pMalloc); \
  196. return (_helper_field_).RegisterSemanticDefine(name); \
  197. } \
  198. __override HRESULT STDMETHODCALLTYPE RegisterSemanticDefineExclusion(LPCWSTR name) { \
  199. DxcThreadMalloc TM(m_pMalloc); \
  200. return (_helper_field_).RegisterSemanticDefineExclusion(name); \
  201. } \
  202. __override HRESULT STDMETHODCALLTYPE RegisterDefine(LPCWSTR name) { \
  203. DxcThreadMalloc TM(m_pMalloc); \
  204. return (_helper_field_).RegisterDefine(name); \
  205. } \
  206. __override HRESULT STDMETHODCALLTYPE SetSemanticDefineValidator(_In_ IDxcSemanticDefineValidator* pValidator) { \
  207. DxcThreadMalloc TM(m_pMalloc); \
  208. return (_helper_field_).SetSemanticDefineValidator(pValidator); \
  209. } \
  210. __override HRESULT STDMETHODCALLTYPE SetSemanticDefineMetaDataName(LPCSTR name) { \
  211. DxcThreadMalloc TM(m_pMalloc); \
  212. return (_helper_field_).SetSemanticDefineMetaDataName(name); \
  213. } \
  214. // A parsed semantic define is a semantic define that has actually been
  215. // parsed by the compiler. It has a name (required), a value (could be
  216. // the empty string), and a location. We use an encoded clang::SourceLocation
  217. // for the location to avoid a clang include dependency.
  218. struct ParsedSemanticDefine{
  219. std::string Name;
  220. std::string Value;
  221. unsigned Location;
  222. };
  223. typedef std::vector<ParsedSemanticDefine> ParsedSemanticDefineList;
  224. // Return the collection of semantic defines parsed by the compiler instance.
  225. ParsedSemanticDefineList
  226. CollectSemanticDefinesParsedByCompiler(clang::CompilerInstance &compiler,
  227. _In_ DxcLangExtensionsHelper *helper);
  228. } // namespace hlsl
  229. #endif