SemaDXR.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119
  1. //===------ SemaDXR.cpp - Semantic Analysis for DXR shader -----*- C++ -*-===//
  2. ///////////////////////////////////////////////////////////////////////////////
  3. // //
  4. // SemaDXR.cpp //
  5. // Copyright (C) Nvidia Corporation. All rights reserved. //
  6. // This file is distributed under the University of Illinois Open Source //
  7. // License. See LICENSE.TXT for details. //
  8. // //
  9. // This file defines the semantic support for DXR. //
  10. // //
  11. ///////////////////////////////////////////////////////////////////////////////
  12. #include "clang/AST/ASTContext.h"
  13. #include "clang/AST/Attr.h"
  14. #include "clang/AST/Decl.h"
  15. #include "clang/AST/DeclCXX.h"
  16. #include "clang/AST/DeclTemplate.h"
  17. #include "clang/AST/Expr.h"
  18. #include "clang/AST/ExprCXX.h"
  19. #include "clang/AST/ExternalASTSource.h"
  20. #include "clang/AST/RecursiveASTVisitor.h"
  21. #include "clang/Sema/SemaHLSL.h"
  22. #include "clang/Analysis/Analyses/Dominators.h"
  23. #include "clang/Analysis/Analyses/ReachableCode.h"
  24. #include "clang/Analysis/CFG.h"
  25. #include "llvm/ADT/BitVector.h"
  26. #include "dxc/DXIL/DxilConstants.h"
  27. using namespace clang;
  28. using namespace sema;
  29. using namespace hlsl;
  30. namespace {
  31. struct PayloadUse {
  32. PayloadUse() = default;
  33. PayloadUse(const Stmt *S, const CFGBlock *Parent)
  34. : S(S), Parent(Parent), Member(nullptr) {}
  35. PayloadUse(const Stmt *S, const CFGBlock *Parent, const MemberExpr *Member)
  36. : S(S), Parent(Parent), Member(Member) {}
  37. bool operator<(const PayloadUse &Other) const { return S < Other.S; }
  38. const Stmt *S = nullptr;
  39. const CFGBlock *Parent = nullptr;
  40. const MemberExpr *Member = nullptr;
  41. };
  42. struct TraceRayCall {
  43. TraceRayCall() = default;
  44. TraceRayCall(const CallExpr *Call, const CFGBlock *Parent)
  45. : Call(Call), Parent(Parent) {}
  46. const CallExpr *Call = nullptr;
  47. const CFGBlock *Parent = nullptr;
  48. };
  49. struct PayloadAccessInfo {
  50. PayloadAccessInfo() = default;
  51. PayloadAccessInfo(const MemberExpr *Member, const CallExpr *Call,
  52. bool IsLValue)
  53. : Member(Member), Call(Call), IsLValue(IsLValue) {}
  54. const MemberExpr *Member = nullptr;
  55. const CallExpr *Call = nullptr;
  56. bool IsLValue = false;
  57. };
  58. struct DxrShaderDiagnoseInfo {
  59. const FunctionDecl *funcDecl;
  60. const VarDecl *Payload;
  61. DXIL::PayloadAccessShaderStage Stage;
  62. std::vector<TraceRayCall> TraceCalls;
  63. std::map<const FieldDecl *, std::vector<PayloadUse>> WritesPerField;
  64. std::map<const FieldDecl *, std::vector<PayloadUse>> ReadsPerField;
  65. std::vector<PayloadUse> PayloadAsCallArg;
  66. };
  67. std::vector<const FieldDecl *>
  68. DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
  69. const std::set<const FieldDecl *> &FieldsToIgnoreRead,
  70. const std::set<const FieldDecl *> &FieldsToIgnoreWrite,
  71. std::set<const FunctionDecl *> VisitedFunctions);
  72. const Stmt *IgnoreParensAndDecay(const Stmt *S);
  73. // Transform the shader stage to string to be used in diagnostics
  74. StringRef GetStringForShaderStage(DXIL::PayloadAccessShaderStage Stage) {
  75. StringRef StageNames[] = {"caller", "closesthit", "miss", "anyhit"};
  76. if (Stage != DXIL::PayloadAccessShaderStage::Invalid)
  77. return StageNames[static_cast<unsigned>(Stage)];
  78. return "";
  79. }
  80. // Returns the Qualifier for a Field and a given shader stage.
  81. DXIL::PayloadAccessQualifier
  82. GetPayloadQualifierForStage(FieldDecl *Field,
  83. DXIL::PayloadAccessShaderStage Stage) {
  84. bool hasRead = false;
  85. bool hasWrite = false;
  86. for (UnusualAnnotation *annotation : Field->getUnusualAnnotations()) {
  87. if (auto *payloadAnnotation =
  88. dyn_cast<hlsl::PayloadAccessAnnotation>(annotation)) {
  89. for (auto &ShaderStage : payloadAnnotation->ShaderStages) {
  90. if (ShaderStage != Stage)
  91. continue;
  92. hasRead |=
  93. payloadAnnotation->qualifier == DXIL::PayloadAccessQualifier::Read;
  94. hasWrite |=
  95. payloadAnnotation->qualifier == DXIL::PayloadAccessQualifier::Write;
  96. }
  97. }
  98. }
  99. if (hasRead && hasWrite)
  100. return DXIL::PayloadAccessQualifier::ReadWrite;
  101. if (hasRead)
  102. return DXIL::PayloadAccessQualifier::Read;
  103. if (hasWrite)
  104. return DXIL::PayloadAccessQualifier::Write;
  105. return DXIL::PayloadAccessQualifier::NoAccess;
  106. }
  107. // Returns the declaration of the payload used in a TraceRay call
  108. const VarDecl *GetPayloadParameterForTraceCall(const CallExpr *Trace) {
  109. const Decl *callee = Trace->getCalleeDecl();
  110. if (!callee)
  111. return nullptr;
  112. if (!isa<FunctionDecl>(callee))
  113. return nullptr;
  114. const FunctionDecl *FD = cast<FunctionDecl>(callee);
  115. if (FD->isImplicit() && FD->getName() == "TraceRay") {
  116. const Stmt *Param = IgnoreParensAndDecay(Trace->getArg(7));
  117. if (const DeclRefExpr *ParamRef = dyn_cast<DeclRefExpr>(Param)) {
  118. if (const VarDecl *Decl = dyn_cast<VarDecl>(ParamRef->getDecl()))
  119. return Decl;
  120. }
  121. }
  122. return nullptr;
  123. }
  124. // Recursively extracts accesses to a payload struct from a Stmt
  125. void GetPayloadAccesses(const Stmt *S, const DxrShaderDiagnoseInfo &Info,
  126. std::vector<PayloadAccessInfo> &Accesses, bool IsLValue,
  127. const MemberExpr *Member, const CallExpr *Call) {
  128. for (auto C : S->children()) {
  129. if (!C)
  130. continue;
  131. if (const DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(C)) {
  132. if (Ref->getDecl() == Info.Payload) {
  133. Accesses.push_back(PayloadAccessInfo{Member, Call, IsLValue});
  134. }
  135. return;
  136. }
  137. if (const ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(C)) {
  138. if (Cast->getCastKind() == CK_LValueToRValue) {
  139. IsLValue = false;
  140. }
  141. }
  142. GetPayloadAccesses(C, Info, Accesses, IsLValue,
  143. Member ? Member : dyn_cast<MemberExpr>(C),
  144. Call ? Call : dyn_cast<CallExpr>(C));
  145. }
  146. }
  147. // Collects all reads, writes and calls with participation of the payload.
  148. void CollectReadsWritesAndCallsForPayload(const Stmt *S,
  149. DxrShaderDiagnoseInfo &Info,
  150. const CFGBlock *Block) {
  151. std::vector<PayloadAccessInfo> PayloadAccesses;
  152. GetPayloadAccesses(S, Info, PayloadAccesses, true, dyn_cast<MemberExpr>(S),
  153. dyn_cast<CallExpr>(S));
  154. for (auto &Access : PayloadAccesses) {
  155. // An access to a payload member was found.
  156. if (Access.Member) {
  157. FieldDecl *Field = cast<FieldDecl>(Access.Member->getMemberDecl());
  158. if (Access.IsLValue) {
  159. Info.WritesPerField[Field].push_back(
  160. PayloadUse{S, Block, Access.Member});
  161. } else {
  162. Info.ReadsPerField[Field].push_back(PayloadUse{S, Block, Access.Member});
  163. }
  164. } else if (Access.Call) {
  165. Info.PayloadAsCallArg.push_back(PayloadUse{S, Block});
  166. }
  167. }
  168. }
  169. // Collects all TraceRay calls.
  170. void CollectTraceRayCalls(const Stmt *S, DxrShaderDiagnoseInfo &Info,
  171. const CFGBlock *Block) {
  172. // TraceRay has void as return type so it should never be something else
  173. // than a plain CallExpr.
  174. if (const CallExpr *Call = dyn_cast<CallExpr>(S)) {
  175. const Decl *Callee = Call->getCalleeDecl();
  176. if (!Callee || !isa<FunctionDecl>(Callee))
  177. return;
  178. const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
  179. // Ignore trace calls here.
  180. if (CalledFunction->isImplicit() &&
  181. CalledFunction->getName() == "TraceRay") {
  182. Info.TraceCalls.push_back({Call, Block});
  183. }
  184. }
  185. }
  186. // Find the last write to the payload field in the given block.
  187. PayloadUse GetLastWriteInBlock(CFGBlock &Block,
  188. ArrayRef<PayloadUse> PayloadWrites) {
  189. PayloadUse LastWrite;
  190. for (auto &Element : Block) { // TODO: reverse iterate?
  191. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  192. auto It =
  193. std::find_if(PayloadWrites.begin(), PayloadWrites.end(),
  194. [&](const PayloadUse &V) { return V.S == S->getStmt(); });
  195. if (It != std::end(PayloadWrites)) {
  196. LastWrite = *It;
  197. LastWrite.Parent = &Block;
  198. }
  199. }
  200. }
  201. return LastWrite;
  202. }
  203. // Travers the CFG until every path has reached a write or the ENTRY.
  204. void TraverseCFGUntilWrite(CFGBlock &Current, std::vector<PayloadUse> &Writes,
  205. ArrayRef<PayloadUse> PayloadWrites,
  206. std::set<const CFGBlock *> &Visited) {
  207. if (Visited.count(&Current))
  208. return;
  209. Visited.insert(&Current);
  210. for (auto I = Current.pred_begin(), E = Current.pred_end(); I != E; ++I) {
  211. CFGBlock *Pred = *I;
  212. if (!Pred)
  213. continue;
  214. PayloadUse WriteInPred = GetLastWriteInBlock(*Pred, PayloadWrites);
  215. if (!WriteInPred.S)
  216. TraverseCFGUntilWrite(*Pred, Writes, PayloadWrites, Visited);
  217. else
  218. Writes.push_back(WriteInPred);
  219. }
  220. }
  221. // Traverse the CFG from the EXIT backwards and stop as soon as a block has a
  222. // write to the payload field.
  223. std::vector<PayloadUse>
  224. GetAllWritesReachingExit(CFG &ShaderCFG, ArrayRef<PayloadUse> PayloadWrites) {
  225. std::vector<PayloadUse> Writes;
  226. CFGBlock &Exit = ShaderCFG.getExit();
  227. std::set<const CFGBlock *> Visited;
  228. TraverseCFGUntilWrite(Exit, Writes, PayloadWrites, Visited);
  229. return Writes;
  230. }
  231. // Find the first read to the payload field in the given block.
  232. PayloadUse GetFirstReadInBlock(CFGBlock &Block,
  233. ArrayRef<PayloadUse> PayloadReads) {
  234. PayloadUse FirstRead;
  235. for (auto &Element : Block) {
  236. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  237. auto It =
  238. std::find_if(PayloadReads.begin(), PayloadReads.end(),
  239. [&](const PayloadUse &V) { return V.S == S->getStmt(); });
  240. if (It != std::end(PayloadReads)) {
  241. FirstRead = *It;
  242. FirstRead.Parent = &Block;
  243. break; // We found the first read and are done with this block.
  244. }
  245. }
  246. }
  247. return FirstRead;
  248. }
  249. // Travers the CFG until every path has reached a read or the EXIT.
  250. void TraverseCFGUntilRead(CFGBlock &Current, std::vector<PayloadUse> &Reads,
  251. ArrayRef<PayloadUse> PayloadWrites,
  252. std::set<const CFGBlock *> &Visited) {
  253. if (Visited.count(&Current))
  254. return;
  255. Visited.insert(&Current);
  256. for (auto I = Current.succ_begin(), E = Current.succ_end(); I != E; ++I) {
  257. CFGBlock *Succ = *I;
  258. if (!Succ)
  259. continue;
  260. PayloadUse ReadInSucc = GetFirstReadInBlock(*Succ, PayloadWrites);
  261. if (!ReadInSucc.S)
  262. TraverseCFGUntilRead(*Succ, Reads, PayloadWrites, Visited);
  263. else
  264. Reads.push_back(ReadInSucc);
  265. }
  266. }
  267. // Traverse the CFG from the ENTRY down and stop as soon as a block has a read
  268. // to the payload field.
  269. std::vector<PayloadUse>
  270. GetAllReadsReachedFromEntry(CFG &ShaderCFG, ArrayRef<PayloadUse> PayloadReads) {
  271. std::vector<PayloadUse> Reads;
  272. CFGBlock &Entry = ShaderCFG.getEntry();
  273. std::set<const CFGBlock *> Visited;
  274. TraverseCFGUntilRead(Entry, Reads, PayloadReads, Visited);
  275. return Reads;
  276. }
  277. // Returns the record type of a payload declaration.
  278. CXXRecordDecl *GetPayloadType(const VarDecl *Payload) {
  279. auto PayloadType = Payload->getType();
  280. if (PayloadType->isStructureOrClassType()) {
  281. return PayloadType->getAsCXXRecordDecl();
  282. }
  283. return nullptr;
  284. }
  285. std::vector<FieldDecl*> GetAllPayloadFields(RecordDecl* PayloadType) {
  286. std::vector<FieldDecl*> PayloadFields;
  287. for (FieldDecl *Field : PayloadType->fields()) {
  288. QualType FieldType = Field->getType();
  289. if (RecordDecl *FieldRecordDecl = FieldType->getAsCXXRecordDecl()) {
  290. // Skip nested payload types.
  291. if (FieldRecordDecl->hasAttr<HLSLRayPayloadAttr>()) {
  292. auto SubTypeFields = GetAllPayloadFields(FieldRecordDecl);
  293. PayloadFields.insert(PayloadFields.end(), SubTypeFields.begin(), SubTypeFields.end());
  294. continue;
  295. }
  296. }
  297. PayloadFields.push_back(Field);
  298. }
  299. return PayloadFields;
  300. }
  301. // Returns true if the field is writeable in an earlier shader stage.
  302. bool IsFieldWriteableInEarlierStage(FieldDecl *Field,
  303. DXIL::PayloadAccessShaderStage ThisStage) {
  304. bool isWriteableInEarlierStage = false;
  305. switch (ThisStage) {
  306. case DXIL::PayloadAccessShaderStage::Anyhit:
  307. case DXIL::PayloadAccessShaderStage::Closesthit:
  308. case DXIL::PayloadAccessShaderStage::Miss: {
  309. auto Qualifier = GetPayloadQualifierForStage(
  310. Field, DXIL::PayloadAccessShaderStage::Caller);
  311. isWriteableInEarlierStage =
  312. Qualifier == DXIL::PayloadAccessQualifier::Write ||
  313. Qualifier == DXIL::PayloadAccessQualifier::ReadWrite;
  314. Qualifier = GetPayloadQualifierForStage(
  315. Field, DXIL::PayloadAccessShaderStage::Anyhit);
  316. isWriteableInEarlierStage |=
  317. Qualifier == DXIL::PayloadAccessQualifier::Write ||
  318. Qualifier == DXIL::PayloadAccessQualifier::ReadWrite;
  319. } break;
  320. default:
  321. break;
  322. }
  323. return isWriteableInEarlierStage;
  324. }
  325. // Emit warnings on payload writes.
  326. void DiagnosePayloadWrites(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
  327. const DxrShaderDiagnoseInfo &Info,
  328. ArrayRef<FieldDecl *> NonWriteableFields,
  329. RecordDecl *PayloadType) {
  330. for (FieldDecl *Field : NonWriteableFields) {
  331. auto WritesToField = Info.WritesPerField.find(Field);
  332. if (WritesToField == Info.WritesPerField.end())
  333. continue;
  334. const auto &WritesToDiagnose =
  335. GetAllWritesReachingExit(ShaderCFG, WritesToField->second);
  336. for (auto &Write : WritesToDiagnose) {
  337. FieldDecl *MemField = cast<FieldDecl>(Write.Member->getMemberDecl());
  338. auto Qualifier = GetPayloadQualifierForStage(MemField, Info.Stage);
  339. if (Qualifier != DXIL::PayloadAccessQualifier::Write &&
  340. Qualifier != DXIL::PayloadAccessQualifier::ReadWrite) {
  341. S.Diag(Write.Member->getExprLoc(), diag::warn_hlsl_payload_access_write_loss)
  342. << Field->getName() << GetStringForShaderStage(Info.Stage);
  343. }
  344. }
  345. }
  346. // Check if a field is not unconditionally written and a write form an earlier
  347. // stage will be lost.
  348. auto PayloadFields = GetAllPayloadFields(PayloadType);
  349. for (FieldDecl *Field : PayloadFields) {
  350. auto Qualifier = GetPayloadQualifierForStage(Field, Info.Stage);
  351. if (IsFieldWriteableInEarlierStage(Field, Info.Stage) &&
  352. Qualifier == DXIL::PayloadAccessQualifier::Write) {
  353. // The field is writeable in an earlier stage and pure write in this
  354. // stage. Check if we find a write that dominates the exit of the
  355. // function.
  356. bool fieldHasDominatingWrite = false;
  357. auto It = Info.WritesPerField.find(Field);
  358. if (It != Info.WritesPerField.end()) {
  359. for (auto &Write : It->second) {
  360. fieldHasDominatingWrite =
  361. DT.dominates(Write.Parent, &ShaderCFG.getExit());
  362. if (fieldHasDominatingWrite)
  363. break;
  364. }
  365. }
  366. if (!fieldHasDominatingWrite) {
  367. S.Diag(Info.Payload->getLocation(),
  368. diag::warn_hlsl_payload_access_data_loss)
  369. << Field->getName() << GetStringForShaderStage(Info.Stage);
  370. }
  371. }
  372. }
  373. }
  374. // Returns true if A is earlier than B in Parent
  375. bool IsEarlierStatementAs(const Stmt *A, const Stmt *B,
  376. const CFGBlock &Parent) {
  377. for (auto Element : Parent) {
  378. if (auto S = Element.getAs<CFGStmt>()) {
  379. if (S->getStmt() == A)
  380. return true;
  381. if (S->getStmt() == B)
  382. return false;
  383. }
  384. }
  385. return true;
  386. }
  387. // Returns true if the write dominates payload use.
  388. template <typename T>
  389. bool WriteDominatesUse(const PayloadUse &Write, const T &Use,
  390. DominatorTree &DT) {
  391. if (Use.Parent == Write.Parent) {
  392. // Use and write are in the same Block.
  393. return IsEarlierStatementAs(Write.S, Use.S, *Use.Parent);
  394. }
  395. return DT.dominates(Write.Parent, Use.Parent);
  396. }
  397. // Emit warnings for payload reads.
  398. void DiagnosePayloadReads(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
  399. const DxrShaderDiagnoseInfo &Info,
  400. ArrayRef<FieldDecl *> NonReadableFields) {
  401. for (FieldDecl *Field : NonReadableFields) {
  402. auto ReadsFromField = Info.ReadsPerField.find(Field);
  403. if (ReadsFromField == Info.ReadsPerField.end())
  404. continue;
  405. auto WritesToField = Info.WritesPerField.find(Field);
  406. bool FieldHasWrites = WritesToField != Info.WritesPerField.end();
  407. const auto &ReadsToDiagnose =
  408. GetAllReadsReachedFromEntry(ShaderCFG, ReadsFromField->second);
  409. for (auto &Read : ReadsToDiagnose) {
  410. bool ReadIsDominatedByWrite = false;
  411. if (FieldHasWrites) {
  412. // We found a read to a field that needs diagnose.
  413. // We do not want to warn about fields that read but are dominated by a
  414. // write. Find writes that dominate the read. If we found one, ignore
  415. // the read.
  416. for (auto Write : WritesToField->second) {
  417. ReadIsDominatedByWrite = WriteDominatesUse(Write, Read, DT);
  418. if (ReadIsDominatedByWrite)
  419. break;
  420. }
  421. }
  422. if (ReadIsDominatedByWrite)
  423. continue;
  424. FieldDecl *MemField = cast<FieldDecl>(Read.Member->getMemberDecl());
  425. auto Qualifier = GetPayloadQualifierForStage(MemField, Info.Stage);
  426. if (Qualifier != DXIL::PayloadAccessQualifier::Read &&
  427. Qualifier != DXIL::PayloadAccessQualifier::ReadWrite) {
  428. S.Diag(Read.Member->getExprLoc(), diag::warn_hlsl_payload_access_undef_read)
  429. << Field->getName() << GetStringForShaderStage(Info.Stage);
  430. }
  431. }
  432. }
  433. }
  434. // Generic CFG traversal that performs PerElementAction on every Stmt in the
  435. // CFG.
  436. template <bool Backward, typename Action>
  437. void TraverseCFG(const CFGBlock &Block, Action PerElementAction,
  438. std::set<const CFGBlock *> &Visited) {
  439. if (Visited.count(&Block))
  440. return;
  441. Visited.insert(&Block);
  442. for (const auto &Element : Block) {
  443. PerElementAction(Block, Element);
  444. }
  445. if (!Backward) {
  446. for (auto I = Block.succ_begin(), E = Block.succ_end(); I != E; ++I) {
  447. CFGBlock *Succ = *I;
  448. if (!Succ)
  449. continue;
  450. TraverseCFG</*Backward=*/false>(*Succ, PerElementAction, Visited);
  451. }
  452. } else {
  453. for (auto I = Block.pred_begin(), E = Block.pred_end(); I != E; ++I) {
  454. CFGBlock *Pred = *I;
  455. if (!Pred)
  456. continue;
  457. TraverseCFG<Backward>(*Pred, PerElementAction, Visited);
  458. }
  459. }
  460. }
  461. // Forward traverse the CFG and collect calls to TraceRay.
  462. void ForwardTraverseCFGAndCollectTraceCalls(
  463. const CFGBlock &Block, DxrShaderDiagnoseInfo &Info,
  464. std::set<const CFGBlock *> &Visited) {
  465. auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
  466. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  467. CollectTraceRayCalls(S->getStmt(), Info, &Block);
  468. }
  469. };
  470. TraverseCFG<false>(Block, Action, Visited);
  471. }
  472. // Foward traverse the CFG and collect all reads and writes to the payload.
  473. void ForwardTraverseCFGAndCollectReadsWrites(
  474. const CFGBlock &StartBlock, DxrShaderDiagnoseInfo &Info,
  475. std::set<const CFGBlock *> &Visited) {
  476. auto Action = [&Info](const CFGBlock &Block, const CFGElement &Element) {
  477. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  478. CollectReadsWritesAndCallsForPayload(S->getStmt(), Info, &Block);
  479. }
  480. };
  481. TraverseCFG<false>(StartBlock, Action, Visited);
  482. }
  483. // Backward traverse the CFG and collect all reads and writes to the payload.
  484. void BackwardTraverseCFGAndCollectReadsWrites(
  485. const CFGBlock &StartBlock, DxrShaderDiagnoseInfo &Info,
  486. std::set<const CFGBlock *> &Visited) {
  487. auto Action = [&](const CFGBlock &Block, const CFGElement &Element) {
  488. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  489. CollectReadsWritesAndCallsForPayload(S->getStmt(), Info, &Block);
  490. }
  491. };
  492. TraverseCFG<true>(StartBlock, Action, Visited);
  493. }
  494. // Returns true if the Stmt uses the Payload.
  495. bool IsPayloadArg(const Stmt *S, const Decl *Payload) {
  496. if (const DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(S)) {
  497. const Decl *Decl = Ref->getDecl();
  498. if (Decl == Payload)
  499. return true;
  500. }
  501. for (auto C : S->children()) {
  502. if (IsPayloadArg(C, Payload))
  503. return true;
  504. }
  505. return false;
  506. }
  507. bool DiagnoseCallExprForExternal(Sema &S, const FunctionDecl *FD,
  508. const CallExpr *CE,
  509. const ParmVarDecl *Payload);
  510. // Collects all writes that dominate a PayloadUse in a CallExpr
  511. // and returns a set of the Fields accessed.
  512. std::set<const FieldDecl *>
  513. CollectDominatingWritesForCall(PayloadUse &Use, DxrShaderDiagnoseInfo &Info,
  514. DominatorTree &DT) {
  515. std::set<const FieldDecl *> FieldsToIgnore;
  516. for (auto P : Info.WritesPerField) {
  517. for (auto Write : P.second) {
  518. bool WriteDominatesCallSite = WriteDominatesUse(Write, Use, DT);
  519. if (WriteDominatesCallSite) {
  520. FieldsToIgnore.insert(P.first);
  521. break;
  522. }
  523. }
  524. }
  525. return FieldsToIgnore;
  526. }
  527. // Collects all reads that are reachable from a PayloadUse in a CallExpr
  528. // and returns a set of the Fields accessed.
  529. std::set<const FieldDecl *>
  530. CollectReachableWritesForCall(PayloadUse &Use,
  531. const DxrShaderDiagnoseInfo &Info) {
  532. std::set<const FieldDecl *> FieldsToIgnore;
  533. assert(Use.Parent);
  534. const CFGBlock *Current = Use.Parent;
  535. // Traverse the CFG beginning from the block of the call and collect all
  536. // fields written to after the call. These fields must not be diagnosed with
  537. // warnings about lost writes.
  538. DxrShaderDiagnoseInfo TempInfo;
  539. TempInfo.Payload = Info.Payload;
  540. bool foundCall = false;
  541. for (auto &Element : *Current) {
  542. // Search for the Call in the block
  543. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  544. if (S->getStmt() == Use.S) {
  545. foundCall = true;
  546. continue;
  547. }
  548. if (foundCall)
  549. CollectReadsWritesAndCallsForPayload(S->getStmt(), TempInfo, Current);
  550. }
  551. }
  552. for (auto I = Current->succ_begin(); I != Current->succ_end(); ++I) {
  553. CFGBlock *Succ = *I;
  554. if (!Succ)
  555. continue;
  556. std::set<const CFGBlock *> Visited;
  557. ForwardTraverseCFGAndCollectReadsWrites(*Succ, TempInfo, Visited);
  558. }
  559. for (auto &p : TempInfo.WritesPerField)
  560. FieldsToIgnore.insert(p.first);
  561. return FieldsToIgnore;
  562. }
  563. // Emit diagnostics when the payload is used as an argument
  564. // in a function call.
  565. std::map<PayloadUse, std::vector<const FieldDecl *>>
  566. DiagnosePayloadAsFunctionArg(
  567. Sema &S, DxrShaderDiagnoseInfo &Info, DominatorTree &DT,
  568. const std::set<const FieldDecl *> &ParentFieldsToIgnoreRead,
  569. const std::set<const FieldDecl *> &ParentFieldsToIgnoreWrite,
  570. std::set<const FunctionDecl *> VisitedFunctions) {
  571. std::map<PayloadUse, std::vector<const FieldDecl *>> WrittenFieldsInCalls;
  572. for (PayloadUse &Use : Info.PayloadAsCallArg) {
  573. if (const CallExpr *Call = dyn_cast<CallExpr>(Use.S)) {
  574. const Decl *Callee = Call->getCalleeDecl();
  575. if (!Callee || !isa<FunctionDecl>(Callee))
  576. continue;
  577. const FunctionDecl *CalledFunction = cast<FunctionDecl>(Callee);
  578. // Ignore trace calls here.
  579. if (CalledFunction->isImplicit() &&
  580. CalledFunction->getName() == "TraceRay") {
  581. Info.TraceCalls.push_back(TraceRayCall{Call, Use.Parent});
  582. continue;
  583. }
  584. // Handle external function calls
  585. if (!CalledFunction->hasBody()) {
  586. assert(isa<ParmVarDecl>(Info.Payload));
  587. DiagnoseCallExprForExternal(S, CalledFunction, Call,
  588. cast<ParmVarDecl>(Info.Payload));
  589. continue;
  590. }
  591. if (VisitedFunctions.count(CalledFunction))
  592. return WrittenFieldsInCalls;
  593. VisitedFunctions.insert(CalledFunction);
  594. DxrShaderDiagnoseInfo CalleeInfo;
  595. for (unsigned i = 0; i < Call->getNumArgs(); ++i) {
  596. const Expr *Arg = Call->getArg(i);
  597. if (IsPayloadArg(Arg, Info.Payload)) {
  598. CalleeInfo.Payload = CalledFunction->getParamDecl(i);
  599. break;
  600. }
  601. }
  602. if (CalleeInfo.Payload) {
  603. CalleeInfo.funcDecl = CalledFunction;
  604. CalleeInfo.Stage = Info.Stage;
  605. auto FieldsToIgnoreRead = CollectDominatingWritesForCall(Use, Info, DT);
  606. auto FieldsToIgnoreWrite = CollectReachableWritesForCall(Use, Info);
  607. FieldsToIgnoreRead.insert(ParentFieldsToIgnoreRead.begin(),
  608. ParentFieldsToIgnoreRead.end());
  609. FieldsToIgnoreWrite.insert(ParentFieldsToIgnoreWrite.begin(),
  610. ParentFieldsToIgnoreWrite.end());
  611. WrittenFieldsInCalls[Use] =
  612. DiagnosePayloadAccess(S, CalleeInfo, FieldsToIgnoreRead,
  613. FieldsToIgnoreWrite, VisitedFunctions);
  614. }
  615. }
  616. }
  617. return WrittenFieldsInCalls;
  618. }
  619. // Collect all fields that cannot be accessed for the given shader stage.
  620. // This function recurses into nested payload types.
  621. void CollectNonAccessableFields(
  622. RecordDecl *PayloadType, DXIL::PayloadAccessShaderStage Stage,
  623. const std::set<const FieldDecl *> &FieldsToIgnoreRead,
  624. const std::set<const FieldDecl *> &FieldsToIgnoreWrite,
  625. std::vector<FieldDecl *> &NonWriteableFields,
  626. std::vector<FieldDecl *> &NonReadableFields) {
  627. for (FieldDecl *Field : PayloadType->fields()) {
  628. QualType FieldType = Field->getType();
  629. if (RecordDecl *FieldRecordDecl = FieldType->getAsCXXRecordDecl()) {
  630. if (FieldRecordDecl->hasAttr<HLSLRayPayloadAttr>()) {
  631. CollectNonAccessableFields(FieldRecordDecl, Stage, FieldsToIgnoreRead,
  632. FieldsToIgnoreWrite, NonWriteableFields,
  633. NonReadableFields);
  634. continue;
  635. }
  636. }
  637. auto Qualifier = GetPayloadQualifierForStage(Field, Stage);
  638. // Diagnose writes only if they are not written heigher in the call-graph.
  639. if (!FieldsToIgnoreWrite.count(Field)) {
  640. if (Qualifier != DXIL::PayloadAccessQualifier::Write &&
  641. Qualifier != DXIL::PayloadAccessQualifier::ReadWrite)
  642. NonWriteableFields.push_back(Field);
  643. }
  644. // Diagnose reads only if they have no write heigher in the call-graph.
  645. if (!FieldsToIgnoreRead.count(Field)) {
  646. if (Qualifier != DXIL::PayloadAccessQualifier::Read &&
  647. Qualifier != DXIL::PayloadAccessQualifier::ReadWrite)
  648. NonReadableFields.push_back(Field);
  649. }
  650. }
  651. }
  652. void CollectAccessableFields(RecordDecl *PayloadType,
  653. const std::vector<FieldDecl *> &NonWriteableFields,
  654. const std::vector<FieldDecl *> &NonReadableFields,
  655. std::vector<FieldDecl *> &WriteableFields,
  656. std::vector<FieldDecl *> &ReadableFields) {
  657. for (FieldDecl *Field : PayloadType->fields()) {
  658. QualType FieldType = Field->getType();
  659. if (RecordDecl *FieldRecordDecl = FieldType->getAsCXXRecordDecl()) {
  660. // Skip nested payload types.
  661. if (FieldRecordDecl->hasAttr<HLSLRayPayloadAttr>()) {
  662. CollectAccessableFields(FieldRecordDecl, NonWriteableFields,
  663. NonReadableFields, WriteableFields,
  664. ReadableFields);
  665. continue;
  666. }
  667. }
  668. if (std::find(NonWriteableFields.begin(), NonWriteableFields.end(),
  669. Field) == NonWriteableFields.end())
  670. WriteableFields.push_back(Field);
  671. if (std::find(NonReadableFields.begin(), NonReadableFields.end(), Field) ==
  672. NonReadableFields.end())
  673. ReadableFields.push_back(Field);
  674. }
  675. }
  676. // Emit diagnostics for a TraceRay call.
  677. void DiagnoseTraceCall(Sema &S, const VarDecl *Payload,
  678. const TraceRayCall &Trace, DominatorTree &DT) {
  679. // For each TraceRay call check if write(caller) fields are written.
  680. const DXIL::PayloadAccessShaderStage CallerStage =
  681. DXIL::PayloadAccessShaderStage::Caller;
  682. std::vector<FieldDecl *> WriteableFields;
  683. std::vector<FieldDecl *> NonWriteableFields;
  684. std::vector<FieldDecl *> ReadableFields;
  685. std::vector<FieldDecl *> NonReadableFields;
  686. RecordDecl *PayloadType = GetPayloadType(Payload);
  687. // Check if the payload type used for this trace call is a payload type
  688. if (!PayloadType->hasAttr<HLSLRayPayloadAttr>()) {
  689. S.Diag(Payload->getLocation(), diag::err_payload_requires_attribute)
  690. << PayloadType->getName();
  691. return;
  692. }
  693. CollectNonAccessableFields(PayloadType, CallerStage, {}, {},
  694. NonWriteableFields, NonReadableFields);
  695. CollectAccessableFields(PayloadType, NonWriteableFields, NonReadableFields,
  696. WriteableFields, ReadableFields);
  697. // Find all writes to Payload that reaches the Trace
  698. DxrShaderDiagnoseInfo TraceInfo;
  699. TraceInfo.Payload = Payload;
  700. std::set<const CFGBlock *> Visited;
  701. const CFGBlock *Parent = Trace.Parent;
  702. Visited.insert(Parent);
  703. // Collect payload accesses in the same block until we reach the TraceRay call
  704. for (auto Element : *Parent) {
  705. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  706. if (S->getStmt() == Trace.Call)
  707. break;
  708. CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent);
  709. }
  710. }
  711. for (auto I = Parent->pred_begin(); I != Parent->pred_end(); ++I) {
  712. CFGBlock *Pred = *I;
  713. if (!Pred)
  714. continue;
  715. BackwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited);
  716. }
  717. // Warn if a writeable field has not been written.
  718. for (const FieldDecl *Field : WriteableFields) {
  719. if (!TraceInfo.WritesPerField.count(Field)) {
  720. S.Diag(Trace.Call->getArg(7)->getExprLoc(),
  721. diag::warn_hlsl_payload_access_no_write_for_trace_payload)
  722. << Field->getName();
  723. }
  724. }
  725. // Warn if a written field is not write(caller)
  726. for (const FieldDecl *Field : NonWriteableFields) {
  727. if (TraceInfo.WritesPerField.count(Field)) {
  728. S.Diag(
  729. Trace.Call->getArg(7)->getExprLoc(),
  730. diag::warn_hlsl_payload_access_write_but_no_write_for_trace_payload)
  731. << Field->getName();
  732. }
  733. }
  734. // After a trace call, collect all reads that are not dominated by another
  735. // write warn if a field is not read(caller) but the value is read (undef
  736. // read).
  737. // Discard reads/writes from backward traversal.
  738. TraceInfo.ReadsPerField.clear();
  739. TraceInfo.WritesPerField.clear();
  740. bool CallFound = false;
  741. for (auto Element : *Parent) { // TODO: reverse iterate?
  742. if (Optional<CFGStmt> S = Element.getAs<CFGStmt>()) {
  743. if (S->getStmt() == Trace.Call) {
  744. CallFound = true;
  745. continue;
  746. }
  747. if (CallFound)
  748. CollectReadsWritesAndCallsForPayload(S->getStmt(), TraceInfo, Parent);
  749. }
  750. }
  751. for (auto I = Parent->succ_begin(); I != Parent->succ_end(); ++I) {
  752. CFGBlock *Pred = *I;
  753. if (!Pred)
  754. continue;
  755. ForwardTraverseCFGAndCollectReadsWrites(*Pred, TraceInfo, Visited);
  756. }
  757. for (const FieldDecl *Field : ReadableFields) {
  758. if (!TraceInfo.ReadsPerField.count(Field)) {
  759. S.Diag(Trace.Call->getArg(7)->getExprLoc(),
  760. diag::warn_hlsl_payload_access_read_but_no_read_after_trace)
  761. << Field->getName();
  762. }
  763. }
  764. for (const FieldDecl *Field : NonReadableFields) {
  765. auto WritesToField = TraceInfo.WritesPerField.find(Field);
  766. bool FieldHasWrites = WritesToField != TraceInfo.WritesPerField.end();
  767. for (auto &Read : TraceInfo.ReadsPerField[Field]) {
  768. bool ReadIsDominatedByWrite = false;
  769. if (FieldHasWrites) {
  770. // We found a read to a field that needs diagnose.
  771. // We do not want to warn about fields that read but are dominated by
  772. // a write. Find writes that dominate the read. If we found one,
  773. // ignore the read.
  774. for (auto Write : WritesToField->second) {
  775. ReadIsDominatedByWrite = WriteDominatesUse(Write, Read, DT);
  776. if (ReadIsDominatedByWrite)
  777. break;
  778. }
  779. }
  780. if (ReadIsDominatedByWrite)
  781. continue;
  782. S.Diag(Read.Member->getExprLoc(),
  783. diag::warn_hlsl_payload_access_read_of_undef_after_trace)
  784. << Field->getName();
  785. }
  786. }
  787. }
  788. // Emit diagnostics for all TraceRay calls.
  789. void DiagnoseTraceCalls(Sema &S, CFG &ShaderCFG, DominatorTree &DT,
  790. DxrShaderDiagnoseInfo &Info) {
  791. // Collect TraceRay calls in the shader.
  792. std::set<const CFGBlock *> Visited;
  793. ForwardTraverseCFGAndCollectTraceCalls(ShaderCFG.getEntry(), Info, Visited);
  794. std::set<const CallExpr *> Diagnosed;
  795. for (const TraceRayCall &TraceCall : Info.TraceCalls) {
  796. if (Diagnosed.count(TraceCall.Call))
  797. continue;
  798. Diagnosed.insert(TraceCall.Call);
  799. const VarDecl *Payload = GetPayloadParameterForTraceCall(TraceCall.Call);
  800. DiagnoseTraceCall(S, Payload, TraceCall, DT);
  801. }
  802. }
  803. // Emit diagnostics for all access to the payload of a shader,
  804. // and the input to TraceRay calls.
  805. std::vector<const FieldDecl *>
  806. DiagnosePayloadAccess(Sema &S, DxrShaderDiagnoseInfo &Info,
  807. const std::set<const FieldDecl *> &FieldsToIgnoreRead,
  808. const std::set<const FieldDecl *> &FieldsToIgnoreWrite,
  809. std::set<const FunctionDecl *> VisitedFunctions) {
  810. clang::DominatorTree DT;
  811. AnalysisDeclContextManager AnalysisManager;
  812. AnalysisDeclContext *AnalysisContext =
  813. AnalysisManager.getContext(Info.funcDecl);
  814. CFG &TheCFG = *AnalysisContext->getCFG();
  815. DT.buildDominatorTree(*AnalysisContext);
  816. // Collect all Fields that gets written to return it back up through the
  817. // recursion.
  818. std::vector<const FieldDecl *> WrittenFields;
  819. // Skip if we are in a RayGeneration shader without payload.
  820. if (Info.Payload) {
  821. std::vector<FieldDecl *> NonWriteableFields;
  822. std::vector<FieldDecl *> NonReadableFields;
  823. RecordDecl *PayloadType = GetPayloadType(Info.Payload);
  824. if (!PayloadType)
  825. return WrittenFields;
  826. CollectNonAccessableFields(PayloadType, Info.Stage, FieldsToIgnoreRead,
  827. FieldsToIgnoreWrite, NonWriteableFields,
  828. NonReadableFields);
  829. std::set<const CFGBlock *> Visited;
  830. ForwardTraverseCFGAndCollectReadsWrites(TheCFG.getEntry(), Info, Visited);
  831. if (Info.Payload->hasAttr<HLSLOutAttr>() ||
  832. Info.Payload->hasAttr<HLSLInOutAttr>()) {
  833. // If there is copy-out semantic on the payload field,
  834. // save the written fields and return it back to the caller for
  835. // better diagnostics in higher recursion levels.
  836. for (auto &p : Info.WritesPerField) {
  837. WrittenFields.push_back(p.first);
  838. }
  839. DiagnosePayloadWrites(S, TheCFG, DT, Info, NonWriteableFields,
  840. PayloadType);
  841. }
  842. auto WrittenFieldsInCalls = DiagnosePayloadAsFunctionArg(
  843. S, Info, DT, FieldsToIgnoreRead, FieldsToIgnoreWrite, VisitedFunctions);
  844. // Add calls that write fields as writes to allow the diagnostics on reads
  845. // to check if a call that writes the field dominates the read.
  846. for (auto& P : WrittenFieldsInCalls) {
  847. for (const FieldDecl* Field : P.second) {
  848. Info.WritesPerField[Field].push_back(P.first);
  849. }
  850. }
  851. if (Info.Payload->hasAttr<HLSLInAttr>() ||
  852. Info.Payload->hasAttr<HLSLInOutAttr>())
  853. DiagnosePayloadReads(S, TheCFG, DT, Info, NonReadableFields);
  854. }
  855. DiagnoseTraceCalls(S, TheCFG, DT, Info);
  856. return WrittenFields;
  857. }
  858. const Stmt *IgnoreParensAndDecay(const Stmt *S) {
  859. for (;;) {
  860. switch (S->getStmtClass()) {
  861. case Stmt::ParenExprClass:
  862. S = cast<ParenExpr>(S)->getSubExpr();
  863. break;
  864. case Stmt::ImplicitCastExprClass: {
  865. const ImplicitCastExpr *castExpr = cast<ImplicitCastExpr>(S);
  866. if (castExpr->getCastKind() != CK_ArrayToPointerDecay &&
  867. castExpr->getCastKind() != CK_NoOp &&
  868. castExpr->getCastKind() != CK_LValueToRValue) {
  869. return S;
  870. }
  871. S = castExpr->getSubExpr();
  872. } break;
  873. default:
  874. return S;
  875. }
  876. }
  877. }
  878. // Emit warnings for calls that pass the payload to extern functions.
  879. bool DiagnoseCallExprForExternal(Sema &S, const FunctionDecl *FD,
  880. const CallExpr *CE,
  881. const ParmVarDecl *Payload) {
  882. // We check if we are passing the entire payload struct to an extern function.
  883. // Here ends what we can check, so we just issue a warning.
  884. if (!FD->hasBody()) {
  885. const DeclRefExpr *DRef = nullptr;
  886. const ParmVarDecl *PDecl = nullptr;
  887. for (unsigned i = 0; i < CE->getNumArgs(); ++i) {
  888. const Stmt *arg = IgnoreParensAndDecay(CE->getArg(i));
  889. if (const DeclRefExpr *ArgRef = dyn_cast<DeclRefExpr>(arg)) {
  890. if (ArgRef->getDecl() == Payload) {
  891. DRef = ArgRef;
  892. PDecl = FD->getParamDecl(i);
  893. break;
  894. }
  895. }
  896. }
  897. if (DRef) {
  898. S.Diag(CE->getExprLoc(),
  899. diag::warn_qualified_payload_passed_to_extern_function);
  900. return true;
  901. }
  902. }
  903. return false;
  904. }
  905. // Emits diagnostics for the Payload parameter of a DXR shader stage.
  906. bool DiagnosePayloadParameter(Sema &S, ParmVarDecl *Payload, FunctionDecl *FD,
  907. DXIL::PayloadAccessShaderStage stage) {
  908. if (!Payload) {
  909. // cought already during codgegen of the function
  910. return false;
  911. }
  912. if (!Payload->getAttr<HLSLInOutAttr>()) {
  913. // error: payload must be inout qualified
  914. return false;
  915. }
  916. CXXRecordDecl *Decl = Payload->getType()->getAsCXXRecordDecl();
  917. if (!Decl || Decl->isImplicit()) {
  918. // error: not a user defined type decl
  919. return false;
  920. }
  921. if (!Decl->hasAttr<HLSLRayPayloadAttr>()) {
  922. S.Diag(Payload->getLocation(), diag::err_payload_requires_attribute)
  923. << Decl->getName();
  924. return false;
  925. }
  926. return true;
  927. }
  928. class DXRShaderVisitor : public RecursiveASTVisitor<DXRShaderVisitor> {
  929. public:
  930. DXRShaderVisitor(Sema &S) : S(S) {}
  931. void diagnose(TranslationUnitDecl *TU) { TraverseTranslationUnitDecl(TU); }
  932. bool VisitFunctionDecl(FunctionDecl *Decl) {
  933. auto attr = Decl->getAttr<HLSLShaderAttr>();
  934. if (!attr)
  935. return true;
  936. StringRef shaderStage = attr->getStage();
  937. if (StringRef("miss,closesthit,anyhit,raygeneration").count(shaderStage)) {
  938. ParmVarDecl *Payload = nullptr;
  939. if (shaderStage != "raygeneration")
  940. Payload = Decl->getParamDecl(0);
  941. DXIL::PayloadAccessShaderStage Stage =
  942. DXIL::PayloadAccessShaderStage::Invalid;
  943. if (shaderStage == "closesthit") {
  944. Stage = DXIL::PayloadAccessShaderStage::Closesthit;
  945. } else if (shaderStage == "miss") {
  946. Stage = DXIL::PayloadAccessShaderStage::Miss;
  947. } else if (shaderStage == "anyhit") {
  948. Stage = DXIL::PayloadAccessShaderStage::Anyhit;
  949. }
  950. // Diagnose the payload parameter.
  951. if (Payload) {
  952. DiagnosePayloadParameter(S, Payload, Decl, Stage);
  953. }
  954. DxrShaderDiagnoseInfo Info;
  955. Info.funcDecl = Decl;
  956. Info.Payload = Payload;
  957. Info.Stage = Stage;
  958. std::set<const FunctionDecl *> VisitedFunctions;
  959. DiagnosePayloadAccess(S, Info, {}, {}, VisitedFunctions);
  960. }
  961. return true;
  962. }
  963. private:
  964. Sema &S;
  965. };
  966. } // namespace
  967. namespace hlsl {
  968. void DiagnoseRaytracingPayloadAccess(clang::Sema &S,
  969. clang::TranslationUnitDecl *TU) {
  970. DXRShaderVisitor visitor(S);
  971. visitor.diagnose(TU);
  972. }
  973. } // namespace hlsl