2
0

HLSLHost.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873
  1. ///////////////////////////////////////////////////////////////////////////////
  2. // //
  3. // HLSLHost.cpp //
  4. // Copyright (C) Microsoft Corporation. All rights reserved. //
  5. // This file is distributed under the University of Illinois Open Source //
  6. // License. See LICENSE.TXT for details. //
  7. // //
  8. // Provides a Win32 application that can act as a host for HLSL programs. //
  9. // //
  10. ///////////////////////////////////////////////////////////////////////////////
  11. #include "dxc/Support/Global.h"
  12. #include "dxc/Support/WinIncludes.h"
  13. #include "dxc/Support/FileIOHelper.h"
  14. #include "dxc/Support/microcom.h"
  15. #include <vector>
  16. #include <string>
  17. #include <comdef.h>
  18. #include <algorithm>
  19. #include <unordered_map>
  20. #include <d3d12.h>
  21. #include <dxgi1_4.h>
  22. #include <atlcoll.h>
  23. #pragma comment(lib, "dxgi.lib")
  24. #pragma comment(lib, "d3d12.lib")
  25. #if 0
  26. Pending work for rendering hosting.
  27. - Pass width / height information to the test DLL.
  28. - Clean up all TODOs markers.
  29. #endif
  30. // Forward declarations.
  31. class ServerFactory;
  32. // RAII helpers.
  33. class HhEvent {
  34. public:
  35. HANDLE m_handle = INVALID_HANDLE_VALUE;
  36. public:
  37. HRESULT Init() {
  38. if (m_handle == INVALID_HANDLE_VALUE) {
  39. m_handle = CreateEvent(nullptr, TRUE, FALSE, nullptr);
  40. if (m_handle == INVALID_HANDLE_VALUE) {
  41. return HRESULT_FROM_WIN32(GetLastError());
  42. }
  43. }
  44. return S_OK;
  45. }
  46. void SetEvent() {
  47. ::SetEvent(m_handle);
  48. }
  49. void ResetEvent() {
  50. ::ResetEvent(m_handle);
  51. }
  52. ~HhEvent() {
  53. if (m_handle != INVALID_HANDLE_VALUE) {
  54. CloseHandle(m_handle);
  55. }
  56. }
  57. };
  58. class HhCriticalSection {
  59. private:
  60. CRITICAL_SECTION m_cs;
  61. public:
  62. HhCriticalSection() { InitializeCriticalSection(&m_cs); }
  63. ~HhCriticalSection() { DeleteCriticalSection(&m_cs); }
  64. class Lock {
  65. private:
  66. CRITICAL_SECTION *m_pLock;
  67. public:
  68. Lock() = delete;
  69. Lock(const Lock&) = delete;
  70. Lock(Lock &&other) { std::swap(m_pLock, other.m_pLock); }
  71. Lock(CRITICAL_SECTION *pLock) : m_pLock(pLock) {
  72. EnterCriticalSection(m_pLock);
  73. }
  74. ~Lock() {
  75. if (m_pLock) LeaveCriticalSection(m_pLock);
  76. }
  77. };
  78. Lock LockCS() {
  79. return Lock(&m_cs);
  80. }
  81. };
  82. class ClassObjectRegistration {
  83. private:
  84. DWORD m_reg = 0;
  85. HRESULT m_hr = E_FAIL;
  86. public:
  87. HRESULT Register(REFCLSID rclsid, IUnknown *pUnk, DWORD dwClsContext, DWORD flags) {
  88. m_hr = CoRegisterClassObject(rclsid, pUnk, dwClsContext, flags, &m_reg);
  89. return m_hr;
  90. }
  91. ~ClassObjectRegistration() {
  92. if (SUCCEEDED(m_hr)) CoRevokeClassObject(m_reg);
  93. }
  94. };
  95. class CoInit {
  96. private:
  97. HRESULT m_hr = E_FAIL;
  98. public:
  99. HRESULT Initialize(DWORD dwCoInit) {
  100. m_hr = CoInitializeEx(nullptr, dwCoInit);
  101. return m_hr;
  102. }
  103. ~CoInit() {
  104. if (SUCCEEDED(m_hr)) CoUninitialize();
  105. }
  106. };
  107. // Globals.
  108. static DWORD g_ProcessLocks;
  109. static const GUID CLSID_HLSLHostServer = // {7FD7A859-6C6B-4352-8F1E-C67BB62E774B}
  110. { 0x7fd7a859, 0x6c6b, 0x4352,{ 0x8f, 0x1e, 0xc6, 0x7b, 0xb6, 0x2e, 0x77, 0x4b } };
  111. static HhEvent g_ShutdownServerEvent;
  112. static const DWORD GetPidMsgId = 1;
  113. static const DWORD ShutdownMsgId = 2;
  114. static const DWORD StartRendererMsgId = 3;
  115. static const DWORD StopRendererMsgId = 4;
  116. static const DWORD SetPayloadMsgId = 5;
  117. static const DWORD ReadLogMsgId = 6;
  118. static const DWORD SetSizeMsgId = 7;
  119. static const DWORD SetParentWndMsgId = 8;
  120. static const DWORD GetPidMsgReplyId = 100 + GetPidMsgId;
  121. static const DWORD StartRendererMsgReplyId = 100 + StartRendererMsgId;
  122. static const DWORD StopRendererMsgReplyId = 100 + StopRendererMsgId;
  123. static const DWORD SetPayloadMsgReplyId = 100 + SetPayloadMsgId;
  124. static const DWORD ReadLogMsgReplyId = 100 + ReadLogMsgId;
  125. static const DWORD SetSizeMsgReplyId = 100 + SetSizeMsgId;
  126. static const DWORD SetParentWndMsgReplyId = 100 + SetParentWndMsgId;
  127. struct HhMessageHeader {
  128. DWORD Length;
  129. DWORD Kind;
  130. };
  131. struct HhGetPidMessage {
  132. HhMessageHeader Header;
  133. };
  134. struct HhSetSizeMessage {
  135. HhMessageHeader Header;
  136. DWORD Width;
  137. DWORD Height;
  138. };
  139. struct HhSetParentWndMessage {
  140. HhMessageHeader Header;
  141. UINT64 Handle;
  142. };
  143. struct HhGetPidReply {
  144. HhMessageHeader Header;
  145. DWORD Pid;
  146. };
  147. struct HhResultReply {
  148. HhMessageHeader Header;
  149. HRESULT hr;
  150. };
  151. // Logging and tracing.
  152. static void HhTrace(LPWSTR pMessage) {
  153. wprintf(L"%s\n", pMessage);
  154. }
  155. template <typename TInterface, typename TObject>
  156. HRESULT DoBasicQueryInterfaceWithRemote(TObject* self, REFIID iid, void** ppvObject)
  157. {
  158. if (ppvObject == nullptr) return E_POINTER;
  159. // Support INoMarshal to void GIT shenanigans.
  160. if (IsEqualIID(iid, __uuidof(IUnknown))) {
  161. *ppvObject = reinterpret_cast<IUnknown*>(self);
  162. reinterpret_cast<IUnknown*>(self)->AddRef();
  163. return S_OK;
  164. }
  165. if (IsEqualIID(iid, __uuidof(TInterface))) {
  166. *(TInterface**)ppvObject = self;
  167. self->AddRef();
  168. return S_OK;
  169. }
  170. return E_NOINTERFACE;
  171. }
  172. // Rendering.
  173. ATOM g_RenderingWindowClass;
  174. HhCriticalSection g_RenderingWindowClassCS;
  175. LRESULT CALLBACK RendererWndProc(HWND, UINT, WPARAM, LPARAM);
  176. DWORD WINAPI RendererStart(LPVOID lpThreadParameter);
  177. void __stdcall RendererLog(void *pRenderer, const wchar_t *pText);
  178. typedef void(__stdcall *OutputStringFn)(void *pCtx, const wchar_t *pText);
  179. typedef HRESULT(WINAPI *InitOpTestFn)(void *pStrCtx, OutputStringFn pOutputStrFn);
  180. typedef HRESULT(WINAPI *RunOpTestFn)(void *pStrCtx, OutputStringFn pOutputStrFn,
  181. LPCSTR pText, ID3D12Device *pDevice,
  182. ID3D12CommandQueue *pCommandQueue,
  183. ID3D12Resource *pRenderTarget,
  184. char **pReadBackDump);
  185. #define WM_RENDERER_SETPAYLOAD (WM_USER)
  186. #define WM_RENDERER_QUIT (WM_USER + 1)
  187. class Renderer {
  188. private:
  189. // This state is accessed by a messaging thread.
  190. DWORD m_tid = 0;
  191. HANDLE m_thread = nullptr;
  192. // This state is used to coordinate the messaging and the rendering threads.
  193. HWND m_hwnd = nullptr;
  194. HhEvent m_threadReady;
  195. HRESULT m_threadStartResult = E_PENDING;
  196. UINT m_height = 0;
  197. UINT m_width = 0;
  198. OutputStringFn m_pLog;
  199. void *m_pLogCtx;
  200. bool m_userQuit = false;
  201. // This state is used by the rendering thread.
  202. CComPtr<ID3D12Device> m_device;
  203. CComPtr<ID3D12CommandQueue> m_commandQueue;
  204. CComPtr<IDXGISwapChain3> m_swapChain;
  205. UINT FrameCount = 2;
  206. UINT m_TargetDeviceIndex = 0;
  207. UINT m_frameIndex;
  208. UINT m_renderCount = 0;
  209. HMODULE m_TestDLL = NULL;
  210. RunOpTestFn m_pRunOpTestFn = nullptr;
  211. InitOpTestFn m_pInitOpTestFn = nullptr;
  212. LPVOID m_ShaderOpText = nullptr;
  213. CComHeapPtr<char> m_ResourceViewText;
  214. HRESULT LoadTestDll() {
  215. if (m_TestDLL == NULL) {
  216. m_TestDLL = LoadLibrary("clang-hlsl-tests.dll");
  217. m_pRunOpTestFn = (RunOpTestFn)GetProcAddress(m_TestDLL, "RunOpTest");
  218. m_pInitOpTestFn = (InitOpTestFn)GetProcAddress(m_TestDLL, "InitializeOpTests");
  219. HRESULT hrInit = m_pInitOpTestFn(this, RendererLog);
  220. if (FAILED(hrInit)) {
  221. CloseHandle(m_TestDLL);
  222. m_TestDLL = nullptr;
  223. return hrInit;
  224. }
  225. }
  226. return S_OK;
  227. }
  228. void HandleCopy() {
  229. LPCSTR OpMessage;
  230. UINT MessageType = MB_ICONERROR;
  231. if (m_ResourceViewText.m_pData == nullptr) {
  232. OpMessage = "No resources read back from a prior render.";
  233. }
  234. else {
  235. OpMessage = "Unable to copy resource data to clipboard.";
  236. if (OpenClipboard(m_hwnd)) {
  237. if (EmptyClipboard()) {
  238. HGLOBAL hMem = GlobalAlloc(GMEM_MOVEABLE, 1 + strlen(m_ResourceViewText.m_pData));
  239. if (hMem) {
  240. LPSTR pCopy = (LPSTR)GlobalLock(hMem);
  241. strcpy(pCopy, m_ResourceViewText.m_pData);
  242. GlobalUnlock(hMem);
  243. SetClipboardData(CF_TEXT, hMem);
  244. OpMessage = "Read back resources copied to clipboard.";
  245. MessageType = MB_ICONINFORMATION;
  246. }
  247. }
  248. CloseClipboard();
  249. }
  250. }
  251. MessageBox(m_hwnd, OpMessage, "Resource Copy", MessageType);
  252. }
  253. void HandleDeviceCycle() {
  254. ReleaseD3DResources();
  255. ++m_TargetDeviceIndex;
  256. SetupD3D();
  257. }
  258. void HandleHelp() {
  259. MessageBoxW(m_hwnd, L"Commands:\r\n"
  260. L"(C)opy Results\r\n"
  261. L"(D)evice Cycle\r\n"
  262. L"(H)elp (show this message)\r\n"
  263. L"(R)un\r\n"
  264. L"(Q)uit",
  265. L"HLSL Host Help", MB_OK);
  266. }
  267. HRESULT HandlePayload() {
  268. CComPtr<ID3D12Resource> pRT;
  269. m_frameIndex = m_swapChain->GetCurrentBackBufferIndex();
  270. IFR(m_swapChain->GetBuffer(m_frameIndex, IID_PPV_ARGS(&pRT)));
  271. wchar_t ResName[32];
  272. StringCchPrintfW(ResName, _countof(ResName), L"SwapChain Buffer #%u", m_frameIndex);
  273. pRT->SetName(ResName);
  274. StringCchPrintfW(ResName, _countof(ResName), L"Render %u\r\n", ++m_renderCount);
  275. Log(ResName);
  276. m_ResourceViewText.Free();
  277. LPSTR pText = (LPSTR)InterlockedExchangePointer(&m_ShaderOpText, nullptr);
  278. HRESULT hr = m_pRunOpTestFn(this, RendererLog, pText, m_device, m_commandQueue, pRT, &m_ResourceViewText);
  279. // If we can restore it, we're set; otherwise we should delete our stale copy.
  280. if (nullptr != InterlockedCompareExchangePointer(&m_ShaderOpText, pText, nullptr))
  281. free(pText);
  282. wchar_t ErrMsg[64];
  283. if (FAILED(hr)) {
  284. StringCchPrintfW(ErrMsg, _countof(ErrMsg), L"Shader operation failed: 0x%08x\r\n", hr);
  285. Log(ErrMsg);
  286. return hr;
  287. }
  288. hr = m_swapChain->Present(1, 0);
  289. if (FAILED(hr)) {
  290. StringCchPrintfW(ErrMsg, _countof(ErrMsg), L"Present failed: 0x%08x\r\n", hr);
  291. Log(ErrMsg);
  292. return hr;
  293. }
  294. return S_OK;
  295. }
  296. void ReleaseD3DResources() {
  297. m_device.Release();
  298. m_commandQueue.Release();
  299. m_swapChain.Release();
  300. }
  301. HRESULT SetupD3D() {
  302. IFR(LoadTestDll());
  303. CComPtr<ID3D12Debug> debugController;
  304. if (SUCCEEDED(D3D12GetDebugInterface(IID_PPV_ARGS(&debugController)))) {
  305. debugController->EnableDebugLayer();
  306. }
  307. CComPtr<IDXGIFactory4> factory;
  308. IFR(CreateDXGIFactory1(IID_PPV_ARGS(&factory)));
  309. CComPtr<IDXGIAdapter> adapter;
  310. if (m_TargetDeviceIndex > 0) {
  311. UINT hardwareIndex = m_TargetDeviceIndex - 1;
  312. HRESULT hrEnum = factory->EnumAdapters(hardwareIndex, &adapter);
  313. if (hrEnum == DXGI_ERROR_NOT_FOUND) {
  314. m_TargetDeviceIndex = 0; // cycle to WARP
  315. }
  316. else {
  317. IFR(hrEnum);
  318. }
  319. }
  320. if (m_TargetDeviceIndex == 0) {
  321. IFR(factory->EnumWarpAdapter(IID_PPV_ARGS(&adapter)));
  322. }
  323. HRESULT hrCreate = D3D12CreateDevice(adapter, D3D_FEATURE_LEVEL_11_0,
  324. IID_PPV_ARGS(&m_device));
  325. IFR(SetWindowTextToDevice(hrCreate, m_hwnd, adapter, m_device));
  326. IFR(hrCreate);
  327. // Describe and create the command queue.
  328. D3D12_COMMAND_QUEUE_DESC queueDesc = {};
  329. queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
  330. queueDesc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT;
  331. IFR(m_device->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&m_commandQueue)));
  332. // Describe and create the swap chain.
  333. DXGI_SWAP_CHAIN_DESC1 swapChainDesc = {};
  334. swapChainDesc.BufferCount = FrameCount;
  335. swapChainDesc.Width = m_width;
  336. swapChainDesc.Height = m_height;
  337. swapChainDesc.Format = DXGI_FORMAT_R8G8B8A8_UNORM;
  338. swapChainDesc.BufferUsage = DXGI_USAGE_RENDER_TARGET_OUTPUT;
  339. swapChainDesc.SwapEffect = DXGI_SWAP_EFFECT_FLIP_DISCARD;
  340. swapChainDesc.SampleDesc.Count = 1;
  341. CComPtr<IDXGISwapChain1> swapChain;
  342. IFR(factory->CreateSwapChainForHwnd(m_commandQueue, m_hwnd, &swapChainDesc,
  343. nullptr, nullptr, &swapChain));
  344. // Do not support fullscreen transitions.
  345. IFR(factory->MakeWindowAssociation(m_hwnd, DXGI_MWA_NO_ALT_ENTER));
  346. IFR(swapChain.QueryInterface(&m_swapChain));
  347. m_frameIndex = m_swapChain->GetCurrentBackBufferIndex();
  348. return S_OK;
  349. }
  350. HRESULT SetWindowTextToDevice(HRESULT hrCreate, HWND hwnd, IDXGIAdapter *pAdapter, ID3D12Device *pDevice) {
  351. DXGI_ADAPTER_DESC AdapterDesc;
  352. D3D12_FEATURE_DATA_D3D12_OPTIONS1 DeviceOptions;
  353. D3D12_FEATURE_DATA_SHADER_MODEL DeviceSM;
  354. wchar_t title[200];
  355. IFR(pAdapter->GetDesc(&AdapterDesc));
  356. IFR(pDevice->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS1, &DeviceOptions, sizeof(DeviceOptions)));
  357. DeviceSM.HighestShaderModel = D3D_SHADER_MODEL_6_0;
  358. IFR(pDevice->CheckFeatureSupport(D3D12_FEATURE_SHADER_MODEL, &DeviceSM, sizeof(DeviceSM)));
  359. IFR(StringCchPrintfW(
  360. title, _countof(title),
  361. L"%s%s - caps:%s%s%s",
  362. SUCCEEDED(hrCreate) ? L"" : L"(cannot create D3D12 device)",
  363. AdapterDesc.Description,
  364. (DeviceSM.HighestShaderModel >= D3D_SHADER_MODEL_6_0) ? L" SM6" : L"",
  365. DeviceOptions.WaveOps ? L" WaveOps" : L"",
  366. DeviceOptions.Int64ShaderOps ? L" I64" : L""));
  367. SetWindowTextW(hwnd, title);
  368. return S_OK;
  369. }
  370. public:
  371. HRESULT Start(void *pLogCtx, OutputStringFn pLog) {
  372. if (SUCCEEDED(m_threadStartResult)) {
  373. return m_threadStartResult;
  374. }
  375. IFR(m_threadReady.Init());
  376. if (m_width == 0) m_width = 320;
  377. if (m_height == 0) m_height = 200;
  378. m_pLog = pLog;
  379. m_pLogCtx = pLogCtx;
  380. m_thread = CreateThread(nullptr, 0, RendererStart, this, 0, &m_tid);
  381. if (!m_thread) return HRESULT_FROM_WIN32(GetLastError());
  382. WaitForSingleObject(m_threadReady.m_handle, INFINITE);
  383. if (FAILED(m_threadStartResult)) {
  384. WaitForSingleObject(m_thread, INFINITE);
  385. CloseHandle(m_thread);
  386. }
  387. return m_threadStartResult;
  388. }
  389. void Run() {
  390. LPCSTR WindowClassName = "Renderer";
  391. HINSTANCE procInstance = GetModuleHandle(nullptr);
  392. {
  393. auto lock = g_RenderingWindowClassCS.LockCS();
  394. if (g_RenderingWindowClass == NULL) {
  395. WNDCLASS wndClass;
  396. ZeroMemory(&wndClass, sizeof(wndClass));
  397. wndClass.lpszClassName = WindowClassName;
  398. wndClass.hInstance = procInstance; // GetModuleHandle("HLSLHost.exe");
  399. wndClass.style = WS_OVERLAPPED;
  400. wndClass.cbWndExtra = sizeof(void*);
  401. wndClass.lpfnWndProc = RendererWndProc;
  402. ATOM atom = RegisterClass(&wndClass);
  403. if (atom == INVALID_ATOM) {
  404. m_threadStartResult = HRESULT_FROM_WIN32(GetLastError());
  405. m_threadReady.SetEvent();
  406. return;
  407. }
  408. }
  409. }
  410. DWORD style = WS_OVERLAPPEDWINDOW;
  411. RECT windowRect = { 0, 0, (LONG)m_width, (LONG)m_height };
  412. AdjustWindowRect(&windowRect, WS_OVERLAPPEDWINDOW, FALSE);
  413. LPVOID lParam = this;
  414. m_hwnd = CreateWindow(WindowClassName, "Renderer", style, CW_USEDEFAULT,
  415. CW_USEDEFAULT,
  416. windowRect.right - windowRect.left,
  417. windowRect.bottom - windowRect.top, NULL, NULL, procInstance, lParam);
  418. if (m_hwnd == NULL) {
  419. m_threadStartResult = HRESULT_FROM_WIN32(GetLastError());
  420. m_threadReady.SetEvent();
  421. return;
  422. }
  423. LONG_PTR l = (LONG_PTR)(Renderer *)this;
  424. SetWindowLongPtr(m_hwnd, 0, l);
  425. ShowWindow(m_hwnd, SW_NORMAL);
  426. m_threadStartResult = S_OK;
  427. m_threadReady.SetEvent();
  428. // Basic dispatch loop.
  429. MSG msg;
  430. while (GetMessage(&msg, NULL, 0, 0)) {
  431. DispatchMessage(&msg);
  432. if (msg.message == WM_QUIT) {
  433. break;
  434. }
  435. }
  436. if (m_userQuit) {
  437. g_ShutdownServerEvent.SetEvent();
  438. }
  439. }
  440. void SetPayload(LPSTR pText) {
  441. LPSTR textCopy = strdup(pText);
  442. LPSTR oldText = (LPSTR)InterlockedExchangePointer(&m_ShaderOpText, textCopy);
  443. if (oldText != nullptr)
  444. free(oldText);
  445. PostMessage(m_hwnd, WM_RENDERER_SETPAYLOAD, 0, 0);
  446. }
  447. HRESULT SetSize(DWORD width, DWORD height) {
  448. RECT windowRect;
  449. GetWindowRect(m_hwnd, &windowRect);
  450. RECT client = { 0, 0, (LONG)width, (LONG)height };
  451. AdjustWindowRect(&client, WS_OVERLAPPEDWINDOW, FALSE);
  452. SetWindowPos(m_hwnd, 0, windowRect.left, windowRect.top,
  453. client.right - client.left,
  454. client.bottom - client.top, SWP_NOZORDER);
  455. return S_OK;
  456. }
  457. HRESULT SetParentHwnd(HWND handle) {
  458. HWND prior = SetParent(m_hwnd, handle);
  459. if (prior == NULL) {
  460. return HRESULT_FROM_WIN32(GetLastError());
  461. }
  462. if (handle == NULL) {
  463. // Top-level, so restore original style.
  464. SetWindowLong(m_hwnd, GWL_STYLE, WS_OVERLAPPEDWINDOW | WS_VISIBLE);
  465. }
  466. else {
  467. // Child window, so set new style and reset position.
  468. SetWindowPos(m_hwnd, 0, 0, 0, 0, 0, SWP_NOZORDER | SWP_NOSIZE);
  469. SetWindowLong(m_hwnd, GWL_STYLE, WS_CHILD | WS_VISIBLE);
  470. }
  471. return S_OK;
  472. }
  473. void Stop() {
  474. if (m_hwnd != NULL) {
  475. PostMessage(m_hwnd, WM_RENDERER_QUIT, 0, 0);
  476. WaitForSingleObject(m_thread, INFINITE);
  477. CloseHandle(m_thread);
  478. m_threadStartResult = E_PENDING;
  479. m_threadReady.ResetEvent();
  480. m_thread = NULL;
  481. m_hwnd = NULL;
  482. }
  483. }
  484. LRESULT HandleMessage(HWND hWnd, UINT msg, WPARAM wParam, LPARAM lParam) {
  485. switch (msg) {
  486. case WM_SHOWWINDOW:
  487. if (m_device == nullptr) {
  488. SetupD3D();
  489. }
  490. break;
  491. case WM_SIZE:
  492. if (m_device) {
  493. RECT r;
  494. GetClientRect(hWnd, &r);
  495. m_width = r.right - r.left;
  496. m_height = r.bottom - r.top;
  497. HRESULT hr = m_swapChain->ResizeBuffers(FrameCount, 0, 0, DXGI_FORMAT_UNKNOWN, 0);
  498. Log(SUCCEEDED(hr) ? L"Swapchain buffers resized." : L"Failed to resize swapchain buffers.");
  499. }
  500. break;
  501. case WM_KEYDOWN:
  502. if (wParam == 'Q') {
  503. m_userQuit = true;
  504. DestroyWindow(hWnd);
  505. }
  506. if (wParam == 'R') {
  507. HandlePayload();
  508. }
  509. if (wParam == 'C') {
  510. HandleCopy();
  511. }
  512. if (wParam == 'H') {
  513. HandleHelp();
  514. }
  515. if (wParam == 'D') {
  516. HandleDeviceCycle();
  517. }
  518. if (wParam == '2') {
  519. DXGI_MODE_DESC d;
  520. ZeroMemory(&d, sizeof(d));
  521. d.Height = 256;
  522. d.Width = 256;
  523. m_swapChain->ResizeTarget(&d);
  524. }
  525. break;
  526. case WM_DESTROY:
  527. ReleaseD3DResources();
  528. PostQuitMessage(0);
  529. break;
  530. case WM_RENDERER_SETPAYLOAD:
  531. HandlePayload();
  532. break;
  533. case WM_RENDERER_QUIT:
  534. DestroyWindow(hWnd);
  535. break;
  536. }
  537. return DefWindowProc(hWnd, msg, wParam, lParam);
  538. }
  539. void Log(const wchar_t *pLog) {
  540. if (m_pLog)
  541. m_pLog(m_pLogCtx, pLog);
  542. }
  543. };
  544. void __stdcall RendererLog(void *pCtx, const wchar_t *pText) {
  545. ((Renderer *)pCtx)->Log(pText);
  546. }
  547. LRESULT CALLBACK RendererWndProc(HWND hWnd, UINT msg, WPARAM wParam, LPARAM lParam) {
  548. if (msg == WM_CREATE) {
  549. return DefWindowProc(hWnd, msg, wParam, lParam);
  550. }
  551. LONG_PTR l = GetWindowLongPtr(hWnd, 0);
  552. return ((Renderer *)l)->HandleMessage(hWnd, msg, wParam, lParam);
  553. }
  554. DWORD WINAPI RendererStart(LPVOID lpThreadParameter) {
  555. Renderer *R = (Renderer *)lpThreadParameter;
  556. R->Run();
  557. return 0;
  558. }
  559. // Server object to handle messaging.
  560. void __stdcall ServerObjLog(void *pCtx, const wchar_t *pText);
  561. class ServerObj : public IStream {
  562. private:
  563. DXC_MICROCOM_REF_FIELD(m_dwRef);
  564. Renderer m_renderer;
  565. HhCriticalSection m_cs;
  566. CAtlArray<wchar_t> m_pMessages;
  567. CAtlArray<BYTE> m_pLog;
  568. DWORD m_pLogStart = 0;
  569. public:
  570. DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
  571. ServerObj() : m_dwRef(0) {}
  572. ~ServerObj() {}
  573. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  574. return DoBasicQueryInterfaceWithRemote<IStream>(this, iid, ppvObject);
  575. }
  576. HRESULT AppendLog(LPCWSTR pLog) {
  577. size_t logCount = wcslen(pLog);
  578. HRESULT hr = S_OK;
  579. auto l = m_cs.LockCS();
  580. size_t count = m_pMessages.GetCount();
  581. if (!m_pMessages.SetCount(count + logCount))
  582. hr = E_OUTOFMEMORY;
  583. else
  584. memcpy(m_pMessages.GetData() + count, pLog, logCount * sizeof(wchar_t));
  585. return hr;
  586. }
  587. // Write a message back to the server user.
  588. HRESULT WriteMessage(const HhMessageHeader *pHeader) {
  589. HRESULT hr = S_OK;
  590. auto l = m_cs.LockCS();
  591. size_t count = m_pLog.GetCount();
  592. if (!m_pLog.SetCount(count + pHeader->Length))
  593. hr = E_OUTOFMEMORY;
  594. else
  595. memcpy(m_pLog.GetData() + count, pHeader, pHeader->Length);
  596. return hr;
  597. }
  598. void WriteRequestResultReply(UINT requestKind, HRESULT hr) {
  599. HhResultReply result;
  600. result.Header.Kind = requestKind + 100;
  601. result.Header.Length = sizeof(HhResultReply);
  602. result.hr = hr;
  603. WriteMessage(&result.Header);
  604. }
  605. void HandleMessage(const HhMessageHeader *pHeader, ULONG cb) {
  606. DWORD MsgKind = pHeader->Kind;
  607. switch (MsgKind) {
  608. case GetPidMsgId:
  609. HhTrace(L"GetPID message received");
  610. HhGetPidReply R;
  611. R.Header.Kind = GetPidMsgReplyId;
  612. R.Header.Length = sizeof(HhGetPidReply);
  613. R.Pid = GetCurrentProcessId();
  614. WriteMessage(&R.Header);
  615. break;
  616. case ShutdownMsgId:
  617. HhTrace(L"Shutdown message received");
  618. m_renderer.Stop();
  619. g_ShutdownServerEvent.SetEvent();
  620. break;
  621. case StartRendererMsgId:
  622. HhTrace(L"StartRenderer message received");
  623. WriteRequestResultReply(MsgKind,
  624. m_renderer.Start(this, ServerObjLog));
  625. break;
  626. case StopRendererMsgId:
  627. HhTrace(L"StopRenderer message received");
  628. m_renderer.Stop();
  629. WriteRequestResultReply(MsgKind, S_OK);
  630. break;
  631. case SetPayloadMsgId:
  632. LPSTR pText;
  633. HhTrace(L"SetPayload message received");
  634. pText = (LPSTR)(pHeader + 1);
  635. m_renderer.SetPayload(pText);
  636. WriteRequestResultReply(MsgKind, S_OK);
  637. break;
  638. case ReadLogMsgId: {
  639. // Do a single grow and write.
  640. HRESULT hr = S_OK;
  641. HhMessageHeader H;
  642. DWORD messageLen;
  643. DWORD messageLenInBytes;
  644. wchar_t nullTerm = L'\0';
  645. auto l = m_cs.LockCS();
  646. messageLen = (DWORD)m_pMessages.GetCount();
  647. messageLenInBytes = messageLen * sizeof(wchar_t);
  648. H.Length = sizeof(HhMessageHeader) + sizeof(messageLen) + messageLenInBytes + sizeof(nullTerm);
  649. H.Kind = ReadLogMsgReplyId;
  650. size_t count = m_pLog.GetCount();
  651. size_t growBy = H.Length;
  652. if (!m_pLog.SetCount(count + growBy))
  653. hr = E_OUTOFMEMORY;
  654. else {
  655. LPBYTE pCursor = m_pLog.GetData() + count;
  656. memcpy(pCursor, &H, sizeof(H));
  657. pCursor += sizeof(H);
  658. memcpy(pCursor, &messageLen, sizeof(messageLen));
  659. pCursor += sizeof(messageLen);
  660. memcpy(pCursor, m_pMessages.GetData(), messageLenInBytes);
  661. pCursor += messageLenInBytes;
  662. memcpy(pCursor, &nullTerm, sizeof(nullTerm));
  663. m_pMessages.SetCount(0);
  664. }
  665. break;
  666. }
  667. case SetSizeMsgId: {
  668. if (cb < sizeof(HhSetSizeMessage)) {
  669. WriteRequestResultReply(MsgKind, E_INVALIDARG);
  670. return;
  671. }
  672. const HhSetSizeMessage *pSetSize = (const HhSetSizeMessage *)pHeader;
  673. WriteRequestResultReply(MsgKind,
  674. m_renderer.SetSize(pSetSize->Width, pSetSize->Height));
  675. }
  676. case SetParentWndMsgId: {
  677. if (cb < sizeof(HhSetParentWndMessage)) {
  678. WriteRequestResultReply(MsgKind, E_INVALIDARG);
  679. return;
  680. }
  681. const HhSetParentWndMessage *pSetParent = (const HhSetParentWndMessage *)pHeader;
  682. WriteRequestResultReply(MsgKind,
  683. m_renderer.SetParentHwnd((HWND)pSetParent->Handle));
  684. }
  685. }
  686. }
  687. // ISequentialStream implementation.
  688. HRESULT STDMETHODCALLTYPE Read(void *pv, ULONG cb, ULONG *pcbRead) override {
  689. if (!pv)
  690. return E_POINTER;
  691. if (cb == 0)
  692. return S_OK;
  693. HRESULT hr = S_OK;
  694. auto l = m_cs.LockCS();
  695. size_t count = m_pLog.GetCount();
  696. size_t countLeft = count - m_pLogStart;
  697. if (cb > countLeft) {
  698. cb = countLeft;
  699. hr = S_FALSE;
  700. }
  701. if (pcbRead)
  702. *pcbRead = cb;
  703. memcpy(pv, m_pLog.GetData() + m_pLogStart, cb);
  704. m_pLogStart += cb;
  705. // If we have more than 2K of wasted space, shrink.
  706. if (m_pLogStart > 2048) {
  707. size_t newSize = count - m_pLogStart;
  708. memmove(m_pLog.GetData(), m_pLog.GetData() + m_pLogStart, newSize);
  709. m_pLog.SetCount(newSize);
  710. m_pLogStart = 0;
  711. }
  712. return hr;
  713. }
  714. HRESULT STDMETHODCALLTYPE Write(void const *pv, ULONG cb,
  715. ULONG *pcbWritten) override {
  716. if (!pv || !pcbWritten)
  717. return E_POINTER;
  718. if (cb == 0)
  719. return S_OK;
  720. if (cb < sizeof(HhMessageHeader)) {
  721. HhTrace(L"Message is smaller than sizeof(HhMessageHeader).");
  722. return E_FAIL;
  723. }
  724. HandleMessage((const HhMessageHeader *)pv, cb);
  725. *pcbWritten = cb;
  726. return S_OK;
  727. }
  728. // IStream implementation.
  729. HRESULT STDMETHODCALLTYPE SetSize(ULARGE_INTEGER val) override {
  730. HhTrace(L"SetSize called - E_NOTIMPL");
  731. return E_NOTIMPL;
  732. }
  733. HRESULT STDMETHODCALLTYPE CopyTo(IStream *, ULARGE_INTEGER,
  734. ULARGE_INTEGER *,
  735. ULARGE_INTEGER *) override {
  736. return E_NOTIMPL;
  737. }
  738. HRESULT STDMETHODCALLTYPE Commit(DWORD) override { return E_NOTIMPL; }
  739. HRESULT STDMETHODCALLTYPE Revert(void) override { return E_NOTIMPL; }
  740. HRESULT STDMETHODCALLTYPE LockRegion(ULARGE_INTEGER,
  741. ULARGE_INTEGER, DWORD) override {
  742. return E_NOTIMPL;
  743. }
  744. HRESULT STDMETHODCALLTYPE UnlockRegion(ULARGE_INTEGER,
  745. ULARGE_INTEGER, DWORD) override {
  746. return E_NOTIMPL;
  747. }
  748. HRESULT STDMETHODCALLTYPE Clone(IStream **) override { return E_NOTIMPL; }
  749. HRESULT STDMETHODCALLTYPE Seek(LARGE_INTEGER, DWORD,
  750. ULARGE_INTEGER *) override {
  751. return E_NOTIMPL;
  752. }
  753. HRESULT STDMETHODCALLTYPE Stat(STATSTG *, DWORD) override {
  754. HhTrace(L"Stat called - E_NOTIMPL");
  755. return E_NOTIMPL;
  756. }
  757. };
  758. void __stdcall ServerObjLog(void *pCtx, const wchar_t *pText) {
  759. ((ServerObj *)pCtx)->AppendLog(pText);
  760. }
  761. // Server startup and lifetime.
  762. class ServerFactory : public IClassFactory {
  763. private:
  764. DXC_MICROCOM_REF_FIELD(m_dwRef);
  765. public:
  766. DXC_MICROCOM_ADDREF_RELEASE_IMPL(m_dwRef)
  767. ServerFactory() : m_dwRef(0) {}
  768. HRESULT STDMETHODCALLTYPE QueryInterface(REFIID iid, void **ppvObject) {
  769. return DoBasicQueryInterfaceWithRemote<IClassFactory>(this, iid, ppvObject);
  770. }
  771. HRESULT STDMETHODCALLTYPE CreateInstance(IUnknown *pUnk, REFIID riid, void **ppvObj) {
  772. if (pUnk) return CLASS_E_NOAGGREGATION;
  773. CComPtr<ServerObj> obj = new (std::nothrow) ServerObj();
  774. if (obj.p == nullptr) return E_OUTOFMEMORY;
  775. return obj.p->QueryInterface(riid, ppvObj);
  776. }
  777. HRESULT STDMETHODCALLTYPE LockServer(BOOL fLock) {
  778. // TODO: implement
  779. return S_OK;
  780. }
  781. };
  782. HRESULT RunServer(REFCLSID rclsid) {
  783. HhTrace(L"Starting HLSL Host...");
  784. CoInit coInit;
  785. ClassObjectRegistration registration;
  786. IFR(coInit.Initialize(COINIT_MULTITHREADED));
  787. IFR(g_ShutdownServerEvent.Init());
  788. CComPtr<ServerFactory> pServerFactory = new (std::nothrow) ServerFactory();
  789. IFR(registration.Register(rclsid, pServerFactory, CLSCTX_LOCAL_SERVER, REGCLS_MULTIPLEUSE));
  790. WaitForSingleObject(g_ShutdownServerEvent.m_handle, INFINITE);
  791. return S_OK;
  792. }
  793. HRESULT RunServer(const wchar_t *pCLSIDText) {
  794. CLSID clsid;
  795. if (pCLSIDText && *pCLSIDText) {
  796. IFR(CLSIDFromString(pCLSIDText, &clsid));
  797. }
  798. else {
  799. clsid = CLSID_HLSLHostServer;
  800. }
  801. return RunServer(clsid);
  802. }
  803. // Entry point for host process to render shaders.
  804. int wmain(int argc, wchar_t* argv[]) {
  805. int resultCode;
  806. resultCode = SUCCEEDED(RunServer(nullptr)) ? 0 : 1;
  807. return resultCode;
  808. }