ExecutionTest.cpp 89 KB


  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // ExecutionTest.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // These tests run by executing compiled programs, and thus involve more //
  9. // moving parts, like the runtime and drivers. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include <algorithm>
  13. #include <memory>
  14. #include <vector>
  15. #include <string>
  16. #include <map>
  17. #include <unordered_set>
  18. #include <strstream>
  19. #include <iomanip>
  20. #include "CompilationResult.h"
  21. #include "HLSLTestData.h"
  22. #include <Shlwapi.h>
  23. #include <atlcoll.h>
  24. #undef _read
  25. #include "WexTestClass.h"
  26. #include "HlslTestUtils.h"
  27. #include "DxcTestUtils.h"
  28. #include "dxc/Support/Global.h"
  29. #include "dxc/Support/WinIncludes.h"
  30. #include "dxc/Support/FileIOHelper.h"
  31. #include "dxc/Support/Unicode.h"
  32. //
  33. // d3d12.h and dxgi1_4.h are included in the Windows 10 SDK
  34. // https://msdn.microsoft.com/en-us/library/windows/desktop/dn899120(v=vs.85).aspx
  35. // https://developer.microsoft.com/en-US/windows/downloads/windows-10-sdk
  36. //
  37. #include <d3d12.h>
  38. #include <dxgi1_4.h>
  39. #include <DXGIDebug.h>
  40. #include <D3dx12.h>
  41. #include <DirectXMath.h>
  42. #include <strsafe.h>
  43. #include <d3dcompiler.h>
  44. #include <wincodec.h>
  45. #include "ShaderOpTest.h"
  46. #pragma comment(lib, "d3dcompiler.lib")
  47. #pragma comment(lib, "windowscodecs.lib")
  48. #pragma comment(lib, "dxguid.lib")
  49. // A more recent Windows SDK than currently required is needed for these.
  50. typedef HRESULT(WINAPI *D3D12EnableExperimentalFeaturesFn)(
  51. UINT NumFeatures,
  52. __in_ecount(NumFeatures) const IID* pIIDs,
  53. __in_ecount_opt(NumFeatures) void* pConfigurationStructs,
  54. __in_ecount_opt(NumFeatures) UINT* pConfigurationStructSizes);
  55. static const GUID D3D12ExperimentalShaderModelsID = { /* 76f5573e-f13a-40f5-b297-81ce9e18933f */
  56. 0x76f5573e,
  57. 0xf13a,
  58. 0x40f5,
  59. { 0xb2, 0x97, 0x81, 0xce, 0x9e, 0x18, 0x93, 0x3f }
  60. };
  61. using namespace DirectX;
  62. using namespace hlsl_test;
  63. template <typename TSequence, typename T>
  64. static bool contains(TSequence s, const T &val) {
  65. return std::cend(s) != std::find(std::cbegin(s), std::cend(s), val);
  66. }
  67. template <typename InputIterator, typename T>
  68. static bool contains(InputIterator b, InputIterator e, const T &val) {
  69. return e != std::find(b, e, val);
  70. }
  71. static HRESULT EnableExperimentalShaderModels() {
  72. HMODULE hRuntime = LoadLibraryW(L"d3d12.dll");
  73. if (hRuntime == NULL) {
  74. return HRESULT_FROM_WIN32(GetLastError());
  75. }
  76. D3D12EnableExperimentalFeaturesFn pD3D12EnableExperimentalFeatures =
  77. (D3D12EnableExperimentalFeaturesFn)GetProcAddress(hRuntime, "D3D12EnableExperimentalFeatures");
  78. if (pD3D12EnableExperimentalFeatures == nullptr) {
  79. FreeLibrary(hRuntime);
  80. return HRESULT_FROM_WIN32(GetLastError());
  81. }
  82. HRESULT hr = pD3D12EnableExperimentalFeatures(1, &D3D12ExperimentalShaderModelsID, nullptr, nullptr);
  83. FreeLibrary(hRuntime);
  84. return hr;
  85. }
  86. static HRESULT ReportLiveObjects() {
  87. CComPtr<IDXGIDebug1> pDebug;
  88. IFR(DXGIGetDebugInterface1(0, IID_PPV_ARGS(&pDebug)));
  89. IFR(pDebug->ReportLiveObjects(DXGI_DEBUG_ALL, DXGI_DEBUG_RLO_ALL));
  90. return S_OK;
  91. }
  92. static void WriteInfoQueueMessages(void *pStrCtx, st::OutputStringFn pOutputStrFn, ID3D12InfoQueue *pInfoQueue) {
  93. bool allMessagesOK = true;
  94. UINT64 count = pInfoQueue->GetNumStoredMessages();
  95. CAtlArray<BYTE> message;
  96. for (UINT64 i = 0; i < count; ++i) {
  97. // 'GetMessageA' rather than 'GetMessage' is an artifact of user32 headers.
  98. SIZE_T msgLen = 0;
  99. if (FAILED(pInfoQueue->GetMessageA(i, nullptr, &msgLen))) {
  100. allMessagesOK = false;
  101. continue;
  102. }
  103. if (message.GetCount() < msgLen) {
  104. if (!message.SetCount(msgLen)) {
  105. allMessagesOK = false;
  106. continue;
  107. }
  108. }
  109. D3D12_MESSAGE *pMessage = (D3D12_MESSAGE *)message.GetData();
  110. if (FAILED(pInfoQueue->GetMessageA(i, pMessage, &msgLen))) {
  111. allMessagesOK = false;
  112. continue;
  113. }
  114. CA2W msgW(pMessage->pDescription, CP_ACP);
  115. pOutputStrFn(pStrCtx, msgW.m_psz);
  116. pOutputStrFn(pStrCtx, L"\r\n");
  117. }
  118. if (!allMessagesOK) {
  119. pOutputStrFn(pStrCtx, L"Failed to retrieve some messages.\r\n");
  120. }
  121. }
  122. class CComContext {
  123. private:
  124. bool m_init;
  125. public:
  126. CComContext() : m_init(false) {}
  127. ~CComContext() { Dispose(); }
  128. void Dispose() { if (!m_init) return; m_init = false; CoUninitialize(); }
  129. HRESULT Init() { HRESULT hr = CoInitializeEx(0, COINIT_MULTITHREADED); if (SUCCEEDED(hr)) { m_init = true; } return hr; }
  130. };
  131. static void SavePixelsToFile(LPCVOID pPixels, DXGI_FORMAT format, UINT32 m_width, UINT32 m_height, LPCWSTR pFileName) {
  132. CComContext ctx;
  133. CComPtr<IWICImagingFactory> pFactory;
  134. CComPtr<IWICBitmap> pBitmap;
  135. CComPtr<IWICBitmapEncoder> pEncoder;
  136. CComPtr<IWICBitmapFrameEncode> pFrameEncode;
  137. CComPtr<hlsl::AbstractMemoryStream> pStream;
  138. CComPtr<IMalloc> pMalloc;
  139. struct PF {
  140. DXGI_FORMAT Format;
  141. GUID PixelFormat;
  142. UINT32 PixelSize;
  143. bool operator==(DXGI_FORMAT F) const {
  144. return F == Format;
  145. }
  146. } Vals[] = {
  147. // Add more pixel format mappings as needed.
  148. { DXGI_FORMAT_R8G8B8A8_UNORM, GUID_WICPixelFormat32bppRGBA, 4 }
  149. };
  150. PF *pFormat = std::find(Vals, Vals + _countof(Vals), format);
  151. VERIFY_SUCCEEDED(ctx.Init());
  152. VERIFY_SUCCEEDED(CoCreateInstance(CLSID_WICImagingFactory, NULL, CLSCTX_INPROC_SERVER, IID_IWICImagingFactory, (LPVOID*)&pFactory));
  153. VERIFY_SUCCEEDED(CoGetMalloc(1, &pMalloc));
  154. VERIFY_SUCCEEDED(hlsl::CreateMemoryStream(pMalloc, &pStream));
  155. VERIFY_ARE_NOT_EQUAL(pFormat, Vals + _countof(Vals));
  156. VERIFY_SUCCEEDED(pFactory->CreateBitmapFromMemory(m_width, m_height, pFormat->PixelFormat, m_width * pFormat->PixelSize, m_width * m_height * pFormat->PixelSize, (BYTE *)pPixels, &pBitmap));
  157. VERIFY_SUCCEEDED(pFactory->CreateEncoder(GUID_ContainerFormatBmp, nullptr, &pEncoder));
  158. VERIFY_SUCCEEDED(pEncoder->Initialize(pStream, WICBitmapEncoderNoCache));
  159. VERIFY_SUCCEEDED(pEncoder->CreateNewFrame(&pFrameEncode, nullptr));
  160. VERIFY_SUCCEEDED(pFrameEncode->Initialize(nullptr));
  161. VERIFY_SUCCEEDED(pFrameEncode->WriteSource(pBitmap, nullptr));
  162. VERIFY_SUCCEEDED(pFrameEncode->Commit());
  163. VERIFY_SUCCEEDED(pEncoder->Commit());
  164. hlsl::WriteBinaryFile(pFileName, pStream->GetPtr(), pStream->GetPtrSize());
  165. }
  166. class ExecutionTest {
  167. public:
  168. // By default, ignore these tests, which require a recent build to run properly.
  169. BEGIN_TEST_CLASS(ExecutionTest)
  170. TEST_CLASS_PROPERTY(L"Ignore", L"true")
  171. TEST_METHOD_PROPERTY(L"Priority", L"0")
  172. END_TEST_CLASS()
  173. TEST_CLASS_SETUP(ExecutionTestClassSetup)
  174. TEST_METHOD(BasicComputeTest);
  175. TEST_METHOD(BasicTriangleTest);
  176. TEST_METHOD(BasicTriangleOpTest);
  177. TEST_METHOD(MinMaxTest);
  178. TEST_METHOD(OutOfBoundsTest);
  179. TEST_METHOD(SaturateTest);
  180. TEST_METHOD(SignTest);
  181. TEST_METHOD(Int64Test);
  182. TEST_METHOD(WaveIntrinsicsTest);
  183. TEST_METHOD(WaveIntrinsicsInPSTest);
  184. TEST_METHOD(DoShaderOpArithTest);
  185. dxc::DxcDllSupport m_support;
  186. bool m_ExperimentalModeEnabled = false;
  187. static const float ClearColor[4];
  188. bool UseDxbc() {
  189. return GetTestParamBool(L"DXBC");
  190. }
  191. bool UseDebugIfaces() {
  192. return true;
  193. }
  194. bool SaveImages() {
  195. return GetTestParamBool(L"SaveImages");
  196. }
  197. void CompileFromText(LPCSTR pText, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, ID3DBlob **ppBlob) {
  198. VERIFY_SUCCEEDED(m_support.Initialize());
  199. CComPtr<IDxcCompiler> pCompiler;
  200. CComPtr<IDxcLibrary> pLibrary;
  201. CComPtr<IDxcBlobEncoding> pTextBlob;
  202. CComPtr<IDxcOperationResult> pResult;
  203. HRESULT resultCode;
  204. VERIFY_SUCCEEDED(m_support.CreateInstance(CLSID_DxcCompiler, &pCompiler));
  205. VERIFY_SUCCEEDED(m_support.CreateInstance(CLSID_DxcLibrary, &pLibrary));
  206. VERIFY_SUCCEEDED(pLibrary->CreateBlobWithEncodingFromPinned((LPBYTE)pText, strlen(pText), CP_UTF8, &pTextBlob));
  207. VERIFY_SUCCEEDED(pCompiler->Compile(pTextBlob, L"hlsl.hlsl", pEntryPoint, pTargetProfile, nullptr, 0, nullptr, 0, nullptr, &pResult));
  208. VERIFY_SUCCEEDED(pResult->GetStatus(&resultCode));
  209. if (FAILED(resultCode)) {
  210. CComPtr<IDxcBlobEncoding> errors;
  211. VERIFY_SUCCEEDED(pResult->GetErrorBuffer(&errors));
  212. LogCommentFmt(L"Failed to compile shader: %s", BlobToUtf16(errors).data());
  213. }
  214. VERIFY_SUCCEEDED(resultCode);
  215. VERIFY_SUCCEEDED(pResult->GetResult((IDxcBlob **)ppBlob));
  216. }
  217. void CreateComputeCommandQueue(ID3D12Device *pDevice, LPCWSTR pName, ID3D12CommandQueue **ppCommandQueue) {
  218. D3D12_COMMAND_QUEUE_DESC queueDesc = {};
  219. queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
  220. queueDesc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
  221. VERIFY_SUCCEEDED(pDevice->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(ppCommandQueue)));
  222. VERIFY_SUCCEEDED((*ppCommandQueue)->SetName(pName));
  223. }
  224. void CreateComputePSO(ID3D12Device *pDevice, ID3D12RootSignature *pRootSignature, LPCSTR pShader, ID3D12PipelineState **ppComputeState) {
  225. CComPtr<ID3DBlob> pComputeShader;
  226. // Load and compile shaders.
  227. if (UseDxbc()) {
  228. DXBCFromText(pShader, L"main", L"cs_6_0", &pComputeShader);
  229. }
  230. else {
  231. CompileFromText(pShader, L"main", L"cs_6_0", &pComputeShader);
  232. }
  233. // Describe and create the compute pipeline state object (PSO).
  234. D3D12_COMPUTE_PIPELINE_STATE_DESC computePsoDesc = {};
  235. computePsoDesc.pRootSignature = pRootSignature;
  236. computePsoDesc.CS = CD3DX12_SHADER_BYTECODE(pComputeShader);
  237. VERIFY_SUCCEEDED(pDevice->CreateComputePipelineState(&computePsoDesc, IID_PPV_ARGS(ppComputeState)));
  238. }
  239. bool CreateDevice(_COM_Outptr_ ID3D12Device **ppDevice) {
  240. const D3D_FEATURE_LEVEL FeatureLevelRequired = D3D_FEATURE_LEVEL_11_0;
  241. CComPtr<IDXGIFactory4> factory;
  242. CComPtr<ID3D12Device> pDevice;
  243. *ppDevice = nullptr;
  244. VERIFY_SUCCEEDED(CreateDXGIFactory1(IID_PPV_ARGS(&factory)));
  245. if (GetTestParamUseWARP(true)) {
  246. CComPtr<IDXGIAdapter> warpAdapter;
  247. VERIFY_SUCCEEDED(factory->EnumWarpAdapter(IID_PPV_ARGS(&warpAdapter)));
  248. HRESULT createHR = D3D12CreateDevice(warpAdapter, FeatureLevelRequired,
  249. IID_PPV_ARGS(&pDevice));
  250. if (FAILED(createHR)) {
  251. LogCommentFmt(L"The available version of WARP does not support d3d12.");
  252. WEX::Logging::Log::Result(WEX::Logging::TestResults::Blocked);
  253. return false;
  254. }
  255. } else {
  256. CComPtr<IDXGIAdapter1> hardwareAdapter;
  257. WEX::Common::String AdapterValue;
  258. IFT(WEX::TestExecution::RuntimeParameters::TryGetValue(L"Adapter",
  259. AdapterValue));
  260. GetHardwareAdapter(factory, AdapterValue, &hardwareAdapter);
  261. if (hardwareAdapter == nullptr) {
  262. WEX::Logging::Log::Error(
  263. L"Unable to find hardware adapter with D3D12 support.");
  264. return false;
  265. }
  266. VERIFY_SUCCEEDED(D3D12CreateDevice(hardwareAdapter, FeatureLevelRequired,
  267. IID_PPV_ARGS(&pDevice)));
  268. }
  269. if (pDevice == nullptr)
  270. return false;
  271. if (!UseDxbc()) {
  272. // Check for DXIL support.
  273. // This is defined in d3d.h for Windows 10 Anniversary Edition SDK, but we only
  274. // require the Windows 10 SDK.
  275. typedef enum D3D_SHADER_MODEL {
  276. D3D_SHADER_MODEL_5_1 = 0x51,
  277. D3D_SHADER_MODEL_6_0 = 0x60
  278. } D3D_SHADER_MODEL;
  279. typedef struct D3D12_FEATURE_DATA_SHADER_MODEL {
  280. _Inout_ D3D_SHADER_MODEL HighestShaderModel;
  281. } D3D12_FEATURE_DATA_SHADER_MODEL;
  282. const UINT D3D12_FEATURE_SHADER_MODEL = 7;
  283. D3D12_FEATURE_DATA_SHADER_MODEL SMData;
  284. SMData.HighestShaderModel = D3D_SHADER_MODEL_6_0;
  285. VERIFY_SUCCEEDED(pDevice->CheckFeatureSupport(
  286. (D3D12_FEATURE)D3D12_FEATURE_SHADER_MODEL, &SMData, sizeof(SMData)));
  287. if (SMData.HighestShaderModel != D3D_SHADER_MODEL_6_0) {
  288. LogCommentFmt(L"The selected device does not support "
  289. L"shader model 6 (required for DXIL).");
  290. WEX::Logging::Log::Result(WEX::Logging::TestResults::Blocked);
  291. return false;
  292. }
  293. }
  294. if (UseDebugIfaces()) {
  295. CComPtr<ID3D12InfoQueue> pInfoQueue;
  296. if (SUCCEEDED(pDevice->QueryInterface(&pInfoQueue))) {
  297. pInfoQueue->SetMuteDebugOutput(FALSE);
  298. }
  299. }
  300. *ppDevice = pDevice.Detach();
  301. return true;
  302. }
  303. void CreateGraphicsCommandQueue(ID3D12Device *pDevice, ID3D12CommandQueue **ppCommandQueue) {
  304. D3D12_COMMAND_QUEUE_DESC queueDesc = {};
  305. queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
  306. queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;;
  307. VERIFY_SUCCEEDED(pDevice->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(ppCommandQueue)));
  308. }
  309. void CreateGraphicsCommandQueueAndList(
  310. ID3D12Device *pDevice, ID3D12CommandQueue **ppCommandQueue,
  311. ID3D12CommandAllocator **ppAllocator,
  312. ID3D12GraphicsCommandList **ppCommandList, ID3D12PipelineState *pPSO) {
  313. CreateGraphicsCommandQueue(pDevice, ppCommandQueue);
  314. VERIFY_SUCCEEDED(pDevice->CreateCommandAllocator(
  315. D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(ppAllocator)));
  316. VERIFY_SUCCEEDED(pDevice->CreateCommandList(
  317. 0, D3D12_COMMAND_LIST_TYPE_DIRECT, *ppAllocator, pPSO,
  318. IID_PPV_ARGS(ppCommandList)));
  319. }
  320. void CreateGraphicsPSO(ID3D12Device *pDevice,
  321. D3D12_INPUT_LAYOUT_DESC *pInputLayout,
  322. ID3D12RootSignature *pRootSignature, LPCSTR pShaders,
  323. ID3D12PipelineState **ppPSO) {
  324. CComPtr<ID3DBlob> vertexShader;
  325. CComPtr<ID3DBlob> pixelShader;
  326. if (UseDxbc()) {
  327. DXBCFromText(pShaders, L"VSMain", L"vs_6_0", &vertexShader);
  328. DXBCFromText(pShaders, L"PSMain", L"ps_6_0", &pixelShader);
  329. } else {
  330. CompileFromText(pShaders, L"VSMain", L"vs_6_0", &vertexShader);
  331. CompileFromText(pShaders, L"PSMain", L"ps_6_0", &pixelShader);
  332. }
  333. // Describe and create the graphics pipeline state object (PSO).
  334. D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {};
  335. psoDesc.InputLayout = *pInputLayout;
  336. psoDesc.pRootSignature = pRootSignature;
  337. psoDesc.VS = CD3DX12_SHADER_BYTECODE(vertexShader);
  338. psoDesc.PS = CD3DX12_SHADER_BYTECODE(pixelShader);
  339. psoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT);
  340. psoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT);
  341. psoDesc.DepthStencilState.DepthEnable = FALSE;
  342. psoDesc.DepthStencilState.StencilEnable = FALSE;
  343. psoDesc.SampleMask = UINT_MAX;
  344. psoDesc.PrimitiveTopologyType = D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE;
  345. psoDesc.NumRenderTargets = 1;
  346. psoDesc.RTVFormats[0] = DXGI_FORMAT_R8G8B8A8_UNORM;
  347. psoDesc.SampleDesc.Count = 1;
  348. VERIFY_SUCCEEDED(
  349. pDevice->CreateGraphicsPipelineState(&psoDesc, IID_PPV_ARGS(ppPSO)));
  350. }
  351. void CreateRenderTargetAndReadback(ID3D12Device *pDevice,
  352. ID3D12DescriptorHeap *pHeap, UINT width,
  353. UINT height,
  354. ID3D12Resource **ppRenderTarget,
  355. ID3D12Resource **ppBuffer) {
  356. const DXGI_FORMAT format = DXGI_FORMAT_R8G8B8A8_UNORM;
  357. const size_t formatElementSize = 4;
  358. CComPtr<ID3D12Resource> pRenderTarget;
  359. CComPtr<ID3D12Resource> pBuffer;
  360. CD3DX12_CPU_DESCRIPTOR_HANDLE rtvHandle(
  361. pHeap->GetCPUDescriptorHandleForHeapStart());
  362. CD3DX12_HEAP_PROPERTIES rtHeap(D3D12_HEAP_TYPE_DEFAULT);
  363. CD3DX12_RESOURCE_DESC rtDesc(
  364. CD3DX12_RESOURCE_DESC::Tex2D(format, width, height));
  365. CD3DX12_CLEAR_VALUE rtClearVal(format, ClearColor);
  366. rtDesc.Flags = D3D12_RESOURCE_FLAG_ALLOW_RENDER_TARGET;
  367. VERIFY_SUCCEEDED(pDevice->CreateCommittedResource(
  368. &rtHeap, D3D12_HEAP_FLAG_NONE, &rtDesc, D3D12_RESOURCE_STATE_COPY_DEST,
  369. &rtClearVal, IID_PPV_ARGS(&pRenderTarget)));
  370. pDevice->CreateRenderTargetView(pRenderTarget, nullptr, rtvHandle);
  371. // rtvHandle.Offset(1, rtvDescriptorSize); // Not needed for a single
  372. // resource.
  373. CD3DX12_HEAP_PROPERTIES readHeap(D3D12_HEAP_TYPE_READBACK);
  374. CD3DX12_RESOURCE_DESC readDesc(
  375. CD3DX12_RESOURCE_DESC::Buffer(width * height * formatElementSize));
  376. VERIFY_SUCCEEDED(pDevice->CreateCommittedResource(
  377. &readHeap, D3D12_HEAP_FLAG_NONE, &readDesc,
  378. D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&pBuffer)));
  379. *ppRenderTarget = pRenderTarget.Detach();
  380. *ppBuffer = pBuffer.Detach();
  381. }
  382. void CreateRootSignatureFromDesc(ID3D12Device *pDevice,
  383. const D3D12_ROOT_SIGNATURE_DESC *pDesc,
  384. ID3D12RootSignature **pRootSig) {
  385. CComPtr<ID3DBlob> signature;
  386. CComPtr<ID3DBlob> error;
  387. VERIFY_SUCCEEDED(D3D12SerializeRootSignature(pDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error));
  388. VERIFY_SUCCEEDED(pDevice->CreateRootSignature(
  389. 0, signature->GetBufferPointer(), signature->GetBufferSize(),
  390. IID_PPV_ARGS(pRootSig)));
  391. }
  392. void CreateRtvDescriptorHeap(ID3D12Device *pDevice, UINT numDescriptors,
  393. ID3D12DescriptorHeap **pRtvHeap, UINT *rtvDescriptorSize) {
  394. D3D12_DESCRIPTOR_HEAP_DESC rtvHeapDesc = {};
  395. rtvHeapDesc.NumDescriptors = numDescriptors;
  396. rtvHeapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_RTV;
  397. rtvHeapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_NONE;
  398. VERIFY_SUCCEEDED(
  399. pDevice->CreateDescriptorHeap(&rtvHeapDesc, IID_PPV_ARGS(pRtvHeap)));
  400. if (rtvDescriptorSize != nullptr) {
  401. *rtvDescriptorSize = pDevice->GetDescriptorHandleIncrementSize(
  402. D3D12_DESCRIPTOR_HEAP_TYPE_RTV);
  403. }
  404. }
  405. void CreateTestUavs(ID3D12Device *pDevice,
  406. ID3D12GraphicsCommandList *pCommandList, LPCVOID values,
  407. UINT32 valueSizeInBytes, ID3D12Resource **ppUavResource,
  408. ID3D12Resource **ppReadBuffer,
  409. ID3D12Resource **ppUploadResource) {
  410. CComPtr<ID3D12Resource> pUavResource;
  411. CComPtr<ID3D12Resource> pReadBuffer;
  412. CComPtr<ID3D12Resource> pUploadResource;
  413. D3D12_SUBRESOURCE_DATA transferData;
  414. D3D12_HEAP_PROPERTIES defaultHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
  415. D3D12_HEAP_PROPERTIES uploadHeapProperties = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
  416. D3D12_RESOURCE_DESC bufferDesc = CD3DX12_RESOURCE_DESC::Buffer(valueSizeInBytes, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS);
  417. D3D12_RESOURCE_DESC uploadBufferDesc = CD3DX12_RESOURCE_DESC::Buffer(valueSizeInBytes);
  418. CD3DX12_HEAP_PROPERTIES readHeap(D3D12_HEAP_TYPE_READBACK);
  419. CD3DX12_RESOURCE_DESC readDesc(CD3DX12_RESOURCE_DESC::Buffer(valueSizeInBytes));
  420. VERIFY_SUCCEEDED(pDevice->CreateCommittedResource(
  421. &defaultHeapProperties,
  422. D3D12_HEAP_FLAG_NONE,
  423. &bufferDesc,
  424. D3D12_RESOURCE_STATE_COPY_DEST,
  425. nullptr,
  426. IID_PPV_ARGS(&pUavResource)));
  427. VERIFY_SUCCEEDED(pDevice->CreateCommittedResource(
  428. &uploadHeapProperties,
  429. D3D12_HEAP_FLAG_NONE,
  430. &uploadBufferDesc,
  431. D3D12_RESOURCE_STATE_GENERIC_READ,
  432. nullptr,
  433. IID_PPV_ARGS(&pUploadResource)));
  434. VERIFY_SUCCEEDED(pDevice->CreateCommittedResource(
  435. &readHeap, D3D12_HEAP_FLAG_NONE, &readDesc,
  436. D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&pReadBuffer)));
  437. transferData.pData = values;
  438. transferData.RowPitch = valueSizeInBytes;
  439. transferData.SlicePitch = transferData.RowPitch;
  440. UpdateSubresources<1>(pCommandList, pUavResource.p, pUploadResource.p, 0, 0, 1, &transferData);
  441. RecordTransitionBarrier(pCommandList, pUavResource, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_UNORDERED_ACCESS);
  442. *ppUavResource = pUavResource.Detach();
  443. *ppReadBuffer = pReadBuffer.Detach();
  444. *ppUploadResource = pUploadResource.Detach();
  445. }
  446. template <typename TVertex, int len>
  447. void CreateVertexBuffer(ID3D12Device *pDevice, TVertex(&vertices)[len],
  448. ID3D12Resource **ppVertexBuffer,
  449. D3D12_VERTEX_BUFFER_VIEW *pVertexBufferView) {
  450. size_t vertexBufferSize = sizeof(vertices);
  451. CComPtr<ID3D12Resource> pVertexBuffer;
  452. CD3DX12_HEAP_PROPERTIES heapProps(D3D12_HEAP_TYPE_UPLOAD);
  453. CD3DX12_RESOURCE_DESC bufferDesc(
  454. CD3DX12_RESOURCE_DESC::Buffer(vertexBufferSize));
  455. VERIFY_SUCCEEDED(pDevice->CreateCommittedResource(
  456. &heapProps, D3D12_HEAP_FLAG_NONE, &bufferDesc,
  457. D3D12_RESOURCE_STATE_GENERIC_READ, nullptr,
  458. IID_PPV_ARGS(&pVertexBuffer)));
  459. UINT8 *pVertexDataBegin;
  460. CD3DX12_RANGE readRange(0, 0);
  461. VERIFY_SUCCEEDED(pVertexBuffer->Map(
  462. 0, &readRange, reinterpret_cast<void **>(&pVertexDataBegin)));
  463. memcpy(pVertexDataBegin, vertices, vertexBufferSize);
  464. pVertexBuffer->Unmap(0, nullptr);
  465. // Initialize the vertex buffer view.
  466. pVertexBufferView->BufferLocation = pVertexBuffer->GetGPUVirtualAddress();
  467. pVertexBufferView->StrideInBytes = sizeof(TVertex);
  468. pVertexBufferView->SizeInBytes = vertexBufferSize;
  469. *ppVertexBuffer = pVertexBuffer.Detach();
  470. }
  471. // Requires Anniversary Edition headers, so simplifying things for current setup.
  472. const UINT D3D12_FEATURE_D3D12_OPTIONS1 = 8;
  473. struct D3D12_FEATURE_DATA_D3D12_OPTIONS1 {
  474. BOOL WaveOps;
  475. UINT WaveLaneCountMin;
  476. UINT WaveLaneCountMax;
  477. UINT TotalLaneCount;
  478. BOOL ExpandedComputeResourceStates;
  479. BOOL Int64ShaderOps;
  480. };
  481. bool DoesDeviceSupportInt64(ID3D12Device *pDevice) {
  482. D3D12_FEATURE_DATA_D3D12_OPTIONS1 O;
  483. if (FAILED(pDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS1, &O, sizeof(O))))
  484. return false;
  485. return O.Int64ShaderOps != FALSE;
  486. }
  487. bool DoesDeviceSupportWaveOps(ID3D12Device *pDevice) {
  488. D3D12_FEATURE_DATA_D3D12_OPTIONS1 O;
  489. if (FAILED(pDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS1, &O, sizeof(O))))
  490. return false;
  491. return O.WaveOps != FALSE;
  492. }
  493. void DXBCFromText(LPCSTR pText, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, ID3DBlob **ppBlob) {
  494. CW2A pEntryPointA(pEntryPoint, CP_UTF8);
  495. CW2A pTargetProfileA(pTargetProfile, CP_UTF8);
  496. CComPtr<ID3DBlob> pErrors;
  497. D3D_SHADER_MACRO d3dMacro[2];
  498. ZeroMemory(d3dMacro, sizeof(d3dMacro));
  499. d3dMacro[0].Definition = "1";
  500. d3dMacro[0].Name = "USING_DXBC";
  501. HRESULT hr = D3DCompile(pText, strlen(pText), "hlsl.hlsl", d3dMacro, nullptr, pEntryPointA, pTargetProfileA, 0, 0, ppBlob, &pErrors);
  502. if (pErrors != nullptr) {
  503. CA2W errors((char *)pErrors->GetBufferPointer(), CP_ACP);
  504. LogCommentFmt(L"Compilation failure: %s", errors.m_szBuffer);
  505. }
  506. VERIFY_SUCCEEDED(hr);
  507. }
  508. HRESULT EnableDebugLayer() {
  509. // The debug layer does net yet validate DXIL programs that require rewriting,
  510. // but basic logging should work properly.
  511. HRESULT hr = S_FALSE;
  512. if (UseDebugIfaces()) {
  513. CComPtr<ID3D12Debug> debugController;
  514. hr = D3D12GetDebugInterface(IID_PPV_ARGS(&debugController));
  515. if (SUCCEEDED(hr)) {
  516. debugController->EnableDebugLayer();
  517. hr = S_OK;
  518. }
  519. }
  520. return hr;
  521. }
  522. HRESULT EnableExperimentalMode() {
  523. if (m_ExperimentalModeEnabled) {
  524. return S_OK;
  525. }
  526. if (!GetTestParamBool(L"ExperimentalShaders")) {
  527. return S_OK;
  528. }
  529. HRESULT hr = EnableExperimentalShaderModels();
  530. if (SUCCEEDED(hr)) {
  531. m_ExperimentalModeEnabled = true;
  532. }
  533. return hr;
  534. }
  535. struct FenceObj {
  536. HANDLE m_fenceEvent = NULL;
  537. CComPtr<ID3D12Fence> m_fence;
  538. UINT64 m_fenceValue;
  539. ~FenceObj() {
  540. if (m_fenceEvent) CloseHandle(m_fenceEvent);
  541. }
  542. };
  543. void InitFenceObj(ID3D12Device *pDevice, FenceObj *pObj) {
  544. pObj->m_fenceValue = 1;
  545. VERIFY_SUCCEEDED(pDevice->CreateFence(0, D3D12_FENCE_FLAG_NONE,
  546. IID_PPV_ARGS(&pObj->m_fence)));
  547. // Create an event handle to use for frame synchronization.
  548. pObj->m_fenceEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
  549. if (pObj->m_fenceEvent == nullptr) {
  550. VERIFY_SUCCEEDED(HRESULT_FROM_WIN32(GetLastError()));
  551. }
  552. }
  553. void ReadHlslDataIntoNewStream(LPCWSTR relativePath, IStream **ppStream) {
  554. VERIFY_SUCCEEDED(m_support.Initialize());
  555. CComPtr<IDxcLibrary> pLibrary;
  556. CComPtr<IDxcBlobEncoding> pBlob;
  557. CComPtr<IStream> pStream;
  558. std::wstring path = GetPathToHlslDataFile(relativePath);
  559. VERIFY_SUCCEEDED(m_support.CreateInstance(CLSID_DxcLibrary, &pLibrary));
  560. VERIFY_SUCCEEDED(pLibrary->CreateBlobFromFile(path.c_str(), nullptr, &pBlob));
  561. VERIFY_SUCCEEDED(pLibrary->CreateStreamFromBlobReadOnly(pBlob, &pStream));
  562. *ppStream = pStream.Detach();
  563. }
  564. void RecordRenderAndReadback(ID3D12GraphicsCommandList *pList,
  565. ID3D12DescriptorHeap *pRtvHeap,
  566. UINT rtvDescriptorSize,
  567. UINT instanceCount,
  568. D3D12_VERTEX_BUFFER_VIEW *pVertexBufferView,
  569. ID3D12RootSignature *pRootSig,
  570. ID3D12Resource *pRenderTarget,
  571. ID3D12Resource *pReadBuffer) {
  572. D3D12_RESOURCE_DESC rtDesc = pRenderTarget->GetDesc();
  573. D3D12_VIEWPORT viewport;
  574. D3D12_RECT scissorRect;
  575. memset(&viewport, 0, sizeof(viewport));
  576. viewport.Height = rtDesc.Height;
  577. viewport.Width = rtDesc.Width;
  578. viewport.MaxDepth = 1.0f;
  579. memset(&scissorRect, 0, sizeof(scissorRect));
  580. scissorRect.right = rtDesc.Width;
  581. scissorRect.bottom = rtDesc.Height;
  582. if (pRootSig != nullptr) {
  583. pList->SetGraphicsRootSignature(pRootSig);
  584. }
  585. pList->RSSetViewports(1, &viewport);
  586. pList->RSSetScissorRects(1, &scissorRect);
  587. // Indicate that the buffer will be used as a render target.
  588. RecordTransitionBarrier(pList, pRenderTarget, D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_RENDER_TARGET);
  589. CD3DX12_CPU_DESCRIPTOR_HANDLE rtvHandle(pRtvHeap->GetCPUDescriptorHandleForHeapStart(), 0, rtvDescriptorSize);
  590. pList->OMSetRenderTargets(1, &rtvHandle, FALSE, nullptr);
  591. pList->ClearRenderTargetView(rtvHandle, ClearColor, 0, nullptr);
  592. pList->IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
  593. pList->IASetVertexBuffers(0, 1, pVertexBufferView);
  594. pList->DrawInstanced(3, instanceCount, 0, 0);
  595. // Transition to copy source and copy into read-back buffer.
  596. RecordTransitionBarrier(pList, pRenderTarget, D3D12_RESOURCE_STATE_RENDER_TARGET, D3D12_RESOURCE_STATE_COPY_SOURCE);
  597. // Copy into read-back buffer.
  598. UINT rowPitch = rtDesc.Width * 4;
  599. if (rowPitch % D3D12_TEXTURE_DATA_PITCH_ALIGNMENT)
  600. rowPitch += D3D12_TEXTURE_DATA_PITCH_ALIGNMENT - (rowPitch % D3D12_TEXTURE_DATA_PITCH_ALIGNMENT);
  601. D3D12_PLACED_SUBRESOURCE_FOOTPRINT Footprint;
  602. Footprint.Offset = 0;
  603. Footprint.Footprint = CD3DX12_SUBRESOURCE_FOOTPRINT(DXGI_FORMAT_R8G8B8A8_UNORM, rtDesc.Width, rtDesc.Height, 1, rowPitch);
  604. CD3DX12_TEXTURE_COPY_LOCATION DstLoc(pReadBuffer, Footprint);
  605. CD3DX12_TEXTURE_COPY_LOCATION SrcLoc(pRenderTarget, 0);
  606. pList->CopyTextureRegion(&DstLoc, 0, 0, 0, &SrcLoc, nullptr);
  607. }
  608. void RunRWByteBufferComputeTest(ID3D12Device *pDevice, LPCSTR shader, std::vector<uint32_t> &values);
  609. void SetDescriptorHeap(ID3D12GraphicsCommandList *pCommandList, ID3D12DescriptorHeap *pHeap) {
  610. ID3D12DescriptorHeap *const pHeaps[1] = { pHeap };
  611. pCommandList->SetDescriptorHeaps(1, pHeaps);
  612. }
  613. void WaitForSignal(ID3D12CommandQueue *pCQ, FenceObj &FO) {
  614. ::WaitForSignal(pCQ, FO.m_fence, FO.m_fenceEvent, FO.m_fenceValue++);
  615. }
  616. };
  617. const float ExecutionTest::ClearColor[4] = { 0.0f, 0.2f, 0.4f, 1.0f };
  618. #define WAVE_INTRINSIC_DXBC_GUARD \
  619. "#ifdef USING_DXBC\r\n" \
  620. "uint WaveGetLaneIndex() { return 1; }\r\n" \
  621. "uint WaveReadLaneFirst(uint u) { return u; }\r\n" \
  622. "bool WaveIsFirstLane() { return true; }\r\n" \
  623. "uint WaveGetLaneCount() { return 1; }\r\n" \
  624. "uint WaveReadLaneAt(uint n, uint u) { return u; }\r\n" \
  625. "bool WaveActiveAnyTrue(bool b) { return b; }\r\n" \
  626. "bool WaveActiveAllTrue(bool b) { return false; }\r\n" \
  627. "uint WaveActiveAllEqual(uint u) { return u; }\r\n" \
  628. "uint4 WaveActiveBallot(bool b) { return 1; }\r\n" \
  629. "uint WaveActiveCountBits(uint u) { return 1; }\r\n" \
  630. "uint WaveActiveSum(uint u) { return 1; }\r\n" \
  631. "uint WaveActiveProduct(uint u) { return 1; }\r\n" \
  632. "uint WaveActiveBitAnd(uint u) { return 1; }\r\n" \
  633. "uint WaveActiveBitOr(uint u) { return 1; }\r\n" \
  634. "uint WaveActiveBitXor(uint u) { return 1; }\r\n" \
  635. "uint WaveActiveMin(uint u) { return 1; }\r\n" \
  636. "uint WaveActiveMax(uint u) { return 1; }\r\n" \
  637. "uint WavePrefixCountBits(uint u) { return 1; }\r\n" \
  638. "uint WavePrefixSum(uint u) { return 1; }\r\n" \
  639. "uint WavePrefixProduct(uint u) { return 1; }\r\n" \
  640. "uint QuadReadLaneAt(uint a, uint u) { return 1; }\r\n" \
  641. "uint QuadReadAcrossX(uint u) { return 1; }\r\n" \
  642. "uint QuadReadAcrossY(uint u) { return 1; }\r\n" \
  643. "uint QuadReadAcrossDiagonal(uint u) { return 1; }\r\n" \
  644. "#endif\r\n"
  645. static void SetupComputeValuePattern(std::vector<uint32_t> &values, size_t count) {
  646. values.resize(count); // one element per dispatch group, in bytes
  647. for (size_t i = 0; i < count; ++i) {
  648. values[i] = i;
  649. }
  650. }
  651. bool ExecutionTest::ExecutionTestClassSetup() {
  652. HRESULT hr = EnableExperimentalMode();
  653. if (FAILED(hr)) {
  654. LogCommentFmt(L"Unable to enable shader experimental mode - 0x%08x.", hr);
  655. }
  656. else {
  657. LogCommentFmt(L"Experimental mode enabled.");
  658. }
  659. hr = EnableDebugLayer();
  660. if (FAILED(hr)) {
  661. LogCommentFmt(L"Unable to enable debug layer - 0x%08x.", hr);
  662. }
  663. else {
  664. LogCommentFmt(L"Debug layer enabled.");
  665. }
  666. return true;
  667. }
  668. void ExecutionTest::RunRWByteBufferComputeTest(ID3D12Device *pDevice, LPCSTR pShader, std::vector<uint32_t> &values) {
  669. static const int DispatchGroupX = 1;
  670. static const int DispatchGroupY = 1;
  671. static const int DispatchGroupZ = 1;
  672. CComPtr<ID3D12GraphicsCommandList> pCommandList;
  673. CComPtr<ID3D12CommandQueue> pCommandQueue;
  674. CComPtr<ID3D12DescriptorHeap> pUavHeap;
  675. CComPtr<ID3D12CommandAllocator> pCommandAllocator;
  676. UINT uavDescriptorSize;
  677. FenceObj FO;
  678. const size_t valueSizeInBytes = values.size() * sizeof(uint32_t);
  679. CreateComputeCommandQueue(pDevice, L"RunRWByteBufferComputeTest Command Queue", &pCommandQueue);
  680. InitFenceObj(pDevice, &FO);
  681. // Describe and create a UAV descriptor heap.
  682. D3D12_DESCRIPTOR_HEAP_DESC heapDesc = {};
  683. heapDesc.NumDescriptors = 1;
  684. heapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
  685. heapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
  686. VERIFY_SUCCEEDED(pDevice->CreateDescriptorHeap(&heapDesc, IID_PPV_ARGS(&pUavHeap)));
  687. uavDescriptorSize = pDevice->GetDescriptorHandleIncrementSize(heapDesc.Type);
  688. // Create root signature.
  689. CComPtr<ID3D12RootSignature> pRootSignature;
  690. {
  691. CD3DX12_DESCRIPTOR_RANGE ranges[1];
  692. ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0, 0);
  693. CD3DX12_ROOT_PARAMETER rootParameters[1];
  694. rootParameters[0].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
  695. CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc;
  696. rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE);
  697. CreateRootSignatureFromDesc(pDevice, &rootSignatureDesc, &pRootSignature);
  698. }
  699. // Create pipeline state object.
  700. CComPtr<ID3D12PipelineState> pComputeState;
  701. CreateComputePSO(pDevice, pRootSignature, pShader, &pComputeState);
  702. // Create a command allocator and list for compute.
  703. VERIFY_SUCCEEDED(pDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&pCommandAllocator)));
  704. VERIFY_SUCCEEDED(pDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE, pCommandAllocator, pComputeState, IID_PPV_ARGS(&pCommandList)));
  705. pCommandList->SetName(L"ExecutionTest::RunRWByteButterComputeTest Command List");
  706. // Set up UAV resource.
  707. CComPtr<ID3D12Resource> pUavResource;
  708. CComPtr<ID3D12Resource> pReadBuffer;
  709. CComPtr<ID3D12Resource> pUploadResource;
  710. CreateTestUavs(pDevice, pCommandList, values.data(), valueSizeInBytes, &pUavResource, &pReadBuffer, &pUploadResource);
  711. VERIFY_SUCCEEDED(pUavResource->SetName(L"RunRWByteBufferComputeText UAV"));
  712. VERIFY_SUCCEEDED(pReadBuffer->SetName(L"RunRWByteBufferComputeText UAV Read Buffer"));
  713. VERIFY_SUCCEEDED(pUploadResource->SetName(L"RunRWByteBufferComputeText UAV Upload Buffer"));
  714. // Close the command list and execute it to perform the GPU setup.
  715. pCommandList->Close();
  716. ExecuteCommandList(pCommandQueue, pCommandList);
  717. WaitForSignal(pCommandQueue, FO);
  718. VERIFY_SUCCEEDED(pCommandAllocator->Reset());
  719. VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputeState));
  720. // Run the compute shader and copy the results back to readable memory.
  721. {
  722. D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
  723. uavDesc.Format = DXGI_FORMAT_R32_TYPELESS;
  724. uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
  725. uavDesc.Buffer.FirstElement = 0;
  726. uavDesc.Buffer.NumElements = values.size();
  727. uavDesc.Buffer.StructureByteStride = 0;
  728. uavDesc.Buffer.CounterOffsetInBytes = 0;
  729. uavDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW;
  730. CD3DX12_CPU_DESCRIPTOR_HANDLE uavHandle(pUavHeap->GetCPUDescriptorHandleForHeapStart());
  731. CD3DX12_GPU_DESCRIPTOR_HANDLE uavHandleGpu(pUavHeap->GetGPUDescriptorHandleForHeapStart());
  732. pDevice->CreateUnorderedAccessView(pUavResource, nullptr, &uavDesc, uavHandle);
  733. SetDescriptorHeap(pCommandList, pUavHeap);
  734. pCommandList->SetComputeRootSignature(pRootSignature);
  735. pCommandList->SetComputeRootDescriptorTable(0, uavHandleGpu);
  736. }
  737. pCommandList->Dispatch(DispatchGroupX, DispatchGroupY, DispatchGroupZ);
  738. RecordTransitionBarrier(pCommandList, pUavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE);
  739. pCommandList->CopyResource(pReadBuffer, pUavResource);
  740. pCommandList->Close();
  741. ExecuteCommandList(pCommandQueue, pCommandList);
  742. WaitForSignal(pCommandQueue, FO);
  743. {
  744. MappedData mappedData(pReadBuffer, valueSizeInBytes);
  745. uint32_t *pData = (uint32_t *)mappedData.data();
  746. memcpy(values.data(), pData, valueSizeInBytes);
  747. }
  748. WaitForSignal(pCommandQueue, FO);
  749. }
  750. TEST_F(ExecutionTest, BasicComputeTest) {
  751. //
  752. // BasicComputeTest is a simple compute shader that can be used as the basis
  753. // for more interesting compute execution tests.
  754. // The HLSL is compatible with shader models <=5.1 to allow using the DXBC
  755. // rendering code paths for comparison.
  756. //
  757. static const char pShader[] =
  758. "RWByteAddressBuffer g_bab : register(u0);\r\n"
  759. "[numthreads(8,8,1)]\r\n"
  760. "void main(uint GI : SV_GroupIndex) {"
  761. " uint addr = GI * 4;\r\n"
  762. " uint val = g_bab.Load(addr);\r\n"
  763. " DeviceMemoryBarrierWithGroupSync();\r\n"
  764. " g_bab.Store(addr, val + 1);\r\n"
  765. "}";
  766. static const int NumtheadsX = 8;
  767. static const int NumtheadsY = 8;
  768. static const int NumtheadsZ = 1;
  769. static const int ThreadsPerGroup = NumtheadsX * NumtheadsY * NumtheadsZ;
  770. static const int DispatchGroupCount = 1;
  771. CComPtr<ID3D12Device> pDevice;
  772. if (!CreateDevice(&pDevice))
  773. return;
  774. std::vector<uint32_t> values;
  775. SetupComputeValuePattern(values, ThreadsPerGroup * DispatchGroupCount);
  776. VERIFY_ARE_EQUAL(values[0], 0);
  777. RunRWByteBufferComputeTest(pDevice, pShader, values);
  778. VERIFY_ARE_EQUAL(values[0], 1);
  779. }
  780. TEST_F(ExecutionTest, BasicTriangleTest) {
  781. static const UINT FrameCount = 2;
  782. static const UINT m_width = 320;
  783. static const UINT m_height = 200;
  784. static const float m_aspectRatio = static_cast<float>(m_width) / static_cast<float>(m_height);
  785. struct Vertex {
  786. XMFLOAT3 position;
  787. XMFLOAT4 color;
  788. };
  789. // Pipeline objects.
  790. CComPtr<ID3D12Device> pDevice;
  791. CComPtr<ID3D12Resource> pRenderTarget;
  792. CComPtr<ID3D12CommandAllocator> pCommandAllocator;
  793. CComPtr<ID3D12CommandQueue> pCommandQueue;
  794. CComPtr<ID3D12RootSignature> pRootSig;
  795. CComPtr<ID3D12DescriptorHeap> pRtvHeap;
  796. CComPtr<ID3D12PipelineState> pPipelineState;
  797. CComPtr<ID3D12GraphicsCommandList> pCommandList;
  798. CComPtr<ID3D12Resource> pReadBuffer;
  799. UINT rtvDescriptorSize;
  800. CComPtr<ID3D12Resource> pVertexBuffer;
  801. D3D12_VERTEX_BUFFER_VIEW vertexBufferView;
  802. // Synchronization objects.
  803. FenceObj FO;
  804. // Shaders.
  805. static const char pShaders[] =
  806. "struct PSInput {\r\n"
  807. " float4 position : SV_POSITION;\r\n"
  808. " float4 color : COLOR;\r\n"
  809. "};\r\n\r\n"
  810. "PSInput VSMain(float4 position : POSITION, float4 color : COLOR) {\r\n"
  811. " PSInput result;\r\n"
  812. "\r\n"
  813. " result.position = position;\r\n"
  814. " result.color = color;\r\n"
  815. " return result;\r\n"
  816. "}\r\n\r\n"
  817. "float4 PSMain(PSInput input) : SV_TARGET {\r\n"
  818. " return 1; //input.color;\r\n"
  819. "};\r\n";
  820. if (!CreateDevice(&pDevice))
  821. return;
  822. struct BasicTestChecker {
  823. CComPtr<ID3D12Device> m_pDevice;
  824. CComPtr<ID3D12InfoQueue> m_pInfoQueue;
  825. bool m_OK = false;
  826. void SetOK(bool value) { m_OK = value; }
  827. BasicTestChecker(ID3D12Device *pDevice) : m_pDevice(pDevice) {
  828. if (FAILED(m_pDevice.QueryInterface(&m_pInfoQueue)))
  829. return;
  830. m_pInfoQueue->PushEmptyStorageFilter();
  831. m_pInfoQueue->PushEmptyRetrievalFilter();
  832. }
  833. ~BasicTestChecker() {
  834. if (!m_OK && m_pInfoQueue != nullptr) {
  835. UINT64 count = m_pInfoQueue->GetNumStoredMessages();
  836. bool invalidBytecodeFound = false;
  837. CAtlArray<BYTE> m_pBytes;
  838. for (UINT64 i = 0; i < count; ++i) {
  839. SIZE_T len = 0;
  840. if (FAILED(m_pInfoQueue->GetMessageA(i, nullptr, &len)))
  841. continue;
  842. if (m_pBytes.GetCount() < len && !m_pBytes.SetCount(len))
  843. continue;
  844. D3D12_MESSAGE *pMsg = (D3D12_MESSAGE *)m_pBytes.GetData();
  845. if (FAILED(m_pInfoQueue->GetMessageA(i, pMsg, &len)))
  846. continue;
  847. if (pMsg->ID == D3D12_MESSAGE_ID_CREATEVERTEXSHADER_INVALIDSHADERBYTECODE ||
  848. pMsg->ID == D3D12_MESSAGE_ID_CREATEPIXELSHADER_INVALIDSHADERBYTECODE) {
  849. invalidBytecodeFound = true;
  850. break;
  851. }
  852. }
  853. if (invalidBytecodeFound) {
  854. LogCommentFmt(L"%s", L"Found an invalid bytecode message. This "
  855. L"typically indicates that experimental mode "
  856. L"is not set up properly.");
  857. if (!GetTestParamBool(L"ExperimentalShaders")) {
  858. LogCommentFmt(L"Note that the ExperimentalShaders test parameter isn't set.");
  859. }
  860. }
  861. else {
  862. LogCommentFmt(L"Did not find corrupt pixel or vertex shaders in "
  863. L"queue - dumping complete queue.");
  864. WriteInfoQueueMessages(nullptr, OutputFn, m_pInfoQueue);
  865. }
  866. }
  867. }
  868. static void __stdcall OutputFn(void *pCtx, const wchar_t *pMsg) {
  869. LogCommentFmt(L"%s", pMsg);
  870. }
  871. };
  872. BasicTestChecker BTC(pDevice);
  873. {
  874. InitFenceObj(pDevice, &FO);
  875. CreateRtvDescriptorHeap(pDevice, FrameCount, &pRtvHeap, &rtvDescriptorSize);
  876. CreateRenderTargetAndReadback(pDevice, pRtvHeap, m_width, m_height, &pRenderTarget, &pReadBuffer);
  877. // Create an empty root signature.
  878. CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc;
  879. rootSignatureDesc.Init(
  880. 0, nullptr, 0, nullptr,
  881. D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
  882. CreateRootSignatureFromDesc(pDevice, &rootSignatureDesc, &pRootSig);
  883. // Create the pipeline state, which includes compiling and loading shaders.
  884. // Define the vertex input layout.
  885. D3D12_INPUT_ELEMENT_DESC inputElementDescs[] = {
  886. {"POSITION", 0, DXGI_FORMAT_R32G32B32_FLOAT, 0, 0,
  887. D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA, 0},
  888. {"COLOR", 0, DXGI_FORMAT_R32G32B32A32_FLOAT, 0, 12,
  889. D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA, 0}};
  890. D3D12_INPUT_LAYOUT_DESC InputLayout = { inputElementDescs, _countof(inputElementDescs) };
  891. CreateGraphicsPSO(pDevice, &InputLayout, pRootSig, pShaders, &pPipelineState);
  892. CreateGraphicsCommandQueueAndList(pDevice, &pCommandQueue,
  893. &pCommandAllocator, &pCommandList,
  894. pPipelineState);
  895. // Define the geometry for a triangle.
  896. Vertex triangleVertices[] = {
  897. { { 0.0f, 0.25f * m_aspectRatio, 0.0f },{ 1.0f, 0.0f, 0.0f, 1.0f } },
  898. { { 0.25f, -0.25f * m_aspectRatio, 0.0f },{ 0.0f, 1.0f, 0.0f, 1.0f } },
  899. { { -0.25f, -0.25f * m_aspectRatio, 0.0f },{ 0.0f, 0.0f, 1.0f, 1.0f } } };
  900. CreateVertexBuffer(pDevice, triangleVertices, &pVertexBuffer, &vertexBufferView);
  901. WaitForSignal(pCommandQueue, FO);
  902. }
  903. // Render and execute the command list.
  904. RecordRenderAndReadback(pCommandList, pRtvHeap, rtvDescriptorSize, 1,
  905. &vertexBufferView, pRootSig, pRenderTarget,
  906. pReadBuffer);
  907. VERIFY_SUCCEEDED(pCommandList->Close());
  908. ExecuteCommandList(pCommandQueue, pCommandList);
  909. // Wait for previous frame.
  910. WaitForSignal(pCommandQueue, FO);
  911. // At this point, we've verified that execution succeeded with DXIL.
  912. BTC.SetOK(true);
  913. // Read back to CPU and examine contents.
  914. {
  915. MappedData data(pReadBuffer, m_width * m_height * 4);
  916. const uint32_t *pPixels = (uint32_t *)data.data();
  917. if (SaveImages()) {
  918. SavePixelsToFile(pPixels, DXGI_FORMAT_R8G8B8A8_UNORM, m_width, m_height, L"basic.bmp");
  919. }
  920. uint32_t top = pPixels[m_width / 2]; // Top center.
  921. uint32_t mid = pPixels[m_width / 2 + m_width * (m_height / 2)]; // Middle center.
  922. VERIFY_ARE_EQUAL(0xff663300, top); // clear color
  923. VERIFY_ARE_EQUAL(0xffffffff, mid); // white
  924. }
  925. }
  926. TEST_F(ExecutionTest, Int64Test) {
  927. static const char pShader[] =
  928. "RWByteAddressBuffer g_bab : register(u0);\r\n"
  929. "[numthreads(8,8,1)]\r\n"
  930. "void main(uint GI : SV_GroupIndex) {"
  931. " uint addr = GI * 4;\r\n"
  932. " uint val = g_bab.Load(addr);\r\n"
  933. " uint64_t u64 = val;\r\n"
  934. " u64 *= val;\r\n"
  935. " g_bab.Store(addr, (uint)(u64 >> 32));\r\n"
  936. "}";
  937. static const int NumtheadsX = 8;
  938. static const int NumtheadsY = 8;
  939. static const int NumtheadsZ = 1;
  940. static const int ThreadsPerGroup = NumtheadsX * NumtheadsY * NumtheadsZ;
  941. static const int DispatchGroupCount = 1;
  942. CComPtr<ID3D12Device> pDevice;
  943. if (!CreateDevice(&pDevice))
  944. return;
  945. if (!DoesDeviceSupportInt64(pDevice)) {
  946. // Optional feature, so it's correct to not support it if declared as such.
  947. WEX::Logging::Log::Comment(L"Device does not support int64 operations.");
  948. return;
  949. }
  950. std::vector<uint32_t> values;
  951. SetupComputeValuePattern(values, ThreadsPerGroup * DispatchGroupCount);
  952. VERIFY_ARE_EQUAL(values[0], 0);
  953. RunRWByteBufferComputeTest(pDevice, pShader, values);
  954. VERIFY_ARE_EQUAL(values[0], 0);
  955. }
  956. TEST_F(ExecutionTest, SignTest) {
  957. static const char pShader[] =
  958. "RWByteAddressBuffer g_bab : register(u0);\r\n"
  959. "[numthreads(8,1,1)]\r\n"
  960. "void main(uint GI : SV_GroupIndex) {"
  961. " uint addr = GI * 4;\r\n"
  962. " int val = g_bab.Load(addr);\r\n"
  963. " g_bab.Store(addr, (uint)(sign(val)));\r\n"
  964. "}";
  965. static const int NumtheadsX = 8;
  966. static const int NumtheadsY = 1;
  967. static const int NumtheadsZ = 1;
  968. static const int ThreadsPerGroup = NumtheadsX * NumtheadsY * NumtheadsZ;
  969. static const int DispatchGroupCount = 1;
  970. CComPtr<ID3D12Device> pDevice;
  971. if (!CreateDevice(&pDevice))
  972. return;
  973. std::vector<uint32_t> values = { (uint32_t)-3, (uint32_t)-2, (uint32_t)-1, 0, 1, 2, 3, 4};
  974. RunRWByteBufferComputeTest(pDevice, pShader, values);
  975. VERIFY_ARE_EQUAL(values[0], -1);
  976. VERIFY_ARE_EQUAL(values[1], -1);
  977. VERIFY_ARE_EQUAL(values[2], -1);
  978. VERIFY_ARE_EQUAL(values[3], 0);
  979. VERIFY_ARE_EQUAL(values[4], 1);
  980. VERIFY_ARE_EQUAL(values[5], 1);
  981. VERIFY_ARE_EQUAL(values[6], 1);
  982. VERIFY_ARE_EQUAL(values[7], 1);
  983. }
  984. TEST_F(ExecutionTest, WaveIntrinsicsTest) {
  985. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  986. struct PerThreadData {
  987. uint32_t id, flags, laneIndex, laneCount, firstLaneId, preds, firstlaneX, lane1X;
  988. uint32_t allBC, allSum, allProd, allAND, allOR, allXOR, allMin, allMax;
  989. uint32_t pfBC, pfSum, pfProd;
  990. uint32_t ballot[4];
  991. uint32_t diver; // divergent value, used in calculation
  992. int32_t i_diver; // divergent value, used in calculation
  993. int32_t i_allMax, i_allMin, i_allSum, i_allProd;
  994. int32_t i_pfSum, i_pfProd;
  995. };
  996. static const char pShader[] =
  997. WAVE_INTRINSIC_DXBC_GUARD
  998. "struct PerThreadData {\r\n"
  999. " uint id, flags, laneIndex, laneCount, firstLaneId, preds, firstlaneX, lane1X;\r\n"
  1000. " uint allBC, allSum, allProd, allAND, allOR, allXOR, allMin, allMax;\r\n"
  1001. " uint pfBC, pfSum, pfProd;\r\n"
  1002. " uint4 ballot;\r\n"
  1003. " uint diver;\r\n"
  1004. " int i_diver;\r\n"
  1005. " int i_allMax, i_allMin, i_allSum, i_allProd;\r\n"
  1006. " int i_pfSum, i_pfProd;\r\n"
  1007. "};\r\n"
  1008. "RWStructuredBuffer<PerThreadData> g_sb : register(u0);\r\n"
  1009. "[numthreads(8,8,1)]\r\n"
  1010. "void main(uint GI : SV_GroupIndex, uint3 GTID : SV_GroupThreadID) {"
  1011. " PerThreadData pts = g_sb[GI];\r\n"
  1012. " uint diver = GTID.x + 2;\r\n"
  1013. " pts.diver = diver;\r\n"
  1014. " pts.flags = 0;\r\n"
  1015. " pts.preds = 0;\r\n"
  1016. " if (WaveIsFirstLane()) pts.flags |= 1;\r\n"
  1017. " pts.laneIndex = WaveGetLaneIndex();\r\n"
  1018. " pts.laneCount = WaveGetLaneCount();\r\n"
  1019. " pts.firstLaneId = WaveReadLaneFirst(pts.id);\r\n"
  1020. " pts.preds |= ((WaveActiveAnyTrue(diver == 1) ? 1 : 0) << 0);\r\n"
  1021. " pts.preds |= ((WaveActiveAllTrue(diver == 1) ? 1 : 0) << 1);\r\n"
  1022. " pts.preds |= ((WaveActiveAllEqual(diver) ? 1 : 0) << 2);\r\n"
  1023. " pts.preds |= ((WaveActiveAllEqual(GTID.z) ? 1 : 0) << 3);\r\n"
  1024. " pts.preds |= ((WaveActiveAllEqual(WaveReadLaneFirst(diver)) ? 1 : 0) << 4);\r\n"
  1025. " pts.ballot = WaveActiveBallot(diver > 3);\r\n"
  1026. " pts.firstlaneX = WaveReadLaneFirst(GTID.x);\r\n"
  1027. " pts.lane1X = WaveReadLaneAt(GTID.x, 1);\r\n"
  1028. "\r\n"
  1029. " pts.allBC = WaveActiveCountBits(diver > 3);\r\n"
  1030. " pts.allSum = WaveActiveSum(diver);\r\n"
  1031. " pts.allProd = WaveActiveProduct(diver);\r\n"
  1032. " pts.allAND = WaveActiveBitAnd(diver);\r\n"
  1033. " pts.allOR = WaveActiveBitOr(diver);\r\n"
  1034. " pts.allXOR = WaveActiveBitXor(diver);\r\n"
  1035. " pts.allMin = WaveActiveMin(diver);\r\n"
  1036. " pts.allMax = WaveActiveMax(diver);\r\n"
  1037. "\r\n"
  1038. " pts.pfBC = WavePrefixCountBits(diver > 3);\r\n"
  1039. " pts.pfSum = WavePrefixSum(diver);\r\n"
  1040. " pts.pfProd = WavePrefixProduct(diver);\r\n"
  1041. "\r\n"
  1042. " int i_diver = pts.i_diver;\r\n"
  1043. " pts.i_allMax = WaveActiveMax(i_diver);\r\n"
  1044. " pts.i_allMin = WaveActiveMin(i_diver);\r\n"
  1045. " pts.i_allSum = WaveActiveSum(i_diver);\r\n"
  1046. " pts.i_allProd = WaveActiveProduct(i_diver);\r\n"
  1047. " pts.i_pfSum = WavePrefixSum(i_diver);\r\n"
  1048. " pts.i_pfProd = WavePrefixProduct(i_diver);\r\n"
  1049. "\r\n"
  1050. " g_sb[GI] = pts;\r\n"
  1051. "}";
  1052. static const int NumtheadsX = 8;
  1053. static const int NumtheadsY = 8;
  1054. static const int NumtheadsZ = 1;
  1055. static const int ThreadsPerGroup = NumtheadsX * NumtheadsY * NumtheadsZ;
  1056. static const int DispatchGroupCount = 1;
  1057. CComPtr<ID3D12Device> pDevice;
  1058. if (!CreateDevice(&pDevice))
  1059. return;
  1060. if (!DoesDeviceSupportWaveOps(pDevice)) {
  1061. // Optional feature, so it's correct to not support it if declared as such.
  1062. WEX::Logging::Log::Comment(L"Device does not support wave operations.");
  1063. return;
  1064. }
  1065. std::vector<PerThreadData> values;
  1066. values.resize(ThreadsPerGroup * DispatchGroupCount);
  1067. for (size_t i = 0; i < values.size(); ++i) {
  1068. memset(&values[i], 0, sizeof(PerThreadData));
  1069. values[i].id = i;
  1070. values[i].i_diver = (int)i;
  1071. values[i].i_diver *= (i % 2) ? 1 : -1;
  1072. }
  1073. static const int DispatchGroupX = 1;
  1074. static const int DispatchGroupY = 1;
  1075. static const int DispatchGroupZ = 1;
  1076. CComPtr<ID3D12GraphicsCommandList> pCommandList;
  1077. CComPtr<ID3D12CommandQueue> pCommandQueue;
  1078. CComPtr<ID3D12DescriptorHeap> pUavHeap;
  1079. CComPtr<ID3D12CommandAllocator> pCommandAllocator;
  1080. UINT uavDescriptorSize;
  1081. FenceObj FO;
  1082. bool dxbc = UseDxbc();
  1083. const size_t valueSizeInBytes = values.size() * sizeof(PerThreadData);
  1084. CreateComputeCommandQueue(pDevice, L"WaveIntrinsicsTest Command Queue", &pCommandQueue);
  1085. InitFenceObj(pDevice, &FO);
  1086. // Describe and create a UAV descriptor heap.
  1087. D3D12_DESCRIPTOR_HEAP_DESC heapDesc = {};
  1088. heapDesc.NumDescriptors = 1;
  1089. heapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
  1090. heapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
  1091. VERIFY_SUCCEEDED(pDevice->CreateDescriptorHeap(&heapDesc, IID_PPV_ARGS(&pUavHeap)));
  1092. uavDescriptorSize = pDevice->GetDescriptorHandleIncrementSize(heapDesc.Type);
  1093. // Create root signature.
  1094. CComPtr<ID3D12RootSignature> pRootSignature;
  1095. {
  1096. CD3DX12_DESCRIPTOR_RANGE ranges[1];
  1097. ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0, 0, 0);
  1098. CD3DX12_ROOT_PARAMETER rootParameters[1];
  1099. rootParameters[0].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
  1100. CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc;
  1101. rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_NONE);
  1102. CComPtr<ID3DBlob> signature;
  1103. CComPtr<ID3DBlob> error;
  1104. VERIFY_SUCCEEDED(D3D12SerializeRootSignature(&rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &signature, &error));
  1105. VERIFY_SUCCEEDED(pDevice->CreateRootSignature(0, signature->GetBufferPointer(), signature->GetBufferSize(), IID_PPV_ARGS(&pRootSignature)));
  1106. }
  1107. // Create pipeline state object.
  1108. CComPtr<ID3D12PipelineState> pComputeState;
  1109. CreateComputePSO(pDevice, pRootSignature, pShader, &pComputeState);
  1110. // Create a command allocator and list for compute.
  1111. VERIFY_SUCCEEDED(pDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&pCommandAllocator)));
  1112. VERIFY_SUCCEEDED(pDevice->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE, pCommandAllocator, pComputeState, IID_PPV_ARGS(&pCommandList)));
  1113. // Set up UAV resource.
  1114. CComPtr<ID3D12Resource> pUavResource;
  1115. CComPtr<ID3D12Resource> pReadBuffer;
  1116. CComPtr<ID3D12Resource> pUploadResource;
  1117. CreateTestUavs(pDevice, pCommandList, values.data(), valueSizeInBytes, &pUavResource, &pReadBuffer, &pUploadResource);
  1118. // Close the command list and execute it to perform the GPU setup.
  1119. pCommandList->Close();
  1120. ExecuteCommandList(pCommandQueue, pCommandList);
  1121. WaitForSignal(pCommandQueue, FO);
  1122. VERIFY_SUCCEEDED(pCommandAllocator->Reset());
  1123. VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pComputeState));
  1124. // Run the compute shader and copy the results back to readable memory.
  1125. {
  1126. D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
  1127. uavDesc.Format = DXGI_FORMAT_UNKNOWN;
  1128. uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
  1129. uavDesc.Buffer.FirstElement = 0;
  1130. uavDesc.Buffer.NumElements = values.size();
  1131. uavDesc.Buffer.StructureByteStride = sizeof(PerThreadData);
  1132. uavDesc.Buffer.CounterOffsetInBytes = 0;
  1133. uavDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_NONE;
  1134. CD3DX12_CPU_DESCRIPTOR_HANDLE uavHandle(pUavHeap->GetCPUDescriptorHandleForHeapStart());
  1135. CD3DX12_GPU_DESCRIPTOR_HANDLE uavHandleGpu(pUavHeap->GetGPUDescriptorHandleForHeapStart());
  1136. pDevice->CreateUnorderedAccessView(pUavResource, nullptr, &uavDesc, uavHandle);
  1137. SetDescriptorHeap(pCommandList, pUavHeap);
  1138. pCommandList->SetComputeRootSignature(pRootSignature);
  1139. pCommandList->SetComputeRootDescriptorTable(0, uavHandleGpu);
  1140. }
  1141. pCommandList->Dispatch(DispatchGroupX, DispatchGroupY, DispatchGroupZ);
  1142. RecordTransitionBarrier(pCommandList, pUavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE);
  1143. pCommandList->CopyResource(pReadBuffer, pUavResource);
  1144. pCommandList->Close();
  1145. ExecuteCommandList(pCommandQueue, pCommandList);
  1146. WaitForSignal(pCommandQueue, FO);
  1147. {
  1148. MappedData mappedData(pReadBuffer, valueSizeInBytes);
  1149. PerThreadData *pData = (PerThreadData *)mappedData.data();
  1150. memcpy(values.data(), pData, valueSizeInBytes);
  1151. // Gather some general data.
  1152. // The 'firstLaneId' captures a unique number per first-lane per wave.
  1153. // Counting the number distinct firstLaneIds gives us the number of waves.
  1154. std::vector<uint32_t> firstLaneIds;
  1155. for (size_t i = 0; i < values.size(); ++i) {
  1156. PerThreadData &pts = values[i];
  1157. uint32_t firstLaneId = pts.firstLaneId;
  1158. if (!contains(firstLaneIds, firstLaneId)) {
  1159. firstLaneIds.push_back(firstLaneId);
  1160. }
  1161. }
  1162. // Waves should cover 4 threads or more.
  1163. LogCommentFmt(L"Found %u distinct lane ids: %u", firstLaneIds.size());
  1164. if (!dxbc) {
  1165. VERIFY_IS_GREATER_THAN_OR_EQUAL(values.size() / 4, firstLaneIds.size());
  1166. }
  1167. // Now, group threads into waves.
  1168. std::map<uint32_t, std::unique_ptr<std::vector<PerThreadData *> > > waves;
  1169. for (size_t i = 0; i < firstLaneIds.size(); ++i) {
  1170. waves[firstLaneIds[i]] = std::make_unique<std::vector<PerThreadData *> >();
  1171. }
  1172. for (size_t i = 0; i < values.size(); ++i) {
  1173. PerThreadData &pts = values[i];
  1174. std::unique_ptr<std::vector<PerThreadData *> > &wave = waves[pts.firstLaneId];
  1175. wave->push_back(&pts);
  1176. }
  1177. // Verify that all the wave values are coherent across the wave.
  1178. for (size_t i = 0; i < values.size(); ++i) {
  1179. PerThreadData &pts = values[i];
  1180. std::unique_ptr<std::vector<PerThreadData *> > &wave = waves[pts.firstLaneId];
  1181. // Sort the lanes by increasing lane ID.
  1182. struct LaneIdOrderPred {
  1183. bool operator()(PerThreadData *a, PerThreadData *b) {
  1184. return a->laneIndex < b->laneIndex;
  1185. }
  1186. };
  1187. std::sort(wave.get()->begin(), wave.get()->end(), LaneIdOrderPred());
  1188. // Verify some interesting properties of the first lane.
  1189. uint32_t pfBC, pfSum, pfProd;
  1190. int32_t i_pfSum, i_pfProd;
  1191. int32_t i_allMax, i_allMin;
  1192. {
  1193. PerThreadData *ptdFirst = wave->front();
  1194. VERIFY_IS_TRUE(0 != (ptdFirst->flags & 1)); // FirstLane sets this bit.
  1195. VERIFY_IS_TRUE(0 == ptdFirst->pfBC);
  1196. VERIFY_IS_TRUE(0 == ptdFirst->pfSum);
  1197. VERIFY_IS_TRUE(1 == ptdFirst->pfProd);
  1198. VERIFY_IS_TRUE(0 == ptdFirst->i_pfSum);
  1199. VERIFY_IS_TRUE(1 == ptdFirst->i_pfProd);
  1200. pfBC = (ptdFirst->diver > 3) ? 1 : 0;
  1201. pfSum = ptdFirst->diver;
  1202. pfProd = ptdFirst->diver;
  1203. i_pfSum = ptdFirst->i_diver;
  1204. i_pfProd = ptdFirst->i_diver;
  1205. i_allMax = i_allMin = ptdFirst->i_diver;
  1206. }
  1207. // Calculate values which take into consideration all lanes.
  1208. uint32_t preds = 0;
  1209. preds |= 1 << 1; // AllTrue starts true, switches to false if needed.
  1210. preds |= 1 << 2; // AllEqual starts true, switches to false if needed.
  1211. preds |= 1 << 3; // WaveActiveAllEqual(GTID.z) is always true
  1212. preds |= 1 << 4; // (WaveActiveAllEqual(WaveReadLaneFirst(diver)) is always true
  1213. uint32_t ballot[4] = { 0, 0, 0, 0 };
  1214. int32_t i_allSum = 0, i_allProd = 1;
  1215. for (size_t n = 0; n < wave->size(); ++n) {
  1216. std::vector<PerThreadData *> &lanes = *wave.get();
  1217. // pts.preds |= ((WaveActiveAnyTrue(diver == 1) ? 1 : 0) << 0);
  1218. if (lanes[n]->diver == 1) preds |= (1 << 0);
  1219. // pts.preds |= ((WaveActiveAllTrue(diver == 1) ? 1 : 0) << 1);
  1220. if (lanes[n]->diver != 1) preds &= ~(1 << 1);
  1221. // pts.preds |= ((WaveActiveAllEqual(diver) ? 1 : 0) << 2);
  1222. if (lanes[0]->diver != lanes[n]->diver) preds &= ~(1 << 2);
  1223. // pts.ballot = WaveActiveBallot(diver > 3);\r\n"
  1224. if (lanes[n]->diver > 3) {
  1225. // This is the uint4 result layout:
  1226. // .x -> bits 0 .. 31
  1227. // .y -> bits 32 .. 63
  1228. // .z -> bits 64 .. 95
  1229. // .w -> bits 96 ..127
  1230. uint32_t component = lanes[n]->laneIndex / 32;
  1231. uint32_t bit = lanes[n]->laneIndex % 32;
  1232. ballot[component] |= 1 << bit;
  1233. }
  1234. i_allMax = std::max(lanes[n]->i_diver, i_allMax);
  1235. i_allMin = std::min(lanes[n]->i_diver, i_allMin);
  1236. i_allProd *= lanes[n]->i_diver;
  1237. i_allSum += lanes[n]->i_diver;
  1238. }
  1239. for (size_t n = 1; n < wave->size(); ++n) {
  1240. // 'All' operations are uniform across the wave.
  1241. std::vector<PerThreadData *> &lanes = *wave.get();
  1242. VERIFY_IS_TRUE(0 == (lanes[n]->flags & 1)); // non-firstlanes do not set this bit
  1243. VERIFY_ARE_EQUAL(lanes[0]->allBC, lanes[n]->allBC);
  1244. VERIFY_ARE_EQUAL(lanes[0]->allSum, lanes[n]->allSum);
  1245. VERIFY_ARE_EQUAL(lanes[0]->allProd, lanes[n]->allProd);
  1246. VERIFY_ARE_EQUAL(lanes[0]->allAND, lanes[n]->allAND);
  1247. VERIFY_ARE_EQUAL(lanes[0]->allOR, lanes[n]->allOR);
  1248. VERIFY_ARE_EQUAL(lanes[0]->allXOR, lanes[n]->allXOR);
  1249. VERIFY_ARE_EQUAL(lanes[0]->allMin, lanes[n]->allMin);
  1250. VERIFY_ARE_EQUAL(lanes[0]->allMax, lanes[n]->allMax);
  1251. VERIFY_ARE_EQUAL(i_allMax, lanes[n]->i_allMax);
  1252. VERIFY_ARE_EQUAL(i_allMin, lanes[n]->i_allMin);
  1253. VERIFY_ARE_EQUAL(i_allProd, lanes[n]->i_allProd);
  1254. VERIFY_ARE_EQUAL(i_allSum, lanes[n]->i_allSum);
  1255. // first-lane reads and uniform reads are uniform across the wave.
  1256. VERIFY_ARE_EQUAL(lanes[0]->firstlaneX, lanes[n]->firstlaneX);
  1257. VERIFY_ARE_EQUAL(lanes[0]->lane1X, lanes[n]->lane1X);
  1258. // the lane count is uniform across the wave.
  1259. VERIFY_ARE_EQUAL(lanes[0]->laneCount, lanes[n]->laneCount);
  1260. // The predicates are uniform across the wave.
  1261. VERIFY_ARE_EQUAL(lanes[n]->preds, preds);
  1262. // the lane index is distinct per thread.
  1263. for (size_t prior = 0; prior < n; ++prior) {
  1264. VERIFY_ARE_NOT_EQUAL(lanes[prior]->laneIndex, lanes[n]->laneIndex);
  1265. }
  1266. // Ballot results are uniform across the wave.
  1267. VERIFY_ARE_EQUAL(0, memcmp(ballot, lanes[n]->ballot, sizeof(ballot)));
  1268. // Keep running total of prefix calculation. Prefix values are exclusive to
  1269. // the executing lane.
  1270. VERIFY_ARE_EQUAL(pfBC, lanes[n]->pfBC);
  1271. VERIFY_ARE_EQUAL(pfSum, lanes[n]->pfSum);
  1272. VERIFY_ARE_EQUAL(pfProd, lanes[n]->pfProd);
  1273. VERIFY_ARE_EQUAL(i_pfSum, lanes[n]->i_pfSum);
  1274. VERIFY_ARE_EQUAL(i_pfProd, lanes[n]->i_pfProd);
  1275. pfBC += (lanes[n]->diver > 3) ? 1 : 0;
  1276. pfSum += lanes[n]->diver;
  1277. pfProd *= lanes[n]->diver;
  1278. i_pfSum += lanes[n]->i_diver;
  1279. i_pfProd *= lanes[n]->i_diver;
  1280. }
  1281. // TODO: add divergent branching and verify that the otherwise uniform values properly diverge
  1282. }
  1283. // Compare each value of each per-thread element.
  1284. for (size_t i = 0; i < values.size(); ++i) {
  1285. PerThreadData &pts = values[i];
  1286. VERIFY_ARE_EQUAL(i, pts.id); // ID is unchanged.
  1287. }
  1288. }
  1289. }
  1290. TEST_F(ExecutionTest, WaveIntrinsicsInPSTest) {
  1291. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  1292. struct Vertex {
  1293. XMFLOAT3 position;
  1294. };
  1295. struct PerPixelData {
  1296. XMFLOAT4 position;
  1297. uint32_t id, flags, laneIndex, laneCount, firstLaneId, sum1;
  1298. uint32_t id0, id1, id2, id3;
  1299. uint32_t acrossX, acrossY, acrossDiag, quadActiveCount;
  1300. };
  1301. const UINT RTWidth = 128;
  1302. const UINT RTHeight = 128;
  1303. // Shaders.
  1304. static const char pShaders[] =
  1305. WAVE_INTRINSIC_DXBC_GUARD
  1306. "struct PSInput {\r\n"
  1307. " float4 position : SV_POSITION;\r\n"
  1308. "};\r\n\r\n"
  1309. "PSInput VSMain(float4 position : POSITION) {\r\n"
  1310. " PSInput result;\r\n"
  1311. "\r\n"
  1312. " result.position = position;\r\n"
  1313. " return result;\r\n"
  1314. "}\r\n\r\n"
  1315. "typedef uint uint32_t;\r\n"
  1316. "uint pos_to_id(float4 pos) { return pos.x * 128 + pos.y; }\r\n"
  1317. "struct PerPixelData {\r\n"
  1318. " float4 position;\r\n"
  1319. " uint32_t id, flags, laneIndex, laneCount, firstLaneId, sum1;\r\n"
  1320. " uint32_t id0, id1, id2, id3;\r\n"
  1321. " uint32_t acrossX, acrossY, acrossDiag, quadActiveCount;\r\n"
  1322. "};\r\n"
  1323. "AppendStructuredBuffer<PerPixelData> g_sb : register(u1);\r\n"
  1324. "float4 PSMain(PSInput input) : SV_TARGET {\r\n"
  1325. " uint one = 1;\r\n"
  1326. " PerPixelData d;\r\n"
  1327. " d.position = input.position;\r\n"
  1328. " d.id = pos_to_id(input.position);\r\n"
  1329. " d.flags = 0;\r\n"
  1330. " if (WaveIsFirstLane()) d.flags |= 1;\r\n"
  1331. " d.laneIndex = WaveGetLaneIndex();\r\n"
  1332. " d.laneCount = WaveGetLaneCount();\r\n"
  1333. " d.firstLaneId = WaveReadLaneFirst(d.id);\r\n"
  1334. " d.sum1 = WaveActiveSum(one);\r\n"
  1335. " d.id0 = QuadReadLaneAt(d.id, 0);\r\n"
  1336. " d.id1 = QuadReadLaneAt(d.id, 1);\r\n"
  1337. " d.id2 = QuadReadLaneAt(d.id, 2);\r\n"
  1338. " d.id3 = QuadReadLaneAt(d.id, 3);\r\n"
  1339. " d.acrossX = QuadReadAcrossX(d.id);\r\n"
  1340. " d.acrossY = QuadReadAcrossY(d.id);\r\n"
  1341. " d.acrossDiag = QuadReadAcrossDiagonal(d.id);\r\n"
  1342. " d.quadActiveCount = one + QuadReadAcrossX(one) + QuadReadAcrossY(one) + QuadReadAcrossDiagonal(one);\r\n"
  1343. " g_sb.Append(d);\r\n"
  1344. " return 1;\r\n"
  1345. "};\r\n";
  1346. CComPtr<ID3D12Device> pDevice;
  1347. CComPtr<ID3D12CommandQueue> pCommandQueue;
  1348. CComPtr<ID3D12DescriptorHeap> pUavHeap, pRtvHeap;
  1349. CComPtr<ID3D12CommandAllocator> pCommandAllocator;
  1350. CComPtr<ID3D12GraphicsCommandList> pCommandList;
  1351. CComPtr<ID3D12PipelineState> pPSO;
  1352. CComPtr<ID3D12Resource> pRenderTarget, pReadBuffer;
  1353. UINT uavDescriptorSize, rtvDescriptorSize;
  1354. CComPtr<ID3D12Resource> pVertexBuffer;
  1355. D3D12_VERTEX_BUFFER_VIEW vertexBufferView;
  1356. if (!CreateDevice(&pDevice))
  1357. return;
  1358. if (!DoesDeviceSupportWaveOps(pDevice)) {
  1359. // Optional feature, so it's correct to not support it if declared as such.
  1360. WEX::Logging::Log::Comment(L"Device does not support wave operations.");
  1361. return;
  1362. }
  1363. FenceObj FO;
  1364. InitFenceObj(pDevice, &FO);
  1365. // Describe and create a UAV descriptor heap.
  1366. D3D12_DESCRIPTOR_HEAP_DESC heapDesc = {};
  1367. heapDesc.NumDescriptors = 1;
  1368. heapDesc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV;
  1369. heapDesc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE;
  1370. VERIFY_SUCCEEDED(pDevice->CreateDescriptorHeap(&heapDesc, IID_PPV_ARGS(&pUavHeap)));
  1371. uavDescriptorSize = pDevice->GetDescriptorHandleIncrementSize(heapDesc.Type);
  1372. CreateRtvDescriptorHeap(pDevice, 1, &pRtvHeap, &rtvDescriptorSize);
  1373. CreateRenderTargetAndReadback(pDevice, pRtvHeap, RTHeight, RTWidth, &pRenderTarget, &pReadBuffer);
  1374. // Create root signature: one UAV.
  1375. CComPtr<ID3D12RootSignature> pRootSignature;
  1376. {
  1377. CD3DX12_DESCRIPTOR_RANGE ranges[1];
  1378. ranges[0].Init(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 1, 0, 0);
  1379. CD3DX12_ROOT_PARAMETER rootParameters[1];
  1380. rootParameters[0].InitAsDescriptorTable(1, &ranges[0], D3D12_SHADER_VISIBILITY_ALL);
  1381. CD3DX12_ROOT_SIGNATURE_DESC rootSignatureDesc;
  1382. rootSignatureDesc.Init(_countof(rootParameters), rootParameters, 0, nullptr, D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT);
  1383. CreateRootSignatureFromDesc(pDevice, &rootSignatureDesc, &pRootSignature);
  1384. }
  1385. D3D12_INPUT_ELEMENT_DESC elementDesc[] = {
  1386. {"POSITION", 0, DXGI_FORMAT_R32G32B32_FLOAT, 0, 0,
  1387. D3D12_INPUT_CLASSIFICATION_PER_VERTEX_DATA, 0}};
  1388. D3D12_INPUT_LAYOUT_DESC InputLayout = {elementDesc, _countof(elementDesc)};
  1389. CreateGraphicsPSO(pDevice, &InputLayout, pRootSignature, pShaders, &pPSO);
  1390. CreateGraphicsCommandQueueAndList(pDevice, &pCommandQueue, &pCommandAllocator,
  1391. &pCommandList, pPSO);
  1392. // Single triangle covering half the target.
  1393. Vertex vertices[] = {
  1394. { { -1.0f, 1.0f, 0.0f } },
  1395. { { 1.0f, 1.0f, 0.0f } },
  1396. { { -1.0f, -1.0f, 0.0f } } };
  1397. const UINT TriangleCount = _countof(vertices) / 3;
  1398. CreateVertexBuffer(pDevice, vertices, &pVertexBuffer, &vertexBufferView);
  1399. bool dxbc = UseDxbc();
  1400. // Set up UAV resource.
  1401. std::vector<PerPixelData> values;
  1402. values.resize(RTWidth * RTHeight * 2);
  1403. UINT valueSizeInBytes = values.size() * sizeof(PerPixelData);
  1404. memset(values.data(), 0, valueSizeInBytes);
  1405. CComPtr<ID3D12Resource> pUavResource;
  1406. CComPtr<ID3D12Resource> pUavReadBuffer;
  1407. CComPtr<ID3D12Resource> pUploadResource;
  1408. CreateTestUavs(pDevice, pCommandList, values.data(), valueSizeInBytes, &pUavResource, &pUavReadBuffer, &pUploadResource);
  1409. // Set up the append counter resource.
  1410. CComPtr<ID3D12Resource> pUavCounterResource;
  1411. CComPtr<ID3D12Resource> pReadCounterBuffer;
  1412. CComPtr<ID3D12Resource> pUploadCounterResource;
  1413. BYTE zero[sizeof(UINT)] = { 0 };
  1414. CreateTestUavs(pDevice, pCommandList, zero, sizeof(zero), &pUavCounterResource, &pReadCounterBuffer, &pUploadCounterResource);
  1415. // Close the command list and execute it to perform the GPU setup.
  1416. pCommandList->Close();
  1417. ExecuteCommandList(pCommandQueue, pCommandList);
  1418. WaitForSignal(pCommandQueue, FO);
  1419. VERIFY_SUCCEEDED(pCommandAllocator->Reset());
  1420. VERIFY_SUCCEEDED(pCommandList->Reset(pCommandAllocator, pPSO));
  1421. pCommandList->SetGraphicsRootSignature(pRootSignature);
  1422. SetDescriptorHeap(pCommandList, pUavHeap);
  1423. {
  1424. D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {};
  1425. uavDesc.Format = DXGI_FORMAT_UNKNOWN;
  1426. uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER;
  1427. uavDesc.Buffer.FirstElement = 0;
  1428. uavDesc.Buffer.NumElements = values.size();
  1429. uavDesc.Buffer.StructureByteStride = sizeof(PerPixelData);
  1430. uavDesc.Buffer.CounterOffsetInBytes = 0;
  1431. uavDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_NONE;
  1432. CD3DX12_CPU_DESCRIPTOR_HANDLE uavHandle(pUavHeap->GetCPUDescriptorHandleForHeapStart());
  1433. CD3DX12_GPU_DESCRIPTOR_HANDLE uavHandleGpu(pUavHeap->GetGPUDescriptorHandleForHeapStart());
  1434. pDevice->CreateUnorderedAccessView(pUavResource, pUavCounterResource, &uavDesc, uavHandle);
  1435. pCommandList->SetGraphicsRootDescriptorTable(0, uavHandleGpu);
  1436. }
  1437. RecordRenderAndReadback(pCommandList, pRtvHeap, rtvDescriptorSize, TriangleCount, &vertexBufferView, nullptr, pRenderTarget, pReadBuffer);
  1438. RecordTransitionBarrier(pCommandList, pUavResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE);
  1439. RecordTransitionBarrier(pCommandList, pUavCounterResource, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, D3D12_RESOURCE_STATE_COPY_SOURCE);
  1440. pCommandList->CopyResource(pUavReadBuffer, pUavResource);
  1441. pCommandList->CopyResource(pReadCounterBuffer, pUavCounterResource);
  1442. VERIFY_SUCCEEDED(pCommandList->Close());
  1443. LogCommentFmt(L"Rendering to %u by %u", RTWidth, RTHeight);
  1444. ExecuteCommandList(pCommandQueue, pCommandList);
  1445. WaitForSignal(pCommandQueue, FO);
  1446. {
  1447. MappedData data(pReadBuffer, RTWidth * RTHeight * 4);
  1448. const uint32_t *pPixels = (uint32_t *)data.data();
  1449. if (SaveImages()) {
  1450. SavePixelsToFile(pPixels, DXGI_FORMAT_R8G8B8A8_UNORM, RTWidth, RTHeight, L"psintrin.bmp");
  1451. }
  1452. }
  1453. uint32_t appendCount;
  1454. {
  1455. MappedData mappedData(pReadCounterBuffer, sizeof(uint32_t));
  1456. appendCount = *((uint32_t *)mappedData.data());
  1457. LogCommentFmt(L"%u elements in append buffer");
  1458. }
  1459. {
  1460. MappedData mappedData(pUavReadBuffer, values.size());
  1461. PerPixelData *pData = (PerPixelData *)mappedData.data();
  1462. memcpy(values.data(), pData, valueSizeInBytes);
  1463. // DXBC is handy to test pipeline setup, but interesting functions are
  1464. // stubbed out, so there is no point in further validation.
  1465. if (dxbc)
  1466. return;
  1467. uint32_t maxActiveLaneCount = 0;
  1468. uint32_t maxLaneCount = 0;
  1469. for (uint32_t i = 0; i < appendCount; ++i) {
  1470. maxActiveLaneCount = std::max(maxActiveLaneCount, values[i].sum1);
  1471. maxLaneCount = std::max(maxLaneCount, values[i].laneCount);
  1472. }
  1473. uint32_t peerOfHelperLanes = 0;
  1474. for (uint32_t i = 0; i < appendCount; ++i) {
  1475. if (values[i].sum1 != maxActiveLaneCount) {
  1476. ++peerOfHelperLanes;
  1477. }
  1478. }
  1479. LogCommentFmt(
  1480. L"Found: %u threads. Waves reported up to %u total lanes, up "
  1481. L"to %u active lanes, and %u threads had helper/inactive lanes.",
  1482. appendCount, maxLaneCount, maxActiveLaneCount, peerOfHelperLanes);
  1483. // Group threads into quad invocations.
  1484. uint32_t singlePixelCount = 0;
  1485. uint32_t multiPixelCount = 0;
  1486. std::unordered_set<uint32_t> ids;
  1487. std::multimap<uint32_t, PerPixelData *> idGroups;
  1488. std::multimap<uint32_t, PerPixelData *> firstIdGroups;
  1489. for (uint32_t i = 0; i < appendCount; ++i) {
  1490. ids.insert(values[i].id);
  1491. idGroups.insert(std::make_pair(values[i].id, &values[i]));
  1492. firstIdGroups.insert(std::make_pair(values[i].firstLaneId, &values[i]));
  1493. }
  1494. for (uint32_t id : ids) {
  1495. if (idGroups.count(id) == 1)
  1496. ++singlePixelCount;
  1497. else
  1498. ++multiPixelCount;
  1499. }
  1500. LogCommentFmt(L"%u pixels were processed by a single thread. %u invocations were for shared pixels.",
  1501. singlePixelCount, multiPixelCount);
  1502. // Multiple threads may have tried to shade the same pixel.
  1503. // Where every pixel is distinct, it's very straightforward to validate.
  1504. {
  1505. auto cur = firstIdGroups.begin(), end = firstIdGroups.end();
  1506. while (cur != end) {
  1507. bool simpleWave = true;
  1508. uint32_t firstId = (*cur).first;
  1509. auto groupEnd = cur;
  1510. while (groupEnd != end && (*groupEnd).first == firstId) {
  1511. if (idGroups.count((*groupEnd).second->id) > 1)
  1512. simpleWave = false;
  1513. ++groupEnd;
  1514. }
  1515. if (simpleWave) {
  1516. // Break the wave into quads.
  1517. struct QuadData {
  1518. unsigned count;
  1519. PerPixelData *data[4];
  1520. };
  1521. std::map<uint32_t, QuadData> quads;
  1522. for (auto i = cur; i != groupEnd; ++i) {
  1523. uint32_t quadId = (*i).second->id0;
  1524. auto match = quads.find(quadId);
  1525. if (match == quads.end()) {
  1526. QuadData qdata;
  1527. qdata.count = 1;
  1528. qdata.data[0] = (*i).second;
  1529. quads.insert(std::make_pair(quadId, qdata));
  1530. }
  1531. else {
  1532. VERIFY_IS_TRUE((*match).second.count < 4);
  1533. (*match).second.data[(*match).second.count++] = (*i).second;
  1534. }
  1535. }
  1536. for (auto quadPair : quads) {
  1537. unsigned count = quadPair.second.count;
  1538. if (count < 2) continue;
  1539. PerPixelData **data = quadPair.second.data;
  1540. bool isTop[4];
  1541. bool isLeft[4];
  1542. PerPixelData helperData;
  1543. memset(&helperData, sizeof(helperData), 0);
  1544. PerPixelData *layout[4]; // tl,tr,bl,br
  1545. memset(layout, sizeof(layout), 0);
  1546. auto fnToLayout = [&](bool top, bool left) -> PerPixelData ** {
  1547. int idx = top ? 0 : 2;
  1548. idx += left ? 0 : 1;
  1549. return &layout[idx];
  1550. };
  1551. auto fnToLayoutData = [&](bool top, bool left) -> PerPixelData * {
  1552. PerPixelData **pResult = fnToLayout(top, left);
  1553. if (*pResult == nullptr) return &helperData;
  1554. return *pResult;
  1555. };
  1556. VERIFY_IS_TRUE(count <= 4);
  1557. if (count == 2) {
  1558. isTop[0] = data[0]->position.y < data[1]->position.y;
  1559. isTop[1] = (data[0]->position.y == data[1]->position.y) ? isTop[0] : !isTop[0];
  1560. isLeft[0] = data[0]->position.x < data[1]->position.x;
  1561. isLeft[1] = (data[0]->position.x == data[1]->position.x) ? isLeft[0] : !isLeft[0];
  1562. }
  1563. else {
  1564. // with at least three samples, we have distinct x and y coordinates.
  1565. float left = std::min(data[0]->position.x, data[1]->position.x);
  1566. left = std::min(data[2]->position.x, left);
  1567. float top = std::min(data[0]->position.y, data[1]->position.y);
  1568. top = std::min(data[2]->position.y, top);
  1569. for (unsigned i = 0; i < count; ++i) {
  1570. isTop[i] = data[i]->position.y == top;
  1571. isLeft[i] = data[i]->position.x == left;
  1572. }
  1573. }
  1574. for (unsigned i = 0; i < count; ++i) {
  1575. *(fnToLayout(isTop[i], isLeft[i])) = data[i];
  1576. }
  1577. // Finally, we have a proper quad reconstructed. Validate.
  1578. for (unsigned i = 0; i < count; ++i) {
  1579. PerPixelData *d = data[i];
  1580. VERIFY_ARE_EQUAL(d->id0, fnToLayoutData(true, true)->id);
  1581. VERIFY_ARE_EQUAL(d->id1, fnToLayoutData(true, false)->id);
  1582. VERIFY_ARE_EQUAL(d->id2, fnToLayoutData(false, true)->id);
  1583. VERIFY_ARE_EQUAL(d->id3, fnToLayoutData(false, false)->id);
  1584. VERIFY_ARE_EQUAL(d->acrossX, fnToLayoutData(isTop[i], !isLeft[i])->id);
  1585. VERIFY_ARE_EQUAL(d->acrossY, fnToLayoutData(!isTop[i], isLeft[i])->id);
  1586. VERIFY_ARE_EQUAL(d->acrossDiag, fnToLayoutData(!isTop[i], !isLeft[i])->id);
  1587. VERIFY_ARE_EQUAL(d->quadActiveCount, count);
  1588. }
  1589. }
  1590. }
  1591. cur = groupEnd;
  1592. }
  1593. }
  1594. // TODO: provide validation for quads where the same pixel was shaded multiple times
  1595. //
  1596. // Consider: for pixels that were shaded multiple times, check whether
  1597. // some grouping of threads into quads satisfies all value requirements.
  1598. }
  1599. }
  1600. struct ShaderOpTestResult {
  1601. st::ShaderOp *ShaderOp;
  1602. std::shared_ptr<st::ShaderOpSet> ShaderOpSet;
  1603. std::shared_ptr<st::ShaderOpTest> Test;
  1604. };
  1605. struct SPrimitives {
  1606. float f_float;
  1607. float f_float2;
  1608. float f_float_o;
  1609. float f_float2_o;
  1610. };
  1611. static float g_SinCosFloats[] = {
  1612. -(INFINITY),
  1613. -1.0f,
  1614. -(FLT_MIN/2),
  1615. -0.0f,
  1616. 0.0f,
  1617. FLT_MIN / 2,
  1618. 1.0f,
  1619. INFINITY,
  1620. NAN
  1621. };
  1622. std::shared_ptr<ShaderOpTestResult>
  1623. RunShaderOpTest(ID3D12Device *pDevice, dxc::DxcDllSupport &support,
  1624. IStream *pStream, LPCSTR pName,
  1625. st::ShaderOpTest::TInitCallbackFn pInitCallback) {
  1626. DXASSERT_NOMSG(pStream != nullptr);
  1627. std::shared_ptr<st::ShaderOpSet> ShaderOpSet =
  1628. std::make_shared<st::ShaderOpSet>();
  1629. st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get());
  1630. st::ShaderOp *pShaderOp;
  1631. if (pName == nullptr) {
  1632. if (ShaderOpSet->ShaderOps.size() != 1) {
  1633. VERIFY_FAIL(L"Expected a single shader operation.");
  1634. }
  1635. pShaderOp = ShaderOpSet->ShaderOps[0].get();
  1636. }
  1637. else {
  1638. pShaderOp = ShaderOpSet->GetShaderOp(pName);
  1639. }
  1640. if (pShaderOp == nullptr) {
  1641. std::string msg = "Unable to find shader op ";
  1642. msg += pName;
  1643. msg += "; available ops";
  1644. const char sep = ':';
  1645. for (auto &pAvailOp : ShaderOpSet->ShaderOps) {
  1646. msg += sep;
  1647. msg += pAvailOp->Name ? pAvailOp->Name : "[n/a]";
  1648. }
  1649. CA2W msgWide(msg.c_str());
  1650. VERIFY_FAIL(msgWide.m_psz);
  1651. }
  1652. std::shared_ptr<st::ShaderOpTest> test = std::make_shared<st::ShaderOpTest>();
  1653. test->SetDxcSupport(&support);
  1654. test->SetInitCallback(pInitCallback);
  1655. test->RunShaderOp(pShaderOp);
  1656. std::shared_ptr<ShaderOpTestResult> result =
  1657. std::make_shared<ShaderOpTestResult>();
  1658. result->ShaderOpSet = ShaderOpSet;
  1659. result->Test = test;
  1660. result->ShaderOp = pShaderOp;
  1661. return result;
  1662. }
  1663. static bool isdenorm(float f) {
  1664. return FP_SUBNORMAL == fpclassify(f);
  1665. }
  1666. static bool isdenorm(double d) {
  1667. return FP_SUBNORMAL == fpclassify(d);
  1668. }
  1669. TEST_F(ExecutionTest, DoShaderOpArithTest) {
  1670. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  1671. CComPtr<IStream> pStream;
  1672. ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
  1673. CComPtr<ID3D12Device> pDevice;
  1674. if (!CreateDevice(&pDevice))
  1675. return;
  1676. // Single operation test at the moment.
  1677. std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTest(pDevice, m_support, pStream, "SinCos",
  1678. [](LPCSTR Name, std::vector<BYTE> &Data) {
  1679. // Initialize the SPrimitives buffer.
  1680. VERIFY_IS_TRUE(0 == _stricmp(Name, "SPrimitives"));
  1681. size_t count = 8 * 8;
  1682. size_t size = sizeof(SPrimitives) * count;
  1683. Data.resize(size);
  1684. SPrimitives *pPrimitives = (SPrimitives *)Data.data();
  1685. for (size_t i = 0; i < count; ++i) {
  1686. SPrimitives *p = &pPrimitives[i];
  1687. p->f_float = g_SinCosFloats[i % _countof(g_SinCosFloats)];
  1688. p->f_float2 = p->f_float;
  1689. }
  1690. });
  1691. MappedData data;
  1692. test->Test->GetReadBackData("SPrimitives", &data);
  1693. // data.dump(); // Uncomment to dump raw bytes from buffer.
  1694. unsigned count = 8 * 8;
  1695. SPrimitives *pPrimitives = (SPrimitives *)data.data();
  1696. WEX::TestExecution::DisableVerifyExceptions dve;
  1697. static const float Error = 0.0008f;
  1698. for (unsigned i = 0; i < count; ++i) {
  1699. SPrimitives *p = &pPrimitives[i];
  1700. float input = p->f_float;
  1701. float sin_o = p->f_float_o;
  1702. float cos_o = p->f_float2_o;
  1703. LogCommentFmt(L"Element #%u, input %f, sin=%f, cos=%f", i, input, sin_o, cos_o);
  1704. if (isinf(input)) {
  1705. VERIFY_IS_TRUE(isnan(sin_o));
  1706. VERIFY_IS_TRUE(isnan(cos_o));
  1707. }
  1708. else if (isnan(input)) {
  1709. VERIFY_IS_TRUE(isnan(sin_o));
  1710. VERIFY_IS_TRUE(isnan(cos_o));
  1711. }
  1712. else if (isdenorm(input)) {
  1713. VERIFY_IS_TRUE(1.0f == cos_o);
  1714. if (signbit(input)) {
  1715. VERIFY_IS_TRUE(-0.0f == sin_o);
  1716. }
  1717. else {
  1718. VERIFY_IS_TRUE(0.0f == sin_o);
  1719. }
  1720. }
  1721. else if (input == 0.0f) {
  1722. VERIFY_IS_TRUE(0.0f == sin_o);
  1723. VERIFY_IS_TRUE(1.0f == cos_o);
  1724. }
  1725. else if (input == -0.0f) {
  1726. VERIFY_IS_TRUE(-0.0f == sin_o);
  1727. VERIFY_IS_TRUE(1.0f == cos_o);
  1728. }
  1729. else {
  1730. float f_sin = sin(input);
  1731. float f_cos = cos(input);
  1732. VERIFY_IS_TRUE((f_sin - Error) <= sin_o && sin_o <= (f_sin + Error));
  1733. VERIFY_IS_TRUE((f_cos - Error) <= cos_o && cos_o <= (f_cos + Error));
  1734. }
  1735. }
  1736. }
  1737. static float ifdenorm_flushf(float a) {
  1738. return isdenorm(a) ? copysign(0.0f, a) : a;
  1739. }
  1740. static bool ifdenorm_flushf_eq(float a, float b) {
  1741. return ifdenorm_flushf(a) == ifdenorm_flushf(b);
  1742. }
  1743. static bool ifdenorm_flushf_eq_or_nans(float a, float b) {
  1744. if (isnan(a) && isnan(b)) return true;
  1745. return ifdenorm_flushf(a) == ifdenorm_flushf(b);
  1746. }
  1747. TEST_F(ExecutionTest, MinMaxTest) {
  1748. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  1749. CComPtr<IStream> pStream;
  1750. ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
  1751. struct SMinMaxElem {
  1752. float f_fa;
  1753. float f_fb;
  1754. float f_fmin_o;
  1755. float f_fmax_o;
  1756. };
  1757. float TestValues[] = {
  1758. -(INFINITY),
  1759. -1.0f,
  1760. -(FLT_MIN/2),
  1761. -0.0f,
  1762. 0.0f,
  1763. FLT_MIN / 2,
  1764. 1.0f,
  1765. INFINITY,
  1766. NAN
  1767. };
  1768. // Single operation test at the moment.
  1769. CComPtr<ID3D12Device> pDevice;
  1770. if (!CreateDevice(&pDevice))
  1771. return;
  1772. std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTest(pDevice, m_support, pStream, "MinMax",
  1773. [&TestValues](LPCSTR Name, std::vector<BYTE> &Data) {
  1774. // Initialize the SPrimitives buffer.
  1775. VERIFY_IS_TRUE(0 == _stricmp(Name, "SPrimitives"));
  1776. size_t count = 10 * 10;
  1777. size_t size = sizeof(SMinMaxElem) * count;
  1778. Data.resize(size);
  1779. SMinMaxElem *pElems = (SMinMaxElem *)Data.data();
  1780. for (size_t a = 0; a < 10; ++a) {
  1781. float fa = TestValues[a % _countof(TestValues)];
  1782. for (size_t b = 0; b < 10; ++b) {
  1783. SMinMaxElem *p = &pElems[a * 10 + b];
  1784. ZeroMemory(p, sizeof(*p));
  1785. p->f_fa = fa;
  1786. p->f_fb = TestValues[b % _countof(TestValues)];
  1787. }
  1788. }
  1789. });
  1790. MappedData data;
  1791. test->Test->GetReadBackData("SPrimitives", &data);
  1792. // data.dump(); // Uncomment to dump raw bytes from buffer.
  1793. unsigned count = 10 * 10;
  1794. SMinMaxElem *pPrimitives = (SMinMaxElem *)data.data();
  1795. WEX::TestExecution::DisableVerifyExceptions dve;
  1796. static const float Error = 0.0008f;
  1797. for (unsigned i = 0; i < count; ++i) {
  1798. SMinMaxElem *p = &pPrimitives[i];
  1799. float fa = p->f_fa;
  1800. float fb = p->f_fb;
  1801. float fmin = p->f_fmin_o;
  1802. float fmax = p->f_fmax_o;
  1803. LogCommentFmt(L"Element #%u, a %f, b %f, min=%f, max=%f", i, fa, fb, fmin, fmax);
  1804. if (isnan(fa)) {
  1805. VERIFY_IS_TRUE(ifdenorm_flushf_eq_or_nans(fmin, fb));
  1806. VERIFY_IS_TRUE(ifdenorm_flushf_eq_or_nans(fmax, fb));
  1807. }
  1808. else if (isnan(fb)) {
  1809. VERIFY_IS_TRUE(ifdenorm_flushf_eq_or_nans(fmin, fa));
  1810. VERIFY_IS_TRUE(ifdenorm_flushf_eq_or_nans(fmax, fa));
  1811. }
  1812. else {
  1813. // Flushing is allowed - check both cases.
  1814. float fmax_0 = fa >= fb ? fa : fb;
  1815. float fmax_1 = ifdenorm_flushf(fmax_0);
  1816. VERIFY_IS_TRUE(fmax == fmax_0 || fmax == fmax_1);
  1817. float fmin_0 = fa < fb ? fa : fb;
  1818. float fmin_1 = ifdenorm_flushf(fmin_0);
  1819. VERIFY_IS_TRUE(fmin == fmin_0 || fmin == fmin_1);
  1820. }
  1821. }
  1822. }
  1823. TEST_F(ExecutionTest, OutOfBoundsTest) {
  1824. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  1825. CComPtr<IStream> pStream;
  1826. ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
  1827. // Single operation test at the moment.
  1828. CComPtr<ID3D12Device> pDevice;
  1829. if (!CreateDevice(&pDevice))
  1830. return;
  1831. std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTest(pDevice, m_support, pStream, "OOB", nullptr);
  1832. MappedData data;
  1833. // Read back to CPU and examine contents - should get pure red.
  1834. {
  1835. MappedData data;
  1836. test->Test->GetReadBackData("RTarget", &data);
  1837. const uint32_t *pPixels = (uint32_t *)data.data();
  1838. uint32_t first = *pPixels;
  1839. VERIFY_ARE_EQUAL(0xff0000ff, first); // pure red - only first component is read
  1840. }
  1841. }
  1842. TEST_F(ExecutionTest, SaturateTest) {
  1843. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  1844. CComPtr<IStream> pStream;
  1845. ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
  1846. // Single operation test at the moment.
  1847. CComPtr<ID3D12Device> pDevice;
  1848. if (!CreateDevice(&pDevice))
  1849. return;
  1850. std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTest(pDevice, m_support, pStream, "Saturate", nullptr);
  1851. MappedData data;
  1852. test->Test->GetReadBackData("U0", &data);
  1853. const float *pValues = (float *)data.data();
  1854. // Everything is zero except for 1.5f and +Inf, which saturate to 1.0f
  1855. const float ExpectedCases[9] = {
  1856. 0.0f, 0.0f, 0.0f, 0.0f, // -inf, -1.5, -denorm, -0
  1857. 0.0f, 0.0f, 1.0f, 1.0f, // 0, denorm, 1.5f, inf
  1858. 0.0f // nan
  1859. };
  1860. for (size_t i = 0; i < _countof(ExpectedCases); ++i) {
  1861. VERIFY_ARE_EQUAL(*pValues, ExpectedCases[i]);
  1862. ++pValues;
  1863. }
  1864. }
  1865. TEST_F(ExecutionTest, BasicTriangleOpTest) {
  1866. WEX::TestExecution::SetVerifyOutput verifySettings(WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
  1867. CComPtr<IStream> pStream;
  1868. ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream);
  1869. // Single operation test at the moment.
  1870. CComPtr<ID3D12Device> pDevice;
  1871. if (!CreateDevice(&pDevice))
  1872. return;
  1873. std::shared_ptr<ShaderOpTestResult> test = RunShaderOpTest(pDevice, m_support, pStream, "Triangle", nullptr);
  1874. MappedData data;
  1875. D3D12_RESOURCE_DESC &D = test->ShaderOp->GetResourceByName("RTarget")->Desc;
  1876. UINT width = (UINT64)D.Width;
  1877. UINT height = (UINT64)D.Height;
  1878. test->Test->GetReadBackData("RTarget", &data);
  1879. const uint32_t *pPixels = (uint32_t *)data.data();
  1880. if (SaveImages()) {
  1881. SavePixelsToFile(pPixels, DXGI_FORMAT_R8G8B8A8_UNORM, 320, 200, L"basic.bmp");
  1882. }
  1883. uint32_t top = pPixels[width / 2]; // Top center.
  1884. uint32_t mid = pPixels[width / 2 + width * (height / 2)]; // Middle center.
  1885. VERIFY_ARE_EQUAL(0xff663300, top); // clear color
  1886. VERIFY_ARE_EQUAL(0xffffffff, mid); // white
  1887. // This is the basic validation test for shader operations, so it's good to
  1888. // check this here at least for this one test case.
  1889. data.reset();
  1890. test.reset();
  1891. ReportLiveObjects();
  1892. }
  1893. static void WriteReadBackDump(st::ShaderOp *pShaderOp, st::ShaderOpTest *pTest,
  1894. char **pReadBackDump) {
  1895. std::stringstream str;
  1896. unsigned count = 0;
  1897. for (auto &R : pShaderOp->Resources) {
  1898. if (!R.ReadBack)
  1899. continue;
  1900. ++count;
  1901. str << "Resource: " << R.Name << "\r\n";
  1902. // Find a descriptor that can tell us how to dump this resource.
  1903. bool found = false;
  1904. for (auto &Heaps : pShaderOp->DescriptorHeaps) {
  1905. for (auto &D : Heaps.Descriptors) {
  1906. if (_stricmp(D.ResName, R.Name) != 0) {
  1907. continue;
  1908. }
  1909. found = true;
  1910. if (_stricmp(D.Kind, "UAV") != 0) {
  1911. str << "Resource dump for kind " << D.Kind << " not implemented yet.\r\n";
  1912. break;
  1913. }
  1914. if (D.UavDesc.ViewDimension != D3D12_UAV_DIMENSION_BUFFER) {
  1915. str << "Resource dump for this kind of view dimension not implemented yet.\r\n";
  1916. break;
  1917. }
  1918. // We can map back to the structure if a structured buffer via the shader, but
  1919. // we'll keep this simple and simply dump out 32-bit uint/float representations.
  1920. MappedData data;
  1921. pTest->GetReadBackData(R.Name, &data);
  1922. uint32_t *pData = (uint32_t *)data.data();
  1923. size_t u32_count = R.Desc.Width / sizeof(uint32_t);
  1924. for (size_t i = 0; i < u32_count; ++i) {
  1925. float f = *(float *)pData;
  1926. str << i << ": 0n" << *pData << " 0x" << std::hex << *pData
  1927. << std::dec << " " << f << "\r\n";
  1928. ++pData;
  1929. }
  1930. break;
  1931. }
  1932. if (found) break;
  1933. }
  1934. if (!found) {
  1935. str << "Unable to find a view for the resource.\r\n";
  1936. }
  1937. }
  1938. str << "Resources read back: " << count << "\r\n";
  1939. std::string s(str.str());
  1940. CComHeapPtr<char> pDump;
  1941. if (!pDump.Allocate(s.size() + 1))
  1942. throw std::bad_alloc();
  1943. memcpy(pDump.m_pData, s.data(), s.size());
  1944. pDump.m_pData[s.size()] = '\0';
  1945. *pReadBackDump = pDump.Detach();
  1946. }
  1947. // This is the exported interface by use from HLSLHost.exe.
  1948. // It's exclusive with the use of the DLL as a TAEF target.
  1949. extern "C" {
  1950. __declspec(dllexport) HRESULT WINAPI InitializeOpTests(void *pStrCtx, st::OutputStringFn pOutputStrFn) {
  1951. HRESULT hr = EnableExperimentalShaderModels();
  1952. if (FAILED(hr)) {
  1953. pOutputStrFn(pStrCtx, L"Unable to enable experimental shader models.\r\n.");
  1954. }
  1955. return S_OK;
  1956. }
  1957. __declspec(dllexport) HRESULT WINAPI
  1958. RunOpTest(void *pStrCtx, st::OutputStringFn pOutputStrFn, LPCSTR pText,
  1959. ID3D12Device *pDevice, ID3D12CommandQueue *pCommandQueue,
  1960. ID3D12Resource *pRenderTarget, char **pReadBackDump) {
  1961. HRESULT hr;
  1962. if (pReadBackDump) *pReadBackDump = nullptr;
  1963. st::SetOutputFn(pStrCtx, pOutputStrFn);
  1964. CComPtr<ID3D12InfoQueue> pInfoQueue;
  1965. CComHeapPtr<char> pDump;
  1966. bool FilterCreation = false;
  1967. if (SUCCEEDED(pDevice->QueryInterface(&pInfoQueue))) {
  1968. // Creation is largely driven by inputs, so don't log create/destroy messages.
  1969. pInfoQueue->PushEmptyStorageFilter();
  1970. pInfoQueue->PushEmptyRetrievalFilter();
  1971. if (FilterCreation) {
  1972. D3D12_INFO_QUEUE_FILTER filter;
  1973. D3D12_MESSAGE_CATEGORY denyCategories[] = { D3D12_MESSAGE_CATEGORY_STATE_CREATION };
  1974. ZeroMemory(&filter, sizeof(filter));
  1975. filter.DenyList.NumCategories = _countof(denyCategories);
  1976. filter.DenyList.pCategoryList = denyCategories;
  1977. pInfoQueue->PushStorageFilter(&filter);
  1978. }
  1979. }
  1980. else {
  1981. pOutputStrFn(pStrCtx, L"Unable to enable info queue for D3D.\r\n.");
  1982. }
  1983. try {
  1984. dxc::DxcDllSupport m_support;
  1985. m_support.Initialize();
  1986. const char *pName = nullptr;
  1987. CComPtr<IStream> pStream = SHCreateMemStream((BYTE *)pText, strlen(pText));
  1988. std::shared_ptr<st::ShaderOpSet> ShaderOpSet =
  1989. std::make_shared<st::ShaderOpSet>();
  1990. st::ParseShaderOpSetFromStream(pStream, ShaderOpSet.get());
  1991. st::ShaderOp *pShaderOp;
  1992. if (pName == nullptr) {
  1993. if (ShaderOpSet->ShaderOps.size() != 1) {
  1994. pOutputStrFn(pStrCtx, L"Expected a single shader operation.\r\n");
  1995. return E_FAIL;
  1996. }
  1997. pShaderOp = ShaderOpSet->ShaderOps[0].get();
  1998. }
  1999. else {
  2000. pShaderOp = ShaderOpSet->GetShaderOp(pName);
  2001. }
  2002. if (pShaderOp == nullptr) {
  2003. std::string msg = "Unable to find shader op ";
  2004. msg += pName;
  2005. msg += "; available ops";
  2006. const char sep = ':';
  2007. for (auto &pAvailOp : ShaderOpSet->ShaderOps) {
  2008. msg += sep;
  2009. msg += pAvailOp->Name ? pAvailOp->Name : "[n/a]";
  2010. }
  2011. CA2W msgWide(msg.c_str());
  2012. pOutputStrFn(pStrCtx, msgWide);
  2013. return E_FAIL;
  2014. }
  2015. std::shared_ptr<st::ShaderOpTest> test = std::make_shared<st::ShaderOpTest>();
  2016. test->SetupRenderTarget(pShaderOp, pDevice, pCommandQueue, pRenderTarget);
  2017. test->SetDxcSupport(&m_support);
  2018. test->RunShaderOp(pShaderOp);
  2019. test->PresentRenderTarget(pShaderOp, pCommandQueue, pRenderTarget);
  2020. pOutputStrFn(pStrCtx, L"Rendering complete.\r\n");
  2021. if (!pShaderOp->IsCompute()) {
  2022. D3D12_QUERY_DATA_PIPELINE_STATISTICS stats;
  2023. test->GetPipelineStats(&stats);
  2024. wchar_t statsText[400];
  2025. StringCchPrintfW(statsText, _countof(statsText),
  2026. L"Vertices/primitives read by input assembler: %I64u/%I64u\r\n"
  2027. L"Vertex shader invocations: %I64u\r\n"
  2028. L"Geometry shader invocations/output primitive: %I64u/%I64u\r\n"
  2029. L"Primitives sent to rasterizer/rendered: %I64u/%I64u\r\n"
  2030. L"PS/HS/DS/CS invocations: %I64u/%I64u/%I64u/%I64u\r\n",
  2031. stats.IAVertices, stats.IAPrimitives, stats.VSInvocations,
  2032. stats.GSInvocations, stats.GSPrimitives, stats.CInvocations,
  2033. stats.CPrimitives, stats.PSInvocations, stats.HSInvocations,
  2034. stats.DSInvocations, stats.CSInvocations);
  2035. pOutputStrFn(pStrCtx, statsText);
  2036. }
  2037. if (pReadBackDump) {
  2038. WriteReadBackDump(pShaderOp, test.get(), &pDump);
  2039. }
  2040. hr = S_OK;
  2041. }
  2042. catch (const CAtlException &E)
  2043. {
  2044. hr = E.m_hr;
  2045. }
  2046. catch (const std::bad_alloc &)
  2047. {
  2048. hr = E_OUTOFMEMORY;
  2049. }
  2050. catch (const std::exception &)
  2051. {
  2052. hr = E_FAIL;
  2053. }
  2054. // Drain the device message queue if available.
  2055. if (pInfoQueue != nullptr) {
  2056. wchar_t buf[200];
  2057. StringCchPrintfW(buf, _countof(buf),
  2058. L"NumStoredMessages=%u limit/discarded by limit=%u/%u "
  2059. L"allowed/denied by storage filter=%u/%u "
  2060. L"NumStoredMessagesAllowedByRetrievalFilter=%u\r\n",
  2061. (unsigned)pInfoQueue->GetNumStoredMessages(),
  2062. (unsigned)pInfoQueue->GetMessageCountLimit(),
  2063. (unsigned)pInfoQueue->GetNumMessagesDiscardedByMessageCountLimit(),
  2064. (unsigned)pInfoQueue->GetNumMessagesAllowedByStorageFilter(),
  2065. (unsigned)pInfoQueue->GetNumMessagesDeniedByStorageFilter(),
  2066. (unsigned)pInfoQueue->GetNumStoredMessagesAllowedByRetrievalFilter());
  2067. pOutputStrFn(pStrCtx, buf);
  2068. WriteInfoQueueMessages(pStrCtx, pOutputStrFn, pInfoQueue);
  2069. pInfoQueue->ClearStoredMessages();
  2070. pInfoQueue->PopRetrievalFilter();
  2071. pInfoQueue->PopStorageFilter();
  2072. if (FilterCreation) {
  2073. pInfoQueue->PopStorageFilter();
  2074. }
  2075. }
  2076. if (pReadBackDump) *pReadBackDump = pDump.Detach();
  2077. return hr;
  2078. }
  2079. }