test_DxrFallback.cpp 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974
  1. #include "dxc/Support/Global.h"
  2. #include "dxc/Support/Unicode.h"
  3. #include "dxc/Support/WinIncludes.h"
  4. #include "dxc/dxcapi.h"
  5. #include "dxc/Support/dxcapi.use.h"
  6. #include "dxc/Support/FileIOHelper.h"
  7. #include "dxc/Support/dxcapi.impl.h"
  8. #include "dxc/dxcdxrfallbackcompiler.h"
  9. #include "dxc/support/dxcapi.use.h"
  10. #include "llvm/IR/LLVMContext.h"
  11. #include "llvm/IR/Module.h"
  12. #include "llvm/Support/FileSystem.h"
  13. #include "llvm/Support/MSFileSystem.h"
  14. #include "defaultTestFilePath.h"
  15. #include "ShaderTester.h"
  16. #undef IGNORE
  17. #undef OPAQUE
  18. #include "testFiles/testTraversal.h"
  19. #include <algorithm>
  20. #include <iomanip>
  21. #include <iostream>
  22. #include <vector>
  23. using namespace dxc;
  24. using namespace llvm;
  25. using namespace hlsl;
  26. const int DEBUG_OUTPUT_LEVEL = 1;
  27. std::string ws2s(const std::wstring& wide)
  28. {
  29. return std::string(wide.begin(), wide.end());
  30. }
  31. std::wstring s2ws(const std::string& str)
  32. {
  33. return std::wstring(str.begin(), str.end());
  34. }
  35. void printErrors(CComPtr<IDxcOperationResult> pResult)
  36. {
  37. CComPtr<IDxcBlobEncoding> pErrorBuffer;
  38. IFT(pResult->GetErrorBuffer(&pErrorBuffer));
  39. const char *pStart = (const char *)pErrorBuffer->GetBufferPointer();
  40. std::string msg(pStart);
  41. std::cerr << msg;
  42. HRESULT status;
  43. pResult->GetStatus(&status);
  44. //IFTMSG(status, msg);
  45. }
  46. void CompileToDxilFromFile(DxcDllSupport& dxcSupport, LPCWSTR pShaderTextFilePath, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, LPCWSTR* pArgs, UINT32 argCount, const DxcDefine *pDefines, UINT32 defineCount, IDxcBlob **ppBlob)
  47. {
  48. CComPtr<IDxcLibrary> pLibrary;
  49. IFT(dxcSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary));
  50. CComPtr<IDxcIncludeHandler> dxcIncludeHandler;
  51. IFT(pLibrary->CreateIncludeHandler(&dxcIncludeHandler));
  52. UINT32 codePage(0);
  53. CComPtr<IDxcBlobEncoding> pTextBlob(nullptr);
  54. IFT(pLibrary->CreateBlobFromFile(pShaderTextFilePath, &codePage, &pTextBlob));
  55. CComPtr<IDxcCompiler> pCompiler;
  56. IFT(dxcSupport.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  57. CComPtr<IDxcOperationResult> pResult;
  58. IFT(pCompiler->Compile(pTextBlob, pShaderTextFilePath, pEntryPoint, pTargetProfile, pArgs, argCount, pDefines, defineCount, dxcIncludeHandler, &pResult));
  59. HRESULT resultCode;
  60. CComPtr<IDxcBlobEncoding> pErrorBuffer;
  61. IFT(pResult->GetStatus(&resultCode));
  62. IFT(pResult->GetErrorBuffer(&pErrorBuffer));
  63. if (SUCCEEDED(resultCode))
  64. {
  65. IFT(pResult->GetResult((IDxcBlob **)ppBlob));
  66. }
  67. else
  68. {
  69. printErrors(pResult);
  70. }
  71. }
  72. bool DxrCompile(
  73. DxcDllSupport& dxrFallbackSupport,
  74. const std::string& entryName,
  75. std::vector<IDxcBlob*>& libs,
  76. const std::vector<std::string>& shaderNames,
  77. std::vector<DxcShaderInfo>& shaderIds,
  78. bool findCalledShaders,
  79. IDxcBlob** ppResultBlob)
  80. {
  81. CComPtr<IDxcDxrFallbackCompiler> pCompiler;
  82. IFT(dxrFallbackSupport.CreateInstance(CLSID_DxcDxrFallbackCompiler, &pCompiler));
  83. std::vector<std::wstring> shaderNamesW(shaderNames.size());
  84. std::vector<LPCWSTR> shaderNamePtrs(shaderNames.size());
  85. for (size_t i = 0; i < shaderNames.size(); ++i)
  86. {
  87. shaderNamesW[i] = s2ws(shaderNames[i]);
  88. shaderNamePtrs[i] = shaderNamesW[i].c_str();
  89. }
  90. const UINT maxAttributeSize = 32;
  91. shaderIds.resize(shaderNames.size());
  92. CComPtr<IDxcOperationResult> pCompileResult;
  93. CComPtr<IDxcBlob> pCompiledCollection;
  94. std::vector<DxcShaderBytecode> bytecode(libs.size());
  95. for (UINT i = 0; i < libs.size(); i++)
  96. {
  97. bytecode[i] = { (LPBYTE)libs[i]->GetBufferPointer(), (UINT32)libs[i]->GetBufferSize() };
  98. }
  99. IFT(pCompiler->SetFindCalledShaders(findCalledShaders));
  100. IFT(pCompiler->SetDebugOutput(DEBUG_OUTPUT_LEVEL));
  101. IFT(pCompiler->Compile(
  102. bytecode.data(), libs.size(),
  103. shaderNamePtrs.data(), shaderIds.data(), shaderNamePtrs.size(), maxAttributeSize,
  104. &pCompileResult));
  105. pCompileResult->GetResult(&pCompiledCollection);
  106. IDxcBlob *compiledCollections[] = { pCompiledCollection };
  107. CComPtr<IDxcOperationResult> pResult;
  108. IFT(pCompiler->Link(
  109. s2ws(entryName).c_str(),
  110. compiledCollections, ARRAYSIZE(compiledCollections),
  111. shaderNamePtrs.data(), shaderIds.data(), shaderNamePtrs.size(),
  112. maxAttributeSize,
  113. 1024,
  114. &pResult));
  115. HRESULT status;
  116. IFT(pResult->GetStatus(&status));
  117. IFT(pResult->GetResult(ppResultBlob));
  118. if (SUCCEEDED(status))
  119. {
  120. return true;
  121. }
  122. else
  123. {
  124. std::cerr << "Compile errors\n";
  125. printErrors(pResult);
  126. return false;
  127. }
  128. }
  129. class Tester
  130. {
  131. public:
  132. Tester(const std::string& deviceName, const std::string& path)
  133. : m_deviceName(s2ws(deviceName))
  134. , m_path(path)
  135. {
  136. dxc::EnsureEnabled(m_dxcSupport);
  137. m_dxrFallbackSupport.InitializeForDll(L"DxrFallbackCompiler.dll", "DxcCreateDxrFallbackCompiler");
  138. }
  139. void setFiles(const std::vector<std::string>& files)
  140. {
  141. std::vector<std::string> filesWithLib(files);
  142. filesWithLib.push_back(m_testLibFilename);
  143. m_inputBlobs.clear();
  144. m_inputBlobPtrs.clear();
  145. for (auto& filename : filesWithLib)
  146. {
  147. CComPtr<IDxcBlob> pInput;
  148. LPCWSTR args[] = { L"-O3" };
  149. CompileToDxilFromFile(m_dxcSupport, s2ws(m_path + filename).c_str(), L"", L"lib_6_3", args, _countof(args), nullptr, 0, &pInput);
  150. m_inputBlobs.push_back(pInput);
  151. m_inputBlobPtrs.push_back(pInput);
  152. }
  153. }
  154. protected:
  155. DxcDllSupport m_dxcSupport;
  156. DxcDllSupport m_dxrFallbackSupport;
  157. std::wstring m_deviceName;
  158. std::vector<CComPtr<IDxcBlob>> m_inputBlobs;
  159. std::vector<IDxcBlob*> m_inputBlobPtrs;
  160. std::string m_path;
  161. std::string m_testLibFilename = "testLib.hlsl";
  162. std::string m_entryName = "CSMain";
  163. int runTest(CComPtr<IDxcBlob> pShader, int initialShaderId, const std::vector<int>& input, const std::vector<int>& expectedOutput)
  164. {
  165. std::vector<int> output;
  166. std::unique_ptr<ShaderTester> tester(ShaderTester::New(pShader));
  167. tester->setDevice(m_deviceName);
  168. tester->runShader(initialShaderId, input, output);
  169. int numFailed = checkResult(output, expectedOutput) ? 0 : 1;
  170. if (numFailed)
  171. {
  172. std::cout << "input:";
  173. for (size_t i = 0; i < input.size(); ++i)
  174. std::cout << " " << input[i];
  175. std::cout << "\n";
  176. }
  177. std::cout << "\n";
  178. return numFailed;
  179. }
  180. bool checkResult(const std::vector<int>& output, const std::vector<int>& expectedOutput)
  181. {
  182. int count = output.empty() ? 0 : output[0];
  183. std::cout << count << ": ";
  184. // print result
  185. for (int i = 0; i < count; ++i)
  186. std::cout << output[i + 1] << " ";
  187. std::cout << "\n";
  188. bool passed = false;
  189. if (count == expectedOutput.size())
  190. {
  191. passed = true;
  192. for (size_t i = 0; i < expectedOutput.size(); ++i)
  193. {
  194. if (output[i + 1] != expectedOutput[i])
  195. {
  196. passed = false;
  197. break;
  198. }
  199. }
  200. }
  201. if (!passed)
  202. {
  203. std::cout << expectedOutput.size() << ": ";
  204. for (size_t i = 0; i < expectedOutput.size(); ++i)
  205. std::cout << expectedOutput[i] << " ";
  206. std::cout << "\n";
  207. }
  208. std::cout << (passed ? "PASSED" : "FAILED") << "\n";
  209. return passed;
  210. }
  211. };
  212. class RtCompilerTester : public Tester
  213. {
  214. public:
  215. struct TestWithEntryPoint
  216. {
  217. std::string entryPoint;
  218. std::vector<int> expectedOutput;
  219. };
  220. RtCompilerTester(const std::string& deviceName, const std::string& path)
  221. : Tester(deviceName, path)
  222. {}
  223. // Returns the number of failures
  224. int runTestsWithEntryPoints(const std::vector<TestWithEntryPoint>& tests)
  225. {
  226. int numFailed = 0;
  227. for (auto& test : tests)
  228. {
  229. std::cout << test.entryPoint << "\n";
  230. std::vector<std::string> shaderNames = { test.entryPoint };
  231. std::vector<int> input;
  232. std::vector<DxcShaderInfo> shaderIds;
  233. CComPtr<IDxcBlob> pComputeShader;
  234. if (DxrCompile(m_dxrFallbackSupport, m_entryName, m_inputBlobPtrs, shaderNames, shaderIds, true, &pComputeShader))
  235. numFailed += runTest(pComputeShader, shaderIds[0].Identifier, input, test.expectedOutput);
  236. }
  237. return numFailed;
  238. }
  239. // The first shader is the entry shader. The shaderId of the shader at
  240. // indirectShaderIdx is placed in const memory.
  241. //
  242. // Returns the number of failures.
  243. int runSingleTest(const std::vector<std::string>& shaderNames, const std::vector<int>& input, const std::vector<int>& expectedOutput)
  244. {
  245. std::vector<DxcShaderInfo> shaderIds(shaderNames.size());
  246. CComPtr<IDxcBlob> pComputeShader;
  247. if (!DxrCompile(m_dxrFallbackSupport, m_entryName, m_inputBlobPtrs, shaderNames, shaderIds, false, &pComputeShader))
  248. return 1;
  249. for (size_t i = 0; i < shaderNames.size(); ++i)
  250. std::cout << shaderNames[i] << ":" << shaderIds[i].Identifier << " ";
  251. std::cout << "\n";
  252. return runTest(pComputeShader, shaderIds[0].Identifier, input, expectedOutput);
  253. }
  254. void compileTest(const std::vector<std::string>& shaderNames, const std::string& entryName)
  255. {
  256. std::vector<DxcShaderInfo> shaderIds(shaderNames.size());
  257. CComPtr<IDxcBlob> pOutput;
  258. if (DxrCompile(m_dxrFallbackSupport, entryName, m_inputBlobPtrs, shaderNames, shaderIds, false, &pOutput))
  259. std::cout << "Compile succeeded\n";
  260. else
  261. std::cout << "Compile failed\n";
  262. for (size_t i = 0; i < shaderNames.size(); ++i)
  263. std::cout << shaderNames[i] << ":" << shaderIds[i].Identifier << " ";
  264. std::cout << "\n";
  265. }
  266. };
  267. int asint(float v)
  268. {
  269. return *(int*)&v;
  270. }
  271. float asfloat(int v)
  272. {
  273. return *(float*)&v;
  274. }
  275. class Leaf
  276. {
  277. public:
  278. int leafType;
  279. };
  280. class Instance : public Leaf
  281. {
  282. public:
  283. int instIdx;
  284. int instId;
  285. int instFlags;
  286. };
  287. class Primitive : public Leaf
  288. {
  289. public:
  290. int primIdx;
  291. int geomIdx;
  292. int geomOpaque;
  293. };
  294. class Triangle : public Primitive
  295. {
  296. public:
  297. float t, u, v, d;
  298. int anyHitRet;
  299. };
  300. class Custom : public Primitive
  301. {
  302. public:
  303. struct Hit
  304. {
  305. float t;
  306. int hitKind;
  307. int attr0, attr1;
  308. int anyHitRet;
  309. };
  310. std::vector<Hit> hits;
  311. };
  312. Instance* inst(int instFlags = 0, int instIdx = 0, int instId = 0)
  313. {
  314. Instance* inst = new Instance; // TODO: make this not leak
  315. inst->leafType = LEAF_INST;
  316. inst->instFlags = instFlags;
  317. inst->instIdx = instIdx;
  318. inst->instId = instId;
  319. return inst;
  320. }
  321. Triangle* tri(float t, float u, float v, int anyHitRet = OPAQUE, float d = 1, int primIdx = 0, int geomIdx = 0)
  322. {
  323. Triangle* tri = new Triangle; // TODO: make this not leak
  324. tri->leafType = LEAF_TRIS;
  325. tri->t = t;
  326. tri->u = u;
  327. tri->v = v;
  328. tri->d = d;
  329. tri->anyHitRet = anyHitRet;
  330. tri->primIdx = primIdx;
  331. tri->geomIdx = geomIdx;
  332. tri->geomOpaque = (anyHitRet == OPAQUE);
  333. return tri;
  334. }
  335. Custom* custom(const std::vector<Custom::Hit>& hits, int geomOpaque = 0, int primIdx = 0, int geomIdx = 0)
  336. {
  337. Custom* c = new Custom;
  338. c->hits = hits;
  339. c->leafType = LEAF_CUSTOM;
  340. c->primIdx = primIdx;
  341. c->geomIdx = geomIdx;
  342. c->geomOpaque = geomOpaque;
  343. return c;
  344. }
  345. struct Payload
  346. {
  347. int val;
  348. int primIdx;
  349. float t;
  350. bool operator!=(const Payload& other)
  351. {
  352. return this->val != other.val || this->primIdx != other.primIdx || this->t != other.t;
  353. }
  354. };
  355. std::ostream& operator<<(std::ostream& out, Payload payload)
  356. {
  357. out << "{" << payload.val << "," << payload.primIdx << "}";
  358. return out;
  359. }
  360. class TestData
  361. {
  362. public:
  363. std::string name;
  364. std::vector<int> input;
  365. std::vector<int> expected;
  366. std::vector<std::string> shaders;
  367. std::map<std::string, std::vector<int> > shaderIdSlots;
  368. static int count;
  369. TestData(const std::string& name) : name(name) { count++; }
  370. void setShaderIds(std::vector<DxcShaderInfo>& shaderIds)
  371. {
  372. for (size_t i = 0; i < shaders.size(); ++i)
  373. {
  374. for (auto& slot : shaderIdSlots[shaders[i]])
  375. input[slot] = shaderIds[i].Identifier;
  376. }
  377. }
  378. struct CommittedPrim
  379. {
  380. const Primitive* prim;
  381. float t;
  382. const Custom::Hit* hit;
  383. };
  384. void simulate(Payload expectedPayload, const std::vector<Leaf*>& leaves, int rayFlags = 0)
  385. {
  386. shaders = { "raygen", "chTri", "ahTri", "intersection", "ahCustom", "chCustom", "miss", "Fallback_TraceRay" };
  387. expect(RAYGEN);
  388. Payload payload = { 1000, -1 };
  389. traceRay(rayFlags);
  390. expect(TRACERAY);
  391. bool terminate = false;
  392. CommittedPrim committed = { nullptr, -1, nullptr };
  393. int instIdx = -1, instId = 0, instFlags = 0;
  394. for (Leaf* leaf : leaves)
  395. {
  396. if (leaf->leafType == LEAF_INST)
  397. {
  398. const Instance* i = (Instance*)(leaf); // TODO: Why does dynamic_cast<Instance*> not work here?
  399. instIdx = i->instIdx;
  400. instId = i->instId;
  401. instFlags = i->instFlags;
  402. leafInst(instIdx, instId, instFlags);
  403. }
  404. else
  405. {
  406. const Primitive* prim = (Primitive*)leaf;
  407. leafPrim(prim);
  408. bool opaque = isOpaque(prim->geomOpaque, instFlags, rayFlags);
  409. if (cull(opaque, rayFlags))
  410. continue;
  411. if (leaf->leafType == LEAF_TRIS)
  412. {
  413. const Triangle* tri = (Triangle*)leaf;
  414. triangle(tri);
  415. float d = (instFlags & INSTANCE_FLAG_TRIANGLE_FRONT_COUNTERCLOCKWISE) ? -tri->d : tri->d;
  416. if (committed.prim && (tri->t >= committed.t) || -d * computeCullFaceDir(instFlags, rayFlags) < 0)
  417. continue;
  418. if (opaque)
  419. {
  420. committed = { tri, tri->t, nullptr };
  421. }
  422. else
  423. {
  424. shader("ahTri");
  425. expect({ ANYHIT, (int)tri->u, (int)tri->v });
  426. payload.val += 100;
  427. anyHitRet(tri->anyHitRet);
  428. if (tri->anyHitRet == TERMINATE)
  429. {
  430. committed = { tri, tri->t, nullptr };
  431. terminate = true;
  432. break;
  433. }
  434. else if (tri->anyHitRet == IGNORE)
  435. {
  436. // do nothing
  437. }
  438. else // ACCEPT)
  439. {
  440. committed = { tri, tri->t, nullptr };
  441. }
  442. }
  443. }
  444. else if (leaf->leafType == LEAF_CUSTOM)
  445. {
  446. const Custom* c = (Custom*)leaf;
  447. shader("ahCustom");
  448. shader("intersection");
  449. expect(INTERSECT + 1);
  450. for (auto& hit : c->hits)
  451. {
  452. customHit(hit);
  453. if (committed.prim && hit.t >= committed.t)
  454. continue;
  455. if (!opaque)
  456. {
  457. expect({ ANYHIT + 1, hit.attr0, hit.attr1 });
  458. payload.val += 100;
  459. anyHitRet(hit.anyHitRet);
  460. if (hit.anyHitRet == TERMINATE)
  461. {
  462. committed = { c, hit.t, &hit };
  463. terminate = true;
  464. break;
  465. }
  466. else if (hit.anyHitRet == IGNORE)
  467. {
  468. // do nothing
  469. continue;
  470. }
  471. // ACCEPT - fall through
  472. }
  473. committed = { c, hit.t, &hit };
  474. if (rayFlags & RAY_FLAG_TERMINATE_ON_FIRST_HIT)
  475. {
  476. terminate = true;
  477. break;
  478. }
  479. }
  480. if (!terminate)
  481. endHits();
  482. }
  483. }
  484. if ((rayFlags & RAY_FLAG_TERMINATE_ON_FIRST_HIT) && committed.prim)
  485. {
  486. terminate = true;
  487. break;
  488. }
  489. }
  490. if (!terminate)
  491. {
  492. if (instIdx != -1)
  493. endAccel();
  494. endAccel();
  495. }
  496. if (!(committed.prim && (rayFlags & RAY_FLAG_SKIP_CLOSEST_HIT_SHADER)))
  497. {
  498. const char* ch = "chTri";
  499. if (committed.prim && committed.prim->leafType == LEAF_CUSTOM)
  500. ch = "chCustom";
  501. shade(ch, "miss");
  502. if (committed.prim)
  503. {
  504. if (committed.prim->leafType == LEAF_TRIS)
  505. {
  506. Triangle* tri = (Triangle*)committed.prim;
  507. expect({ CLOSESTHIT, (int)tri->u, (int)tri->v });
  508. }
  509. else
  510. {
  511. expect({ CLOSESTHIT + 1, committed.hit->attr0, committed.hit->attr1 });
  512. }
  513. payload.val += 10;
  514. payload.primIdx = committed.prim->primIdx;
  515. }
  516. else
  517. {
  518. expect(MISS);
  519. payload.val += 1;
  520. }
  521. }
  522. expect(payload.val);
  523. expect(payload.primIdx);
  524. if (payload != expectedPayload)
  525. std::cout << count << ": simulated payload " << payload << " does not match expected " << expectedPayload << "\n";
  526. }
  527. void traceRay(unsigned rayFlags)
  528. {
  529. input.push_back(rayFlags);
  530. }
  531. void shade(const std::string& closestHit, const std::string& miss)
  532. {
  533. shader(closestHit);
  534. shader(miss);
  535. }
  536. void leafPrim(const Primitive* prim)
  537. {
  538. input.push_back(prim->leafType);
  539. input.push_back(pack(prim->primIdx, prim->geomIdx, prim->geomOpaque));
  540. }
  541. void triangle(const Triangle* tr)
  542. {
  543. input.push_back(asint(tr->t));
  544. input.push_back(asint(tr->u));
  545. input.push_back(asint(tr->v));
  546. input.push_back(asint(tr->d));
  547. }
  548. void customHit(const Custom::Hit& hit)
  549. {
  550. input.push_back(asint(hit.t));
  551. input.push_back(hit.hitKind);
  552. input.push_back(hit.attr0);
  553. input.push_back(hit.attr1);
  554. }
  555. void leafInst(int instIdx, int instId, int instFlags)
  556. {
  557. input.push_back(LEAF_INST);
  558. input.push_back(instIdx);
  559. input.push_back(instId);
  560. input.push_back(instFlags);
  561. }
  562. void shader(const std::string& shaderName)
  563. {
  564. shaderIdSlots[shaderName].push_back(input.size()); // fix up later
  565. input.push_back(-1);
  566. }
  567. void anyHitRet(int val)
  568. {
  569. input.push_back(val);
  570. }
  571. void endHits()
  572. {
  573. input.push_back(-1);
  574. }
  575. void endAccel()
  576. {
  577. input.push_back(LEAF_DONE);
  578. }
  579. void expect(int val)
  580. {
  581. expected.push_back(val);
  582. }
  583. void expect(const std::vector<int>& vals)
  584. {
  585. expected.insert(expected.end(), vals.begin(), vals.end());
  586. }
  587. void expect(float val)
  588. {
  589. expected.push_back(asint(val));
  590. }
  591. };
  592. int TestData::count = -1;
  593. class TraversalTester : public Tester
  594. {
  595. public:
  596. TraversalTester(const std::string& deviceName, const std::string& path)
  597. : Tester(deviceName, path)
  598. {
  599. setFiles({ "testTraversal.hlsl", "testTraversal2.hlsl" });
  600. }
  601. int run(const std::vector<TestData*>& tests)
  602. {
  603. int failedTests = 0;
  604. int testIndex = 0;
  605. for (auto td : tests)
  606. {
  607. std::cout << testIndex++ << " " << td->name << std::endl;
  608. std::vector<DxcShaderInfo> shaderIds(td->shaders.size());
  609. CComPtr<IDxcBlob> pComputeShader;
  610. if (!DxrCompile(m_dxrFallbackSupport, m_entryName, m_inputBlobPtrs, td->shaders, shaderIds, false, &pComputeShader))
  611. {
  612. failedTests++;
  613. continue;
  614. }
  615. td->setShaderIds(shaderIds);
  616. for (size_t i = 0; i < td->shaders.size(); ++i)
  617. std::cout << td->shaders[i] << ":" << shaderIds[i].Identifier << " ";
  618. std::cout << "\n";
  619. failedTests += runTest(pComputeShader, shaderIds[0].Identifier, td->input, td->expected);
  620. delete td;
  621. }
  622. return failedTests;
  623. }
  624. };
  625. TestData* test_nohit(Payload expectedPayload)
  626. {
  627. TestData* td = new TestData("nohit");
  628. td->simulate(expectedPayload, {});
  629. return td;
  630. }
  631. TestData* test_instance_nohit(Payload expectedPayload)
  632. {
  633. TestData* td = new TestData("instance_nohit");
  634. td->simulate(expectedPayload, { inst() });
  635. return td;
  636. }
  637. TestData* test_tri(Payload expectedPayload, int anyHitRet, int instFlags = 0, int rayFlags = 0, float d = 1)
  638. {
  639. TestData* td = new TestData("trihit");
  640. td->simulate(
  641. expectedPayload,
  642. {
  643. inst(instFlags),
  644. tri(1, 55, 66, anyHitRet, d),
  645. },
  646. rayFlags
  647. );
  648. return td;
  649. }
  650. struct TriHit
  651. {
  652. int anyHitRet;
  653. float t;
  654. };
  655. TestData* test_2tri(Payload expectedPayload, const TriHit& tri0, const TriHit& tri1, int rayFlags = 0)
  656. {
  657. TestData* td = new TestData("trihit2");
  658. td->simulate(
  659. expectedPayload,
  660. {
  661. inst(),
  662. tri(tri0.t, (expectedPayload.primIdx == 0) ? 5555 : 55, 66, tri0.anyHitRet, 1, 0),
  663. tri(tri1.t, (expectedPayload.primIdx == 1) ? 5555 : 56, 67, tri1.anyHitRet, 1, 1),
  664. },
  665. rayFlags
  666. );
  667. return td;
  668. }
  669. struct CustomHit2
  670. {
  671. int anyHitRet;
  672. float t;
  673. };
  674. TestData* test_custom(Payload expectedPayload, int geomOpaque, const std::vector<CustomHit2> hits, int instFlags = 0, int rayFlags = 0)
  675. {
  676. TestData* td = new TestData("custom");
  677. std::vector<Custom::Hit> customHits;
  678. for (size_t i = 0; i < hits.size(); ++i)
  679. {
  680. const CustomHit2& h = hits[i];
  681. customHits.push_back({ h.t, 33, int(55 + i), int(66 + i), h.anyHitRet });
  682. }
  683. td->simulate(
  684. expectedPayload,
  685. {
  686. inst(instFlags),
  687. custom(customHits, geomOpaque),
  688. },
  689. rayFlags
  690. );
  691. return td;
  692. }
  693. void printUsageAndExit()
  694. {
  695. std::cerr
  696. << "Options:\n"
  697. << " -h | --help Print this message\n"
  698. << " -d | --device <name> Name of device to use. Can be a prefix, e.g. WARP, AMD, etc.\n"
  699. << " -p | --path <directory> Base path for test input files.\n"
  700. << std::endl;
  701. exit(1);
  702. }
  703. int main(int argc, const char* argv[])
  704. {
  705. std::string deviceName = "";
  706. std::string basePath = DEFAULT_TEST_FILE_PATH;
  707. // Parse arguments
  708. std::vector<std::string> args;
  709. for (int i = 1; i < argc; ++i)
  710. args.push_back(argv[i]);
  711. for (size_t i = 0; i < args.size(); ++i)
  712. {
  713. if (args[i] == "-h" || args[i] == "--help")
  714. {
  715. printUsageAndExit();
  716. }
  717. else if (args[i] == "-d" || args[i] == "--device")
  718. {
  719. deviceName = args[++i];
  720. }
  721. else if (args[i] == "-p" || args[i] == "--path")
  722. {
  723. basePath = args[++i];
  724. }
  725. else
  726. {
  727. std::cerr << "Bad arg:" << args[i] << std::endl;
  728. printUsageAndExit();
  729. }
  730. }
  731. try
  732. {
  733. if (!deviceName.empty())
  734. std::cout << "Testing on device " << deviceName << std::endl;
  735. int numFailed = 0;
  736. if (1)
  737. {
  738. RtCompilerTester tester(deviceName, basePath);
  739. tester.setFiles({ "testShader1.hlsl", "testShader2.hlsl" });
  740. numFailed += tester.runTestsWithEntryPoints({
  741. {"no_call", {1, 1}},
  742. {"no_live_values", {1, 1, -99, 2, 2}},
  743. {"single_call", {-99, 1, 1}},
  744. {"single_call_in", {10}},
  745. {"single_call_out", {-99, 64, 64}},
  746. {"single_call_inout", {10, 64, 64}},
  747. {"single_call_inout_passthru", {-98, 10, 64, 64}},
  748. {"types", {-99, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8}},
  749. {"multiple_calls", {-99, 1, 1, -99, 4, 4, 2, 2}},
  750. {"multiple_calls_with_args", {1, 1, -99, 1, 1, -99, 2, 2, 3, 3}},
  751. {"branch", {-99, 64, 64}},
  752. {"no_branch", {10, 10}},
  753. {"loop", {-99, 1, 1, -99, -99, -99, -99, 5, 5}},
  754. {"recursive", {5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0, 1, 1}},
  755. {"use_buffer", {-99, 10, 10}},
  756. {"lower_intrinsics", {-99, 0, 0}},
  757. {"local_array", {-99, 4, 4}},
  758. {"dispatch_idx_and_dims", {0, 0, 1, 1}},
  759. });
  760. numFailed += tester.runSingleTest({ "indirect", "indirect_callee" }, { 1002 }, { -99 });
  761. numFailed += tester.runSingleTest({ "raygen_tri", "chTri", "intersection", "continuation", "Fallback_TraceRay" }, { 1002 }, { -98, -97, 555, 666, -99, 1010 });
  762. numFailed += tester.runSingleTest({ "raygen_custom", "chCustom1", "chCustom2", "intersection", "continuation", "Fallback_TraceRay" }, { 1003, 1005 }, { -98, -95, 19, 10, 11, 12, 13, -100, -99, 500, -96, 333, 444, -99, 1010, -98, -95, 59, 50, 51, 52, 53, -100, -99, 500, -96, 333, 444, -99, 1110 });
  763. tester.setFiles({ "testShader3.hlsl" });
  764. numFailed += tester.runSingleTest({ "pass_struct", "Fallback_TraceRay" }, {}, { -99, 1, 2, 3, 4, 5, 6, 7, 8, 11 });
  765. tester.setFiles({ "testShader4.hlsl" });
  766. numFailed += tester.runSingleTest({ "full_trace_ray", "Fallback_TraceRay" }, {}, { 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15 });
  767. tester.setFiles({ "testShader5.hlsl" });
  768. numFailed += tester.runSingleTest({ "raygen", "ch1", "ch2", "miss1", "miss2", "Fallback_TraceRay" }, {1002, 1005, 1007, 1009, 1009}, {-99, 100,0, -99,101,1, -99,102,2, -99,103,3, 2, 1, 0, -99,103,4, 0, 21111});
  769. }
  770. if (1)
  771. {
  772. // Expected payload is the number of invocations in the following shader types:
  773. // RG AH CH MS
  774. // These counts are store in each digit.
  775. TraversalTester tester(deviceName, basePath);
  776. numFailed += tester.run({
  777. test_nohit({1001,-1}),
  778. test_instance_nohit({1001,-1}),
  779. test_tri({1010, 0}, OPAQUE),
  780. test_tri({1110, 0}, ACCEPT),
  781. test_tri({1101,-1}, IGNORE),
  782. test_tri({1110, 0}, TERMINATE),
  783. test_tri({1001,-1}, OPAQUE, 0, RAY_FLAG_CULL_OPAQUE), // culling
  784. test_tri({1010, 0}, OPAQUE, 0, 0, -1), // flipping direction doesn't matter without culling flags
  785. test_tri({1001,-1}, OPAQUE, 0, RAY_FLAG_CULL_FRONT_FACING_TRIANGLES, 1), // triangle culling
  786. test_tri({1001,-1}, OPAQUE, 0, RAY_FLAG_CULL_BACK_FACING_TRIANGLES, -1), // triangle culling
  787. test_tri({1010, 0}, OPAQUE, INSTANCE_FLAG_TRIANGLE_CULL_DISABLE, RAY_FLAG_CULL_BACK_FACING_TRIANGLES, -1), // disable triangle culling
  788. test_tri({1010, 0}, OPAQUE, INSTANCE_FLAG_TRIANGLE_FRONT_COUNTERCLOCKWISE, 1), // flip winding
  789. test_tri({1001,-1}, OPAQUE, INSTANCE_FLAG_TRIANGLE_FRONT_COUNTERCLOCKWISE, RAY_FLAG_CULL_BACK_FACING_TRIANGLES, 1),
  790. test_tri({1010, 0}, ACCEPT, INSTANCE_FLAG_FORCE_OPAQUE),
  791. test_tri({1010, 0}, ACCEPT, 0, RAY_FLAG_FORCE_OPAQUE),
  792. test_tri({1110, 0}, OPAQUE, INSTANCE_FLAG_FORCE_NON_OPAQUE),
  793. test_tri({1110, 0}, OPAQUE, 0, RAY_FLAG_FORCE_NON_OPAQUE),
  794. test_tri({1010, 0}, ACCEPT, INSTANCE_FLAG_FORCE_NON_OPAQUE, RAY_FLAG_FORCE_OPAQUE), // ray flags opaque overrides instance
  795. test_tri({1110, 0}, OPAQUE, INSTANCE_FLAG_FORCE_OPAQUE, RAY_FLAG_FORCE_NON_OPAQUE), // ray flags opaque overrides instance
  796. test_tri({1010, 0}, OPAQUE, INSTANCE_FLAG_TRIANGLE_CULL_DISABLE, 0, -1), // disable cull
  797. test_tri({1000,-1}, OPAQUE, 0, RAY_FLAG_SKIP_CLOSEST_HIT_SHADER),
  798. test_2tri({1010, 0}, {OPAQUE, 1}, {OPAQUE, 2}), // pick closest (first)
  799. test_2tri({1010, 1}, {OPAQUE, 2}, {OPAQUE, 1}), // pick closest (second)
  800. test_2tri({1010, 0}, {OPAQUE, 2}, {OPAQUE, 1}, RAY_FLAG_TERMINATE_ON_FIRST_HIT),
  801. test_2tri({1110, 0}, {ACCEPT, 1}, {ACCEPT, 2}), // pick closest (first)
  802. test_2tri({1210, 1}, {ACCEPT, 2}, {ACCEPT, 1}), // pick closest (second)
  803. test_2tri({1110, 0}, {ACCEPT, 2}, {ACCEPT, 1}, RAY_FLAG_TERMINATE_ON_FIRST_HIT),
  804. test_2tri({1210, 1}, {IGNORE, 1}, {ACCEPT, 2}), // ignore first (even though closer)
  805. test_2tri({1210, 0}, {ACCEPT, 2}, {IGNORE, 1}), // ignore second (even though closer)
  806. test_2tri({1110, 0}, {TERMINATE, 2}, {ACCEPT, 1}),
  807. test_custom({1010, 0}, 1, {{ACCEPT, 1}}),
  808. test_custom({1110, 0}, 0, {{ACCEPT, 1}}),
  809. test_custom({1101,-1}, 0, {{IGNORE, 1}}),
  810. test_custom({1110, 0}, 0, {{TERMINATE, 1}}),
  811. test_custom({1110, 0}, 0, {{ACCEPT, 1}, {ACCEPT, 2}}), // closest first - no anyhit for second
  812. test_custom({1210, 0}, 0, {{ACCEPT, 2}, {ACCEPT, 1}}), // closest second
  813. test_custom({1210, 0}, 0, {{IGNORE, 1}, {ACCEPT, 2}}), // ignore closer hit
  814. test_custom({1201,-1}, 0, {{IGNORE, 2}, {IGNORE, 1}}), // ignore both
  815. test_custom({1201,-1}, 0, {{IGNORE, 1}, {IGNORE, 2}}), // ignore both - anyhit for both
  816. test_custom({1110, 0}, 0, {{TERMINATE, 2}, {ACCEPT, 1}}), // terminate ==> don't handle second
  817. test_custom({1001,-1}, 1, {{ACCEPT, 1}}, 0, RAY_FLAG_CULL_OPAQUE),
  818. test_custom({1110, 0}, 0, {{ACCEPT, 1}}, 0, RAY_FLAG_CULL_OPAQUE), // no effect on non-opaque
  819. test_custom({1010, 0}, 1, {{ACCEPT, 1}}, 0, RAY_FLAG_CULL_NON_OPAQUE), // no effect on non-opaque
  820. test_custom({1001,-1}, 0, {{ACCEPT, 1}}, 0, RAY_FLAG_CULL_NON_OPAQUE),
  821. test_custom({1010, 0}, 0, {{IGNORE, 1}}, INSTANCE_FLAG_FORCE_OPAQUE), //no anyhit
  822. test_custom({1010, 0}, 0, {{IGNORE, 1}}, 0, RAY_FLAG_FORCE_OPAQUE),
  823. test_custom({1101,-1}, 1, {{IGNORE, 1}}, INSTANCE_FLAG_FORCE_NON_OPAQUE), // anyhit drops the hit
  824. test_custom({1101,-1}, 1, {{IGNORE, 1}}, 0, RAY_FLAG_FORCE_NON_OPAQUE),
  825. test_custom({1010, 0}, 0, {{IGNORE, 1}}, INSTANCE_FLAG_FORCE_NON_OPAQUE, RAY_FLAG_FORCE_OPAQUE), // ray flags opaque overrides instance
  826. test_custom({1101,-1}, 1, {{IGNORE, 1}}, INSTANCE_FLAG_FORCE_OPAQUE, RAY_FLAG_FORCE_NON_OPAQUE), // ray flags opaque overrides instance
  827. test_custom({1001,-1}, 0, {{ACCEPT, 1}}, INSTANCE_FLAG_FORCE_OPAQUE, RAY_FLAG_CULL_OPAQUE),
  828. test_custom({1100,-1}, 0, {{ACCEPT, 1}}, 0, RAY_FLAG_SKIP_CLOSEST_HIT_SHADER),
  829. test_custom({1210, 0}, 0, {{IGNORE, 3}, {ACCEPT,2}, {ACCEPT, 1}}, 0, RAY_FLAG_TERMINATE_ON_FIRST_HIT),
  830. });
  831. }
  832. std::cout << "===============================================\n";
  833. if (numFailed == 0)
  834. std::cout << "PASSED\n";
  835. else
  836. {
  837. std::cout << "FAILED\n";
  838. std::cout << numFailed << " tests failed\n";
  839. }
  840. }
  841. catch (...)
  842. {
  843. printf("Failed - unknown error.\n");
  844. return 1;
  845. }
  846. }