ExtensionTest.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // Copyright (C) Microsoft Corporation. All rights reserved. //
  4. // ExtensionTest.cpp //
  5. // //
  6. // Provides tests for the language extension APIs. //
  7. // //
  8. ///////////////////////////////////////////////////////////////////////////////
  9. #include "dxc/Test/CompilationResult.h"
  10. #include "dxc/Test/HlslTestUtils.h"
  11. #include "dxc/Test/DxcTestUtils.h"
  12. #include "dxc/Support/microcom.h"
  13. #include "dxc/dxcapi.internal.h"
  14. #include "dxc/HLSL/HLOperationLowerExtension.h"
  15. #include "dxc/HlslIntrinsicOp.h"
  16. #include "dxc/DXIL/DxilOperations.h"
  17. #include "llvm/Support/Regex.h"
  18. ///////////////////////////////////////////////////////////////////////////////
  19. // Support for test intrinsics.
  20. // $result = test_fn(any_vector<any_cardinality> value)
  21. static const HLSL_INTRINSIC_ARGUMENT TestFnArgs[] = {
  22. { "test_fn", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  23. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
  24. };
  25. // void test_proc(any_vector<any_cardinality> value)
  26. static const HLSL_INTRINSIC_ARGUMENT TestProcArgs[] = {
  27. { "test_proc", 0, 0, LITEMPLATE_VOID, 0, LICOMPTYPE_VOID, 0, 0 },
  28. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
  29. };
  30. // $result = test_poly(any_vector<any_cardinality> value)
  31. static const HLSL_INTRINSIC_ARGUMENT TestFnCustomArgs[] = {
  32. { "test_poly", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  33. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
  34. };
  35. // $result = test_int(int<any_cardinality> value)
  36. static const HLSL_INTRINSIC_ARGUMENT TestFnIntArgs[] = {
  37. { "test_int", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_INT, 1, IA_C },
  38. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_INT, 1, IA_C }
  39. };
  40. // $result = test_nolower(any_vector<any_cardinality> value)
  41. static const HLSL_INTRINSIC_ARGUMENT TestFnNoLowerArgs[] = {
  42. { "test_nolower", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  43. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
  44. };
  45. // void test_pack_0(any_vector<any_cardinality> value)
  46. static const HLSL_INTRINSIC_ARGUMENT TestFnPack0[] = {
  47. { "test_pack_0", 0, 0, LITEMPLATE_VOID, 0, LICOMPTYPE_VOID, 0, 0 },
  48. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
  49. };
  50. // $result = test_pack_1()
  51. static const HLSL_INTRINSIC_ARGUMENT TestFnPack1[] = {
  52. { "test_pack_1", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_FLOAT, 1, 2 },
  53. };
  54. // $result = test_pack_2(any_vector<any_cardinality> value1, any_vector<any_cardinality> value2)
  55. static const HLSL_INTRINSIC_ARGUMENT TestFnPack2[] = {
  56. { "test_pack_2", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  57. { "value1", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  58. { "value2", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  59. };
  60. // $scalar = test_pack_3(any_vector<any_cardinality> value)
  61. static const HLSL_INTRINSIC_ARGUMENT TestFnPack3[] = {
  62. { "test_pack_3", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_FLOAT, 1, 1 },
  63. { "value1", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_FLOAT, 1, 2},
  64. };
  65. // float<2> = test_pack_4(float<3> value)
  66. static const HLSL_INTRINSIC_ARGUMENT TestFnPack4[] = {
  67. { "test_pack_4", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_FLOAT, 1, 2 },
  68. { "value", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_FLOAT, 1, 3},
  69. };
  70. // float<2> = test_rand()
  71. static const HLSL_INTRINSIC_ARGUMENT TestRand[] = {
  72. { "test_rand", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_FLOAT, 1, 2 },
  73. };
  74. // uint = test_rand(uint x)
  75. static const HLSL_INTRINSIC_ARGUMENT TestUnsigned[] = {
  76. { "test_unsigned", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_UINT, 1, 1 },
  77. { "x", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 1},
  78. };
  79. // float2 = MyBufferOp(uint2 addr)
  80. static const HLSL_INTRINSIC_ARGUMENT TestMyBufferOp[] = {
  81. { "MyBufferOp", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_FLOAT, 1, 2 },
  82. { "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
  83. };
  84. // bool<> = test_isinf(float<> x)
  85. static const HLSL_INTRINSIC_ARGUMENT TestIsInf[] = {
  86. { "test_isinf", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_BOOL, 1, IA_C },
  87. { "x", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_FLOAT, 1, IA_C},
  88. };
  89. // int = test_ibfe(uint width, uint offset, uint val)
  90. static const HLSL_INTRINSIC_ARGUMENT TestIBFE[] = {
  91. { "test_ibfe", AR_QUAL_OUT, 0, LITEMPLATE_SCALAR, 0, LICOMPTYPE_INT, 1, 1 },
  92. { "width", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
  93. { "offset", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
  94. { "val", AR_QUAL_IN, 1, LITEMPLATE_SCALAR, 1, LICOMPTYPE_UINT, 1, 1},
  95. };
  96. // float2 = MySamplerOp(uint2 addr)
  97. static const HLSL_INTRINSIC_ARGUMENT TestMySamplerOp[] = {
  98. { "MySamplerOp", AR_QUAL_OUT, 0, LITEMPLATE_VECTOR, 0, LICOMPTYPE_FLOAT, 1, 2 },
  99. { "addr", AR_QUAL_IN, 1, LITEMPLATE_VECTOR, 1, LICOMPTYPE_UINT, 1, 2},
  100. };
  101. // $result = wave_proc(any_vector<any_cardinality> value)
  102. static const HLSL_INTRINSIC_ARGUMENT WaveProcArgs[] = {
  103. { "wave_proc", AR_QUAL_OUT, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C },
  104. { "value", AR_QUAL_IN, 1, LITEMPLATE_ANY, 1, LICOMPTYPE_NUMERIC, 1, IA_C }
  105. };
  106. struct Intrinsic {
  107. LPCWSTR hlslName;
  108. const char *dxilName;
  109. const char *strategy;
  110. HLSL_INTRINSIC hlsl;
  111. };
  112. const char * DEFAULT_NAME = "";
  113. // llvm::array_lengthof that returns a UINT instead of size_t
  114. template <class T, std::size_t N>
  115. UINT countof(T(&)[N]) { return static_cast<UINT>(N); }
  116. Intrinsic Intrinsics[] = {
  117. {L"test_fn", DEFAULT_NAME, "r", { 1, false, true, false, -1, countof(TestFnArgs), TestFnArgs }},
  118. {L"test_proc", DEFAULT_NAME, "r", { 2, false, false, false,-1, countof(TestProcArgs), TestProcArgs }},
  119. {L"test_poly", "test_poly.$o", "r", { 3, false, true, false, -1, countof(TestFnCustomArgs), TestFnCustomArgs }},
  120. {L"test_int", "test_int", "r", { 4, false, true, false, -1, countof(TestFnIntArgs), TestFnIntArgs}},
  121. {L"test_nolower", "test_nolower.$o", "n", { 5, false, true, false, -1, countof(TestFnNoLowerArgs), TestFnNoLowerArgs}},
  122. {L"test_pack_0", "test_pack_0.$o", "p", { 6, false, false, false,-1, countof(TestFnPack0), TestFnPack0}},
  123. {L"test_pack_1", "test_pack_1.$o", "p", { 7, false, true, false, -1, countof(TestFnPack1), TestFnPack1}},
  124. {L"test_pack_2", "test_pack_2.$o", "p", { 8, false, true, false, -1, countof(TestFnPack2), TestFnPack2}},
  125. {L"test_pack_3", "test_pack_3.$o", "p", { 9, false, true, false, -1, countof(TestFnPack3), TestFnPack3}},
  126. {L"test_pack_4", "test_pack_4.$o", "p", { 10, false, false, false,-1, countof(TestFnPack4), TestFnPack4}},
  127. {L"test_rand", "test_rand", "r", { 11, false, false, false,-1, countof(TestRand), TestRand}},
  128. {L"test_isinf", "test_isinf", "d", { 13, true, true, false, -1, countof(TestIsInf), TestIsInf}},
  129. {L"test_ibfe", "test_ibfe", "d", { 14, true, true, false, -1, countof(TestIBFE), TestIBFE}},
  130. // Make this intrinsic have the same opcode as an hlsl intrinsic with an unsigned
  131. // counterpart for testing purposes.
  132. {L"test_unsigned","test_unsigned", "n", { static_cast<unsigned>(hlsl::IntrinsicOp::IOP_min), false, true, false, -1, countof(TestUnsigned), TestUnsigned}},
  133. {L"wave_proc", DEFAULT_NAME, "r", { 16, false, true, true, -1, countof(WaveProcArgs), WaveProcArgs }},
  134. };
  135. Intrinsic BufferIntrinsics[] = {
  136. {L"MyBufferOp", "MyBufferOp", "m", { 12, false, true, false, -1, countof(TestMyBufferOp), TestMyBufferOp}},
  137. };
  138. // Test adding a method to an object that normally has no methods (SamplerState will do).
  139. Intrinsic SamplerIntrinsics[] = {
  140. {L"MySamplerOp", "MySamplerOp", "m", { 15, false, true, false, -1, countof(TestMySamplerOp), TestMySamplerOp}},
  141. };
  142. class IntrinsicTable {
  143. public:
  144. IntrinsicTable(const wchar_t *ns, Intrinsic *begin, Intrinsic *end)
  145. : m_namespace(ns), m_begin(begin), m_end(end)
  146. { }
  147. struct SearchResult {
  148. Intrinsic *intrinsic;
  149. uint64_t index;
  150. SearchResult() : SearchResult(nullptr, 0) {}
  151. SearchResult(Intrinsic *i, uint64_t n) : intrinsic(i), index(n) {}
  152. operator bool() { return intrinsic != nullptr; }
  153. };
  154. SearchResult Search(const wchar_t *name, std::ptrdiff_t startIndex) const {
  155. Intrinsic *begin = m_begin + startIndex;
  156. assert(std::distance(begin, m_end) >= 0);
  157. if (IsStar(name))
  158. return BuildResult(begin);
  159. Intrinsic *found = std::find_if(begin, m_end, [name](const Intrinsic &i) {
  160. return wcscmp(i.hlslName, name) == 0;
  161. });
  162. return BuildResult(found);
  163. }
  164. SearchResult Search(unsigned opcode) const {
  165. Intrinsic *begin = m_begin;
  166. assert(std::distance(begin, m_end) >= 0);
  167. Intrinsic *found = std::find_if(begin, m_end, [opcode](const Intrinsic &i) {
  168. return i.hlsl.Op == opcode;
  169. });
  170. return BuildResult(found);
  171. }
  172. bool MatchesNamespace(const wchar_t *ns) const {
  173. return wcscmp(m_namespace, ns) == 0;
  174. }
  175. private:
  176. const wchar_t *m_namespace;
  177. Intrinsic *m_begin;
  178. Intrinsic *m_end;
  179. bool IsStar(const wchar_t *name) const {
  180. return wcscmp(name, L"*") == 0;
  181. }
  182. SearchResult BuildResult(Intrinsic *found) const {
  183. if (found == m_end)
  184. return SearchResult{ nullptr, std::numeric_limits<uint64_t>::max() };
  185. return SearchResult{ found, static_cast<uint64_t>(std::distance(m_begin, found)) };
  186. }
  187. };
  188. class TestIntrinsicTable : public IDxcIntrinsicTable {
  189. private:
  190. DXC_MICROCOM_REF_FIELD(m_dwRef)
  191. std::vector<IntrinsicTable> m_tables;
  192. public:
  193. TestIntrinsicTable() : m_dwRef(0) {
  194. m_tables.push_back(IntrinsicTable(L"", std::begin(Intrinsics), std::end(Intrinsics)));
  195. m_tables.push_back(IntrinsicTable(L"Buffer", std::begin(BufferIntrinsics), std::end(BufferIntrinsics)));
  196. m_tables.push_back(IntrinsicTable(L"SamplerState", std::begin(SamplerIntrinsics), std::end(SamplerIntrinsics)));
  197. }
  198. DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
  199. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override {
  200. return DoBasicQueryInterface<IDxcIntrinsicTable>(this, iid, ppvObject);
  201. }
  202. HRESULT STDMETHODCALLTYPE
  203. GetTableName(_Outptr_ LPCSTR *pTableName) override {
  204. *pTableName = "test";
  205. return S_OK;
  206. }
  207. HRESULT STDMETHODCALLTYPE LookupIntrinsic(
  208. LPCWSTR typeName, LPCWSTR functionName, const HLSL_INTRINSIC **pIntrinsic,
  209. _Inout_ UINT64 *pLookupCookie) override {
  210. if (typeName == nullptr)
  211. return E_FAIL;
  212. // Search for matching intrinsic name in matching namespace.
  213. IntrinsicTable::SearchResult result;
  214. for (const IntrinsicTable &table : m_tables) {
  215. if (table.MatchesNamespace(typeName)) {
  216. result = table.Search(functionName, *pLookupCookie);
  217. break;
  218. }
  219. }
  220. if (result) {
  221. *pIntrinsic = &result.intrinsic->hlsl;
  222. *pLookupCookie = result.index + 1;
  223. }
  224. else {
  225. *pIntrinsic = nullptr;
  226. *pLookupCookie = 0;
  227. }
  228. return result.intrinsic ? S_OK : E_FAIL;
  229. }
  230. HRESULT STDMETHODCALLTYPE
  231. GetLoweringStrategy(UINT opcode, _Outptr_ LPCSTR *pStrategy) override {
  232. Intrinsic *intrinsic = FindByOpcode(opcode);
  233. if (!intrinsic)
  234. return E_FAIL;
  235. *pStrategy = intrinsic->strategy;
  236. return S_OK;
  237. }
  238. HRESULT STDMETHODCALLTYPE
  239. GetIntrinsicName(UINT opcode, _Outptr_ LPCSTR *pName) override {
  240. Intrinsic *intrinsic = FindByOpcode(opcode);
  241. if (!intrinsic)
  242. return E_FAIL;
  243. *pName = intrinsic->dxilName;
  244. return S_OK;
  245. }
  246. HRESULT STDMETHODCALLTYPE
  247. GetDxilOpCode(UINT opcode, _Outptr_ UINT *pDxilOpcode) override {
  248. if (opcode == 13) {
  249. *pDxilOpcode = static_cast<UINT>(hlsl::OP::OpCode::IsInf);
  250. return S_OK;
  251. }
  252. else if (opcode == 14) {
  253. *pDxilOpcode = static_cast<UINT>(hlsl::OP::OpCode::Ibfe);
  254. return S_OK;
  255. }
  256. return E_FAIL;
  257. }
  258. Intrinsic *FindByOpcode(UINT opcode) {
  259. IntrinsicTable::SearchResult result;
  260. for (const IntrinsicTable &table : m_tables) {
  261. result = table.Search(opcode);
  262. if (result)
  263. break;
  264. }
  265. return result.intrinsic;
  266. }
  267. };
  268. // A class to test semantic define validation.
  269. // It takes a list of defines that when present should cause errors
  270. // and defines that should cause warnings. A more realistic validator
  271. // would look at the values and make sure (for example) they are
  272. // the correct type (integer, string, etc).
  273. class TestSemanticDefineValidator : public IDxcSemanticDefineValidator {
  274. private:
  275. DXC_MICROCOM_REF_FIELD(m_dwRef)
  276. std::vector<std::string> m_errorDefines;
  277. std::vector<std::string> m_warningDefines;
  278. public:
  279. TestSemanticDefineValidator(const std::vector<std::string> &errorDefines, const std::vector<std::string> &warningDefines)
  280. : m_dwRef(0)
  281. , m_errorDefines(errorDefines)
  282. , m_warningDefines(warningDefines)
  283. { }
  284. DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
  285. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void** ppvObject) override {
  286. return DoBasicQueryInterface<IDxcSemanticDefineValidator>(this, iid, ppvObject);
  287. }
  288. virtual HRESULT STDMETHODCALLTYPE GetSemanticDefineWarningsAndErrors(LPCSTR pName, LPCSTR pValue, IDxcBlobEncoding **ppWarningBlob, IDxcBlobEncoding **ppErrorBlob) override {
  289. if (!pName || !pValue || !ppWarningBlob || !ppErrorBlob)
  290. return E_FAIL;
  291. auto Check = [pName](const std::vector<std::string> &errors, IDxcBlobEncoding **blob) {
  292. if (std::find(errors.begin(), errors.end(), pName) != errors.end()) {
  293. dxc::DxcDllSupport dllSupport;
  294. VERIFY_SUCCEEDED(dllSupport.Initialize());
  295. std::string error("bad define: ");
  296. error.append(pName);
  297. Utf8ToBlob(dllSupport, error.c_str(), blob);
  298. }
  299. };
  300. Check(m_errorDefines, ppErrorBlob);
  301. Check(m_warningDefines, ppWarningBlob);
  302. return S_OK;
  303. }
  304. };
  305. static void CheckOperationFailed(IDxcOperationResult *pResult) {
  306. HRESULT status;
  307. VERIFY_SUCCEEDED(pResult->GetStatus(&status));
  308. VERIFY_FAILED(status);
  309. }
  310. static std::string GetCompileErrors(IDxcOperationResult *pResult) {
  311. CComPtr<IDxcBlobEncoding> pErrors;
  312. VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&pErrors));
  313. if (!pErrors)
  314. return "";
  315. return BlobToUtf8(pErrors);
  316. }
  317. class Compiler {
  318. public:
  319. Compiler(dxc::DxcDllSupport &dll) : m_dllSupport(dll) {
  320. VERIFY_SUCCEEDED(m_dllSupport.Initialize());
  321. VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  322. VERIFY_SUCCEEDED(pCompiler.QueryInterface(&pLangExtensions));
  323. }
  324. void RegisterSemanticDefine(LPCWSTR define) {
  325. VERIFY_SUCCEEDED(pLangExtensions->RegisterSemanticDefine(define));
  326. }
  327. void RegisterSemanticDefineExclusion(LPCWSTR define) {
  328. VERIFY_SUCCEEDED(pLangExtensions->RegisterSemanticDefineExclusion(define));
  329. }
  330. void SetSemanticDefineValidator(IDxcSemanticDefineValidator *validator) {
  331. pTestSemanticDefineValidator = validator;
  332. VERIFY_SUCCEEDED(pLangExtensions->SetSemanticDefineValidator(pTestSemanticDefineValidator));
  333. }
  334. void SetSemanticDefineMetaDataName(const char *name) {
  335. VERIFY_SUCCEEDED(pLangExtensions->SetSemanticDefineMetaDataName("test.defs"));
  336. }
  337. void RegisterIntrinsicTable(IDxcIntrinsicTable *table) {
  338. pTestIntrinsicTable = table;
  339. VERIFY_SUCCEEDED(pLangExtensions->RegisterIntrinsicTable(pTestIntrinsicTable));
  340. }
  341. IDxcOperationResult *Compile(const char *program) {
  342. return Compile(program, {}, {});
  343. }
  344. IDxcOperationResult *Compile(const char *program, const std::vector<LPCWSTR> &arguments, const std::vector<DxcDefine> defs ) {
  345. Utf8ToBlob(m_dllSupport, program, &pCodeBlob);
  346. VERIFY_SUCCEEDED(pCompiler->Compile(pCodeBlob, L"hlsl.hlsl", L"main",
  347. L"ps_6_0",
  348. const_cast<LPCWSTR *>(arguments.data()), arguments.size(),
  349. defs.data(), defs.size(),
  350. nullptr, &pCompileResult));
  351. return pCompileResult;
  352. }
  353. std::string Disassemble() {
  354. CComPtr<IDxcBlob> pBlob;
  355. CheckOperationSucceeded(pCompileResult, &pBlob);
  356. return DisassembleProgram(m_dllSupport, pBlob);
  357. }
  358. dxc::DxcDllSupport &m_dllSupport;
  359. CComPtr<IDxcCompiler> pCompiler;
  360. CComPtr<IDxcLangExtensions> pLangExtensions;
  361. CComPtr<IDxcBlobEncoding> pCodeBlob;
  362. CComPtr<IDxcOperationResult> pCompileResult;
  363. CComPtr<IDxcSemanticDefineValidator> pTestSemanticDefineValidator;
  364. CComPtr<IDxcIntrinsicTable> pTestIntrinsicTable;
  365. };
  366. ///////////////////////////////////////////////////////////////////////////////
  367. // Extension unit tests.
  368. #ifdef _WIN32
  369. class ExtensionTest {
  370. #else
  371. class ExtensionTest : public ::testing::Test {
  372. #endif
  373. public:
  374. BEGIN_TEST_CLASS(ExtensionTest)
  375. TEST_CLASS_PROPERTY(L"Parallel", L"true")
  376. TEST_METHOD_PROPERTY(L"Priority", L"0")
  377. END_TEST_CLASS()
  378. dxc::DxcDllSupport m_dllSupport;
  379. TEST_METHOD(DefineWhenRegisteredThenPreserved)
  380. TEST_METHOD(DefineValidationError)
  381. TEST_METHOD(DefineValidationWarning)
  382. TEST_METHOD(DefineNoValidatorOk)
  383. TEST_METHOD(DefineFromMacro)
  384. TEST_METHOD(IntrinsicWhenAvailableThenUsed)
  385. TEST_METHOD(CustomIntrinsicName)
  386. TEST_METHOD(NoLowering)
  387. TEST_METHOD(PackedLowering)
  388. TEST_METHOD(ReplicateLoweringWhenOnlyVectorIsResult)
  389. TEST_METHOD(UnsignedOpcodeIsUnchanged)
  390. TEST_METHOD(ResourceExtensionIntrinsic)
  391. TEST_METHOD(NameLoweredWhenNoReplicationNeeded)
  392. TEST_METHOD(DxilLoweringVector1)
  393. TEST_METHOD(DxilLoweringVector2)
  394. TEST_METHOD(DxilLoweringScalar)
  395. TEST_METHOD(SamplerExtensionIntrinsic)
  396. TEST_METHOD(WaveIntrinsic)
  397. };
  398. TEST_F(ExtensionTest, DefineWhenRegisteredThenPreserved) {
  399. Compiler c(m_dllSupport);
  400. c.RegisterSemanticDefine(L"FOO*");
  401. c.RegisterSemanticDefineExclusion(L"FOOBAR");
  402. c.SetSemanticDefineValidator(new TestSemanticDefineValidator({ "FOOLALA" }, {}));
  403. c.SetSemanticDefineMetaDataName("test.defs");
  404. c.Compile(
  405. "#define FOOTBALL AWESOME\n"
  406. "#define FOOTLOOSE TOO\n"
  407. "#define FOOBAR 123\n"
  408. "#define FOOD\n"
  409. "#define FOO 1 2 3\n"
  410. "float4 main() : SV_Target {\n"
  411. " return 0;\n"
  412. "}\n",
  413. {L"/Vd"},
  414. { { L"FOODEF", L"1"} }
  415. );
  416. std::string disassembly = c.Disassemble();
  417. // Check for root named md node. It contains pointers to md nodes for each define.
  418. VERIFY_IS_TRUE(
  419. disassembly.npos !=
  420. disassembly.find("!test.defs"));
  421. // #define FOODEF 1
  422. VERIFY_IS_TRUE(
  423. disassembly.npos !=
  424. disassembly.find("!{!\"FOODEF\", !\"1\"}"));
  425. // #define FOOTBALL AWESOME
  426. VERIFY_IS_TRUE(
  427. disassembly.npos !=
  428. disassembly.find("!{!\"FOOTBALL\", !\"AWESOME\"}"));
  429. // #define FOOTLOOSE TOO
  430. VERIFY_IS_TRUE(
  431. disassembly.npos !=
  432. disassembly.find("!{!\"FOOTLOOSE\", !\"TOO\"}"));
  433. // #define FOOD
  434. VERIFY_IS_TRUE(
  435. disassembly.npos !=
  436. disassembly.find("!{!\"FOOD\", !\"\"}"));
  437. // #define FOO 1 2 3
  438. VERIFY_IS_TRUE(
  439. disassembly.npos !=
  440. disassembly.find("!{!\"FOO\", !\"1 2 3\"}"));
  441. // FOOBAR should be excluded.
  442. VERIFY_IS_TRUE(
  443. disassembly.npos ==
  444. disassembly.find("!{!\"FOOBAR\""));
  445. }
  446. TEST_F(ExtensionTest, DefineValidationError) {
  447. Compiler c(m_dllSupport);
  448. c.RegisterSemanticDefine(L"FOO*");
  449. c.SetSemanticDefineValidator(new TestSemanticDefineValidator({ "FOO" }, {}));
  450. IDxcOperationResult *pCompileResult = c.Compile(
  451. "#define FOO 1\n"
  452. "float4 main() : SV_Target {\n"
  453. " return 0;\n"
  454. "}\n",
  455. {L"/Vd"}, {}
  456. );
  457. // Check that validation error causes compile failure.
  458. CheckOperationFailed(pCompileResult);
  459. std::string errors = GetCompileErrors(pCompileResult);
  460. // Check that the error message is for the validation failure.
  461. VERIFY_IS_TRUE(
  462. errors.npos !=
  463. errors.find("hlsl.hlsl:1:9: error: bad define: FOO"));
  464. }
  465. TEST_F(ExtensionTest, DefineValidationWarning) {
  466. Compiler c(m_dllSupport);
  467. c.RegisterSemanticDefine(L"FOO*");
  468. c.SetSemanticDefineValidator(new TestSemanticDefineValidator({}, { "FOO" }));
  469. IDxcOperationResult *pCompileResult = c.Compile(
  470. "#define FOO 1\n"
  471. "float4 main() : SV_Target {\n"
  472. " return 0;\n"
  473. "}\n",
  474. { L"/Vd" }, {}
  475. );
  476. std::string errors = GetCompileErrors(pCompileResult);
  477. // Check that the error message is for the validation failure.
  478. VERIFY_IS_TRUE(
  479. errors.npos !=
  480. errors.find("hlsl.hlsl:1:9: warning: bad define: FOO"));
  481. // Check the define is still emitted.
  482. std::string disassembly = c.Disassemble();
  483. // Check for root named md node. It contains pointers to md nodes for each define.
  484. VERIFY_IS_TRUE(
  485. disassembly.npos !=
  486. disassembly.find("!hlsl.semdefs"));
  487. // #define FOO 1
  488. VERIFY_IS_TRUE(
  489. disassembly.npos !=
  490. disassembly.find("!{!\"FOO\", !\"1\"}"));
  491. }
  492. TEST_F(ExtensionTest, DefineNoValidatorOk) {
  493. Compiler c(m_dllSupport);
  494. c.RegisterSemanticDefine(L"FOO*");
  495. c.Compile(
  496. "#define FOO 1\n"
  497. "float4 main() : SV_Target {\n"
  498. " return 0;\n"
  499. "}\n",
  500. { L"/Vd" }, {}
  501. );
  502. std::string disassembly = c.Disassemble();
  503. // Check the define is emitted.
  504. // #define FOO 1
  505. VERIFY_IS_TRUE(
  506. disassembly.npos !=
  507. disassembly.find("!{!\"FOO\", !\"1\"}"));
  508. }
  509. TEST_F(ExtensionTest, DefineFromMacro) {
  510. Compiler c(m_dllSupport);
  511. c.RegisterSemanticDefine(L"FOO*");
  512. c.Compile(
  513. "#define BAR 1\n"
  514. "#define FOO BAR\n"
  515. "float4 main() : SV_Target {\n"
  516. " return 0;\n"
  517. "}\n",
  518. { L"/Vd" }, {}
  519. );
  520. std::string disassembly = c.Disassemble();
  521. // Check the define is emitted.
  522. // #define FOO 1
  523. VERIFY_IS_TRUE(
  524. disassembly.npos !=
  525. disassembly.find("!{!\"FOO\", !\"1\"}"));
  526. }
  527. TEST_F(ExtensionTest, IntrinsicWhenAvailableThenUsed) {
  528. Compiler c(m_dllSupport);
  529. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  530. c.Compile(
  531. "float2 main(float2 v : V, int2 i : I) : SV_Target {\n"
  532. " test_proc(v);\n"
  533. " float2 a = test_fn(v);\n"
  534. " int2 b = test_fn(i);\n"
  535. " return a + b;\n"
  536. "}\n",
  537. { L"/Vd" }, {}
  538. );
  539. std::string disassembly = c.Disassemble();
  540. // Things to call out:
  541. // - result is float, not a vector
  542. // - mangled name contains the 'test' and '.r' parts
  543. // - opcode is first i32 argument
  544. // - second argument is float, ie it got scalarized
  545. VERIFY_IS_TRUE(
  546. disassembly.npos !=
  547. disassembly.find("call void @\"test.\\01?test_proc@hlsl@@YAXV?$vector@M$01@@@Z.r\"(i32 2, float"));
  548. VERIFY_IS_TRUE(
  549. disassembly.npos !=
  550. disassembly.find("call float @\"test.\\01?test_fn@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32 1, float"));
  551. VERIFY_IS_TRUE(
  552. disassembly.npos !=
  553. disassembly.find("call i32 @\"test.\\01?test_fn@hlsl@@YA?AV?$vector@H$01@@V2@@Z.r\"(i32 1, i32"));
  554. // - attributes are added to the declaration (the # at the end of the decl)
  555. // TODO: would be nice to check for the actual attribute (e.g. readonly)
  556. VERIFY_IS_TRUE(
  557. disassembly.npos !=
  558. disassembly.find("declare float @\"test.\\01?test_fn@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32, float) #"));
  559. }
  560. TEST_F(ExtensionTest, CustomIntrinsicName) {
  561. Compiler c(m_dllSupport);
  562. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  563. c.Compile(
  564. "float2 main(float2 v : V, int2 i : I) : SV_Target {\n"
  565. " float2 a = test_poly(v);\n"
  566. " int2 b = test_poly(i);\n"
  567. " int2 c = test_int(i);\n"
  568. " return a + b + c;\n"
  569. "}\n",
  570. { L"/Vd" }, {}
  571. );
  572. std::string disassembly = c.Disassemble();
  573. // - custom name works for polymorphic function
  574. VERIFY_IS_TRUE(
  575. disassembly.npos !=
  576. disassembly.find("call float @test_poly.float(i32 3, float"));
  577. VERIFY_IS_TRUE(
  578. disassembly.npos !=
  579. disassembly.find("call i32 @test_poly.i32(i32 3, i32"));
  580. // - custom name works for non-polymorphic function
  581. VERIFY_IS_TRUE(
  582. disassembly.npos !=
  583. disassembly.find("call i32 @test_int(i32 4, i32"));
  584. }
  585. TEST_F(ExtensionTest, NoLowering) {
  586. Compiler c(m_dllSupport);
  587. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  588. c.Compile(
  589. "float2 main(float2 v : V, int2 i : I) : SV_Target {\n"
  590. " float2 a = test_nolower(v);\n"
  591. " float2 b = test_nolower(i);\n"
  592. " return a + b;\n"
  593. "}\n",
  594. { L"/Vd" }, {}
  595. );
  596. std::string disassembly = c.Disassemble();
  597. // - custom name works for non-lowered function
  598. // - non-lowered function has vector type as argument
  599. VERIFY_IS_TRUE(
  600. disassembly.npos !=
  601. disassembly.find("call <2 x float> @test_nolower.float(i32 5, <2 x float>"));
  602. VERIFY_IS_TRUE(
  603. disassembly.npos !=
  604. disassembly.find("call <2 x i32> @test_nolower.i32(i32 5, <2 x i32>"));
  605. }
  606. TEST_F(ExtensionTest, PackedLowering) {
  607. Compiler c(m_dllSupport);
  608. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  609. c.Compile(
  610. "float2 main(float2 v1 : V1, float2 v2 : V2, float3 v3 : V3) : SV_Target {\n"
  611. " test_pack_0(v1);\n"
  612. " int2 a = test_pack_1();\n"
  613. " float2 b = test_pack_2(v1, v2);\n"
  614. " float c = test_pack_3(v1);\n"
  615. " float2 d = test_pack_4(v3);\n"
  616. " return a + b + float2(c, c);\n"
  617. "}\n",
  618. { L"/Vd" }, {}
  619. );
  620. std::string disassembly = c.Disassemble();
  621. // - pack strategy changes vectors into structs
  622. VERIFY_IS_TRUE(
  623. disassembly.npos !=
  624. disassembly.find("call void @test_pack_0.float(i32 6, { float, float }"));
  625. VERIFY_IS_TRUE(
  626. disassembly.npos !=
  627. disassembly.find("call { float, float } @test_pack_1.float(i32 7)"));
  628. VERIFY_IS_TRUE(
  629. disassembly.npos !=
  630. disassembly.find("call { float, float } @test_pack_2.float(i32 8, { float, float }"));
  631. VERIFY_IS_TRUE(
  632. disassembly.npos !=
  633. disassembly.find("call float @test_pack_3.float(i32 9, { float, float }"));
  634. VERIFY_IS_TRUE(
  635. disassembly.npos !=
  636. disassembly.find("call { float, float } @test_pack_4.float(i32 10, { float, float, float }"));
  637. }
  638. TEST_F(ExtensionTest, ReplicateLoweringWhenOnlyVectorIsResult) {
  639. Compiler c(m_dllSupport);
  640. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  641. c.Compile(
  642. "float2 main(float2 v1 : V1, float2 v2 : V2, float3 v3 : V3) : SV_Target {\n"
  643. " return test_rand();\n"
  644. "}\n",
  645. { L"/Vd" }, {}
  646. );
  647. std::string disassembly = c.Disassemble();
  648. // - replicate strategy works for vector results
  649. VERIFY_IS_TRUE(
  650. disassembly.npos !=
  651. disassembly.find("call float @test_rand(i32 11)"));
  652. }
  653. TEST_F(ExtensionTest, UnsignedOpcodeIsUnchanged) {
  654. Compiler c(m_dllSupport);
  655. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  656. c.Compile(
  657. "uint main(uint v1 : V1) : SV_Target {\n"
  658. " return test_unsigned(v1);\n"
  659. "}\n",
  660. { L"/Vd" }, {}
  661. );
  662. std::string disassembly = c.Disassemble();
  663. // - opcode is unchanged when it matches an hlsl intrinsic with
  664. // an unsigned version.
  665. // This should use the same value as IOP_min.
  666. std::string matchStr;
  667. std::ostringstream ss(matchStr);
  668. ss << "call i32 @test_unsigned(i32 "
  669. << (unsigned)hlsl::IntrinsicOp::IOP_min
  670. << ", ";
  671. VERIFY_IS_TRUE(
  672. disassembly.npos !=
  673. disassembly.find(ss.str()));
  674. }
  675. TEST_F(ExtensionTest, ResourceExtensionIntrinsic) {
  676. Compiler c(m_dllSupport);
  677. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  678. c.Compile(
  679. "Buffer<float2> buf;"
  680. "float2 main(uint2 v1 : V1) : SV_Target {\n"
  681. " return buf.MyBufferOp(uint2(1, 2));\n"
  682. "}\n",
  683. { L"/Vd" }, {}
  684. );
  685. std::string disassembly = c.Disassemble();
  686. // Things to check
  687. // - return type is translated to dx.types.ResRet
  688. // - buffer is translated to dx.types.Handle
  689. // - vector is exploded
  690. llvm::Regex regex("call %dx.types.ResRet.f32 @MyBufferOp\\(i32 12, %dx.types.Handle %.*, i32 1, i32 2\\)");
  691. std::string regexErrors;
  692. VERIFY_IS_TRUE(regex.isValid(regexErrors));
  693. VERIFY_IS_TRUE(regex.match(disassembly));
  694. }
  695. TEST_F(ExtensionTest, NameLoweredWhenNoReplicationNeeded) {
  696. Compiler c(m_dllSupport);
  697. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  698. c.Compile(
  699. "int main(int v1 : V1) : SV_Target {\n"
  700. " return test_int(v1);\n"
  701. "}\n",
  702. { L"/Vd" }, {}
  703. );
  704. std::string disassembly = c.Disassemble();
  705. // Make sure the name is still lowered even when no replication
  706. // is needed because a non-vector overload of the function
  707. // was used.
  708. VERIFY_IS_TRUE(
  709. disassembly.npos !=
  710. disassembly.find("call i32 @test_int("));
  711. }
  712. TEST_F(ExtensionTest, DxilLoweringVector1) {
  713. Compiler c(m_dllSupport);
  714. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  715. c.Compile(
  716. "int main(float v1 : V1) : SV_Target {\n"
  717. " return test_isinf(v1);\n"
  718. "}\n",
  719. { L"/Vd" }, {}
  720. );
  721. std::string disassembly = c.Disassemble();
  722. // Check that the extension was lowered to the correct dxil intrinsic.
  723. static_assert(9 == (unsigned)hlsl::OP::OpCode::IsInf, "isinf opcode changed?");
  724. VERIFY_IS_TRUE(
  725. disassembly.npos !=
  726. disassembly.find("call i1 @dx.op.isSpecialFloat.f32(i32 9"));
  727. }
  728. TEST_F(ExtensionTest, DxilLoweringVector2) {
  729. Compiler c(m_dllSupport);
  730. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  731. c.Compile(
  732. "int2 main(float2 v1 : V1) : SV_Target {\n"
  733. " return test_isinf(v1);\n"
  734. "}\n",
  735. { L"/Vd" }, {}
  736. );
  737. std::string disassembly = c.Disassemble();
  738. // Check that the extension was lowered to the correct dxil intrinsic.
  739. static_assert(9 == (unsigned)hlsl::OP::OpCode::IsInf, "isinf opcode changed?");
  740. VERIFY_IS_TRUE(
  741. disassembly.npos !=
  742. disassembly.find("call i1 @dx.op.isSpecialFloat.f32(i32 9"));
  743. }
  744. TEST_F(ExtensionTest, DxilLoweringScalar) {
  745. Compiler c(m_dllSupport);
  746. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  747. c.Compile(
  748. "int main(uint v1 : V1, uint v2 : V2, uint v3 : V3) : SV_Target {\n"
  749. " return test_ibfe(v1, v2, v3);\n"
  750. "}\n",
  751. { L"/Vd" }, {}
  752. );
  753. std::string disassembly = c.Disassemble();
  754. // Check that the extension was lowered to the correct dxil intrinsic.
  755. static_assert(51 == (unsigned)hlsl::OP::OpCode::Ibfe, "ibfe opcode changed?");
  756. VERIFY_IS_TRUE(
  757. disassembly.npos !=
  758. disassembly.find("call i32 @dx.op.tertiary.i32(i32 51"));
  759. }
  760. TEST_F(ExtensionTest, SamplerExtensionIntrinsic) {
  761. // Test adding methods to objects that don't have any methods normally,
  762. // and therefore have null default intrinsic table.
  763. Compiler c(m_dllSupport);
  764. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  765. auto result = c.Compile(
  766. "SamplerState samp;"
  767. "float2 main(uint2 v1 : V1) : SV_Target {\n"
  768. " return samp.MySamplerOp(uint2(1, 2));\n"
  769. "}\n",
  770. { L"/Vd" }, {}
  771. );
  772. CheckOperationResultMsgs(result, {}, true, false);
  773. std::string disassembly = c.Disassemble();
  774. // Things to check
  775. // - works when SamplerState normally has no methods
  776. // - return type is translated to dx.types.ResRet
  777. // - buffer is translated to dx.types.Handle
  778. // - vector is exploded
  779. LPCSTR expected[] = {
  780. "call %dx.types.ResRet.f32 @MySamplerOp\\(i32 15, %dx.types.Handle %.*, i32 1, i32 2\\)"
  781. };
  782. CheckMsgs(disassembly.c_str(), disassembly.length(), expected, 1, true);
  783. }
  784. TEST_F(ExtensionTest, WaveIntrinsic) {
  785. // Test wave-sensitive intrinsic in breaked loop
  786. Compiler c(m_dllSupport);
  787. c.RegisterIntrinsicTable(new TestIntrinsicTable());
  788. c.Compile(
  789. "StructuredBuffer<int> buf[]: register(t2);"
  790. "float2 main(float2 a : A, int b : B) : SV_Target {"
  791. " int res = 0;"
  792. " float2 u = {0,0};"
  793. " for (;;) {"
  794. " u += wave_proc(a);"
  795. " if (a.x == u.x) {"
  796. " res += buf[b][(int)u.y];"
  797. " break;"
  798. " }"
  799. " }"
  800. " return res;"
  801. "}",
  802. { L"/Vd" }, {}
  803. );
  804. std::string disassembly = c.Disassemble();
  805. // Check that the wave op causes the break block to be retained
  806. VERIFY_IS_TRUE(
  807. disassembly.npos !=
  808. disassembly.find("@dx.break.cond = internal constant [1 x i32] zeroinitializer"));
  809. VERIFY_IS_TRUE(
  810. disassembly.npos !=
  811. disassembly.find("%1 = load i32, i32* getelementptr inbounds ([1 x i32], [1 x i32]* @dx.break.cond"));
  812. VERIFY_IS_TRUE(
  813. disassembly.npos !=
  814. disassembly.find("%2 = icmp eq i32 %1, 0"));
  815. VERIFY_IS_TRUE(
  816. disassembly.npos !=
  817. disassembly.find("call float @\"test.\\01?wave_proc@hlsl@@YA?AV?$vector@M$01@@V2@@Z.r\"(i32 16, float"));
  818. VERIFY_IS_TRUE(
  819. disassembly.npos !=
  820. disassembly.find("br i1 %2"));
  821. }